Mécanisme d’Attention : l’innovation qui a engendré les Transformers et les LLM
Le mécanisme d’attention est une technique qui permet à un modèle de deep learning de pondérer dynamiquement l’importance de chaque partie de l’entrée lors du traitement, en se « concentrant » sur les éléments les plus pertinents pour la tâche en cours. Inventé pour la traduction automatique (Bahdanau et al., 2014), il est devenu la brique fondamentale de l’architecture Transformer et de tous les LLM modernes.
- Principe
- Pondérer dynamiquement les éléments de l’entrée selon leur pertinence
- Composants clés
- Query (Q), Key (K), Value (V), scores d’alignement, softmax
- Variantes historiques
- Additive (Bahdanau, 2014), multiplicative (Luong, 2015)
- Variante dominante
- Scaled dot-product attention + multi-head (Transformer, 2017)
- Types
- Self-attention, cross-attention, causal (masked) attention
- Complexité
- O(n²) en longueur de séquence (quadratique)
- Papiers clés
- Bahdanau et al. (2014), Vaswani et al. (2017, « Attention Is All You Need »)
L’intuition : pourquoi l’attention ?
Imaginez que vous traduisez la phrase « Le chat noir dort sur le tapis rouge » en anglais. Pour traduire « black », vous devez vous concentrer sur « noir » dans la phrase source, pas sur l’ensemble de la phrase. Pour « sleeps », c’est « dort » qui compte. Le mécanisme d’attention formalise cette capacité de « regarder au bon endroit au bon moment ».
Avant l’attention, les modèles Seq2Seq comprimaient toute la phrase source dans un unique vecteur de contexte de taille fixe, créant un goulot d’étranglement. L’attention résout ce problème en donnant au decoder un accès direct à chaque état caché de l’encoder, avec des poids de pertinence calculés dynamiquement à chaque étape de décodage.
L’attention de Bahdanau (2014) : l’attention additive
Proposée dans le papier « Neural Machine Translation by Jointly Learning to Align and Translate », l’attention de Bahdanau est la première formulation moderne du mécanisme d’attention pour le NLP.
Le principe : à chaque pas de décodage t, calculer un score d’alignement entre l’état caché du decoder et chaque état caché de l’encoder, puis utiliser ces scores pour construire un vecteur de contexte pondéré.
Pour chaque pas de décodage t :
1. Score d'alignement (additif) :
e_tj = v^T · tanh(W_q · s_{t-1} + W_k · h_j)
(s_{t-1} = état du decoder, h_j = état j de l'encoder)
2. Poids d'attention (softmax) :
α_tj = softmax(e_tj) = exp(e_tj) / Σ exp(e_tk)
3. Vecteur de contexte (somme pondérée) :
c_t = Σ α_tj · h_j
4. Le decoder utilise c_t et s_{t-1} pour produire la sortie
L’appellation « additive » vient du fait que le score d’alignement additionne les projections du query (état du decoder) et du key (état de l’encoder) avant d’appliquer tanh. Cette approche a deux matrices de poids apprenables (W_q et W_k) et un vecteur v, ce qui la rend plus expressive mais plus coûteuse que l’alternative multiplicative.
L’attention de Luong (2015) : l’attention multiplicative
Luong et al. (2015) proposent une variante simplifiée avec trois fonctions de score d’alignement :
Dot product : score(s_t, h_j) = s_t^T · h_j (pas de paramètres supplémentaires, très rapide)
General : score(s_t, h_j) = s_t^T · W · h_j (une matrice de poids apprise)
Concat : similaire à Bahdanau, score = v^T · tanh(W · [s_t ; h_j])
Autre différence : dans l’attention de Luong, le vecteur de contexte est calculé après que le RNN decoder a produit l’état caché du pas courant, tandis que chez Bahdanau, le contexte influence directement l’entrée du RNN decoder. En pratique, le dot product attention est devenu le standard car il est le plus rapide à calculer et scale bien avec la dimension.
Scaled Dot-Product Attention : le cœur du Transformer
Le Transformer (Vaswani et al., 2017) généralise et formalise l’attention multiplicative avec le mécanisme Query-Key-Value (QKV) :
Attention(Q, K, V) = softmax(Q · K^T / √d_k) · V
Où :
- Q (Queries) = ce que l'on cherche (dimension d_k)
- K (Keys) = ce contre quoi on compare (dimension d_k)
- V (Values) = ce que l'on extrait (dimension d_v)
- √d_k = facteur de mise à l'échelle (scaling factor)
Les trois étapes conceptuelles :
1. Calcul des scores : le produit scalaire Q·K^T mesure la similarité entre chaque query et chaque key. Le résultat est une matrice n×n (où n est la longueur de la séquence).
2. Normalisation : la division par √d_k empêche les scores d’être trop grands quand la dimension est élevée (ce qui saturerait le softmax et produirait des gradients quasi nuls). Le softmax convertit les scores en probabilités sommant à 1.
3. Agrégation : les poids d’attention sont multipliés par les values pour produire la sortie. Chaque position de la séquence reçoit un mélange pondéré des values de toutes les autres positions.
Multi-Head Attention : regarder sous plusieurs angles
Au lieu d’exécuter une seule opération d’attention, le Transformer utilise plusieurs « têtes » en parallèle. Chaque tête projette Q, K, V dans un sous-espace différent avec ses propres matrices de poids, effectue l’attention indépendamment, puis les résultats sont concaténés et projetés.
MultiHead(Q, K, V) = Concat(head_1, ..., head_h) · W_O
Où chaque head_i = Attention(Q · W_Q_i, K · W_K_i, V · W_V_i)
Transformer original : h=8 têtes, d_model=512
→ chaque tête opère en dimension d_k = d_v = 512/8 = 64
L’intérêt : différentes têtes peuvent capturer différents types de relations. Une tête peut se spécialiser dans les dépendances syntaxiques (sujet-verbe), une autre dans la coréférence (pronom-antécédent), une troisième dans les relations sémantiques. Les recherches en interprétabilité ont confirmé que les têtes d’attention se spécialisent effectivement dans des patterns linguistiques distincts.
Types d’attention dans les Transformers
Self-attention (auto-attention)
Dans la self-attention, Q, K et V proviennent tous de la même séquence. Chaque token « regarde » tous les autres tokens de la même séquence pour construire sa représentation. C’est le mécanisme de base de l’encoder du Transformer et des modèles comme BERT.
Exemple : dans « Le chat qui mangeait dormait », la self-attention permet au token « dormait » de fortement attendre au token « chat » (le sujet), malgré les mots intermédiaires.
Causal attention (attention masquée)
Dans le decoder, chaque token ne peut regarder que les tokens qui le précèdent (pas ceux qui le suivent). Un masque triangulaire force les scores d’attention des positions futures à -∞ avant le softmax, ce qui les ramène à 0. C’est ce qui rend les modèles GPT et les LLM autoregressifs : chaque token est prédit uniquement à partir du contexte passé.
Cross-attention (attention croisée)
Dans la cross-attention, Q provient du decoder et K, V proviennent de l’encoder. C’est le mécanisme qui relie l’encoder au decoder dans les architectures encoder-decoder (T5, BART, Whisper). Il permet au decoder de « consulter » l’entrée encodée à chaque étape de génération.
| Type d’attention | Source Q | Source K, V | Usage |
|---|---|---|---|
| Self-attention | Même séquence | Même séquence | Encoder (BERT), chaque couche du Transformer |
| Causal (masked) | Même séquence | Même séquence (masquée) | Decoder autoregressif (GPT, LLM) |
| Cross-attention | Decoder | Encoder | Traduction, transcription (T5, Whisper) |
Le défi de la complexité quadratique
Le produit Q·K^T produit une matrice n×n, où n est la longueur de la séquence. Le coût en calcul et en mémoire est donc O(n²). Pour une séquence de 1024 tokens, la matrice d’attention a ~1 million d’entrées. Pour 100K tokens, c’est 10 milliards. C’est le principal goulot d’étranglement des Transformers sur les longues séquences.
Plusieurs approches tentent de réduire cette complexité :
FlashAttention : ne réduit pas la complexité théorique mais optimise drastiquement l’utilisation de la mémoire GPU en exploitant la hiérarchie mémoire (SRAM vs HBM). C’est le standard dans tous les frameworks modernes.
Sparse attention : chaque token n’attende qu’à un sous-ensemble des autres tokens (voisins, tokens espacés régulièrement). Utilisé dans GPT-3 et Longformer.
Linear attention : remplace le softmax par une approximation linéaire, ramenant la complexité à O(n). Moins expressif mais beaucoup plus rapide.
Sliding window attention : chaque token n’attende qu’aux tokens dans une fenêtre locale. Mistral utilise cette approche avec un fenêtre de 4096 tokens.
State Space Models : Mamba et ses dérivés remplacent l’attention par une récurrence sélective à complexité linéaire, offrant une alternative radicalement différente.
L’attention au-delà du NLP
Vision par ordinateur : les Vision Transformers (ViT) appliquent la self-attention aux patches d’images, remplaçant la convolution par l’attention pour la reconnaissance d’images.
Audio : Whisper (OpenAI) utilise l’attention dans un encoder-decoder Transformer pour la transcription vocale multilingue.
Multimodal : dans les modèles comme CLIP et Flamingo, la cross-attention relie les représentations textuelles et visuelles, permettant au modèle de « raisonner » sur des images à partir de texte.
Bioinformatique : AlphaFold utilise l’attention pour modéliser les interactions entre acides aminés et prédire la structure 3D des protéines.
Recommandation : les systèmes de recommandation utilisent l’attention pour pondérer l’historique d’interactions d’un utilisateur et prédire ses préférences.
Implémentation : Scaled Dot-Product Attention en PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
class ScaledDotProductAttention(nn.Module):
def __init__(self, d_k):
super().__init__()
self.scale = math.sqrt(d_k)
def forward(self, Q, K, V, mask=None):
# Q, K, V shapes: (batch, heads, seq_len, d_k)
scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
if mask is not None:
scores = scores.masked_fill(mask == 0, float('-inf'))
attn_weights = F.softmax(scores, dim=-1)
output = torch.matmul(attn_weights, V)
return output, attn_weights
class MultiHeadAttention(nn.Module):
def __init__(self, d_model=512, n_heads=8):
super().__init__()
self.d_k = d_model // n_heads
self.n_heads = n_heads
self.W_q = nn.Linear(d_model, d_model)
self.W_k = nn.Linear(d_model, d_model)
self.W_v = nn.Linear(d_model, d_model)
self.W_o = nn.Linear(d_model, d_model)
self.attention = ScaledDotProductAttention(self.d_k)
def forward(self, Q, K, V, mask=None):
batch_size = Q.size(0)
# Projections linéaires et reshape pour multi-head
Q = self.W_q(Q).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
K = self.W_k(K).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
V = self.W_v(V).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
out, weights = self.attention(Q, K, V, mask)
# Concaténation et projection finale
out = out.transpose(1, 2).contiguous().view(batch_size, -1, self.n_heads * self.d_k)
return self.W_o(out)
En production, vous utiliserez nn.MultiheadAttention de PyTorch (qui intègre FlashAttention automatiquement via torch.nn.functional.scaled_dot_product_attention) plutôt que cette implémentation manuelle. Mais comprendre le code ci-dessus est essentiel pour saisir ce qui se passe « sous le capot » des LLM.
Évolutions récentes de l’attention
Grouped Query Attention (GQA) : utilisée dans LLaMA 2/3 et Mistral. Plusieurs têtes de queries partagent les mêmes K et V, réduisant la taille du cache KV à l’inférence sans perte de qualité significative.
Multi-Query Attention (MQA) : cas extrême de GQA où toutes les têtes partagent un seul jeu de K et V. Utilisée dans PaLM et Falcon. Inférence très rapide mais qualité légèrement inférieure à GQA.
Sliding Window Attention : chaque token n’attende qu’à une fenêtre locale de tokens voisins, réduisant la complexité à O(n × w) où w est la taille de la fenêtre. Mistral l’utilise avec w=4096.
Ring Attention : technique de parallélisation qui distribue les séquences très longues sur plusieurs GPU, chacun calculant l’attention sur sa portion et échangeant les résultats. Permet de traiter des séquences de millions de tokens.
Visualiser l’attention : interprétabilité
Un avantage majeur de l’attention est son interprétabilité. Les poids d’attention (α_tj) forment une matrice qui montre « qui regarde qui ». En visualisant ces matrices comme des heatmaps, on peut comprendre quelles relations le modèle a capturées :
En traduction : les poids d’attention révèlent l’alignement entre mots source et cible (« chat » → « cat »).
En self-attention : on peut observer les dépendances syntaxiques apprises (le verbe « attende » au sujet, l’adjectif au nom qu’il qualifie).
En vision : les attention maps montrent les régions de l’image sur lesquelles le modèle se concentre pour sa prédiction.
Des outils comme BertViz (pour les modèles de langue) et Attention Rollout (pour les ViT) permettent d’explorer ces visualisations interactivement. Dans les ViT, l’attention rollout combine récursivement les matrices d’attention de toutes les couches pour produire une carte d’attention globale montrant quels patches de l’image ont le plus influencé la prédiction finale.
Questions fréquentes sur le mécanisme d’attention
Quelle est la différence entre attention et self-attention ?
L’attention « classique » (Bahdanau, cross-attention) relie deux séquences différentes : le decoder consulte l’encoder. La self-attention relie une séquence à elle-même : chaque token consulte tous les autres tokens de la même séquence. La self-attention est le mécanisme central des Transformers et des LLM. Elle permet de capturer les relations internes d’une séquence (dépendances syntaxiques, coréférences) sans passer par une boucle récurrente.
Pourquoi diviser par √d_k dans le scaled dot-product attention ?
Quand la dimension d_k est grande, le produit scalaire Q·K^T produit des valeurs élevées. Or, le softmax est très sensible à la magnitude de ses entrées : des valeurs trop grandes poussent la distribution vers un « one-hot » (un seul token reçoit tout le poids), ce qui produit des gradients quasi nuls et ralentit l’apprentissage. Diviser par √d_k recentre les scores dans une plage où le softmax produit des distributions plus lisses et informatives.
Combien de têtes d’attention faut-il utiliser ?
Le Transformer original utilise 8 têtes pour d_model=512. Les LLM modernes utilisent généralement 32 à 128 têtes. La règle empirique est que d_k = d_model / h doit rester suffisamment grand (≥ 64) pour que chaque tête ait assez de capacité. Des travaux récents (Grouped Query Attention, Multi-Query Attention) réduisent le nombre de têtes pour K et V tout en gardant un nombre élevé de têtes pour Q, offrant un meilleur compromis vitesse/qualité à l’inférence.
L’attention est-elle plus puissante que la récurrence ?
Pour capturer les dépendances à longue portée, oui. L’attention accède directement à n’importe quel token en un seul pas, tandis que la récurrence doit propager l’information à travers une chaîne d’états cachés. Mais l’attention a un coût quadratique qui la rend impraticable pour les très longues séquences sans optimisations. Les architectures hybrides (Mamba + attention) combinent le meilleur des deux approches : récurrence sélective pour l’efficacité, attention pour la précision de récupération d’information.
FlashAttention change-t-il le résultat de l’attention ?
Non. FlashAttention produit exactement le même résultat que l’attention standard. C’est une optimisation purement algorithmique et matérielle qui réorganise les calculs pour minimiser les transferts de données entre les différents niveaux de mémoire du GPU (HBM et SRAM). Le résultat : 2 à 4× plus rapide et significativement moins de mémoire consommée, sans aucun compromis sur la qualité. C’est intégré par défaut dans PyTorch 2.0+ et dans tous les frameworks d’entraînement de LLM.