FlashAttention
FlashAttention est un algorithme d’attention exacte et IO-aware qui accélère le calcul d’attention des Transformers de 2 à 15x en minimisant les transferts de données entre la mémoire principale du GPU (HBM) et sa mémoire rapide (SRAM), tout en réduisant la consommation mémoire de quadratique à linéaire par rapport à la longueur de la séquence.
- Créateur
- Tri Dao (Princeton University, Together AI) et collaborateurs
- Versions
- v1 (2022, A100), v2 (2023, A100), v3 (2024, H100), v4 (mars 2026, B200)
- Performance FA4
- 1 605 TFLOPS sur B200 (71% d’utilisation), 2,7x plus rapide que Triton
- Principe clé
- Tiling + online softmax pour éviter de matérialiser la matrice d’attention N×N en HBM
- Mémoire
- O(N) au lieu de O(N²) pour la séquence de longueur N
- Exactitude
- Résultat mathématiquement identique à l’attention standard (aucune approximation)
- Licence
- Open Source (BSD-3-Clause)
- Installation
pip install flash-attn(FA2),pip install fa4(FA4)
Pourquoi FlashAttention existe
Le goulot d’étranglement de l’attention standard
Le mécanisme d’attention est au cœur de tous les LLM modernes. Sa formule de base est simple : Attention(Q, K, V) = softmax(QKT / √d) × V. Mais cette formule cache un problème majeur de performance.
Dans l’implémentation naïve, le calcul d’attention pour une séquence de N tokens produit une matrice intermédiaire S = QKT de taille N×N. Cette matrice doit être écrite en mémoire GPU (HBM), puis relue pour appliquer le softmax, puis relue encore pour le produit avec V. Pour N = 128K tokens, cette matrice occupe N² × sizeof(float16) ≈ 32 Go, soit plus que la mémoire de la plupart des GPU.
Même pour des séquences plus courtes où la matrice tient en mémoire, les allers-retours entre le GPU et sa mémoire HBM sont extrêmement coûteux. Le calcul d’attention est fondamentalement memory-bound : le GPU passe plus de temps à transférer les données qu’à les calculer. Sur un H100, les Tensor Cores peuvent effectuer 989 TFLOPS, mais la bande passante HBM est « seulement » de 3,35 To/s. Le ratio calcul/mémoire rend les opérations d’attention sous-optimales.
L’idée clé : IO-awareness
FlashAttention part d’une observation simple : au lieu de calculer la matrice d’attention complète, de l’écrire en HBM, puis de la relire, on peut découper le calcul en blocs (tiling) et effectuer chaque bloc entièrement dans la SRAM rapide du GPU, sans jamais matérialiser la matrice N×N complète en HBM.
Le défi technique est le softmax, qui est normalement une opération globale (elle nécessite de connaître la somme de toutes les exponentielles pour normaliser). FlashAttention résout ce problème avec l’online softmax : une variante incrémentale qui accumule les statistiques (max et somme des exponentielles) au fur et à mesure du traitement des blocs, puis applique une correction de rescaling pour obtenir le résultat exact.
Le résultat est un algorithme qui calcule le même résultat exact que l’attention standard, mais avec des accès mémoire réduits de O(N²) à O(N), et une empreinte mémoire linéaire au lieu de quadratique.
L’évolution de FlashAttention : v1 à v4
FlashAttention-1 (mai 2022, Ampere/A100)
La version originale introduit le tiling et l’online softmax. Sur A100, elle atteint 25 à 40% d’utilisation des FLOPS théoriques et apporte des speedups de 2 à 4x par rapport à l’attention PyTorch standard. L’économie mémoire est de 10x à la longueur de séquence 2K et 20x à 4K. FlashAttention-1 a permis aux modèles de passer de contextes de 2-4K tokens (GPT-3, OPT) à des fenêtres beaucoup plus grandes.
FlashAttention-2 (2023, Ampere/A100)
FA2 améliore le parallélisme et le partitionnement du travail entre les warps du GPU. L’utilisation des FLOPS passe à environ 70% sur A100. Le training de GPT-2 est accéléré de 3 à 5x par rapport à l’implémentation HuggingFace de base, atteignant 225 TFLOPS/s par A100 (72% d’utilisation). FA2 gère des dimensions de tête allant jusqu’à 256 et supporte les GPU AMD ROCm.
FA2 est aujourd’hui la version la plus largement déployée, intégrée dans PyTorch, HuggingFace Transformers, vLLM, et la quasi-totalité des frameworks d’entraînement et d’inférence. C’est le « workhorse » de l’écosystème LLM.
FlashAttention-3 (juillet 2024, Hopper/H100)
FA3 exploite trois nouvelles capacités des GPU Hopper. Le warp-specialization pour chevaucher le calcul et le transfert de données via les instructions WGMMA asynchrones. L’interleaving des opérations matmul et softmax en pipeline à 2 étages. Et le support du FP8 avec quantization par blocs et traitement incohérent pour limiter la perte de précision.
Les performances sur H100 atteignent 740 TFLOPS en FP16 (75% d’utilisation) et approchent 1,2 PFLOPS en FP8, avec une erreur numérique 2,6x inférieure à l’attention FP8 de base. FA3 est accepté comme Spotlight Poster à NeurIPS 2024.
FlashAttention-4 (mars 2026, Blackwell/B200)
FA4 est une refonte complète, conçue de zéro pour l’architecture Blackwell de NVIDIA. Le problème central que FA4 résout est le scaling asymétrique du matériel : entre le H100 (Hopper) et le B200 (Blackwell), les Tensor Cores passent de ~1 à ~2,25 PFLOPS, mais les unités de fonctions spéciales (SFU) pour le softmax exponentiel et la bande passante de la shared memory n’ont pas suivi.
Conséquence : sur Blackwell, le goulot d’étranglement n’est plus le matmul mais le calcul de l’exponentielle dans le softmax (forward) et le trafic shared memory (backward). FA4 introduit plusieurs innovations pour contourner ces limites :
Pipeline ping-pong avec double tuile Query. FA4 traite deux tuiles de Query par CTA (128 tokens chacune) en alternance, gardant les deux pipelines saturés en permanence.
Softmax distribué et émulé. Le calcul de l’exponentielle est réparti entre l’instruction matérielle MUFU.EX2 et une émulation logicielle sur les FMA, doublant effectivement le débit du softmax.
Tensor Memory (TMEM). Les accumulateurs intermédiaires sont stockés dans la TMEM de Blackwell (une nouvelle mémoire on-chip proche des Tensor Cores) au lieu des registres, libérant de la pression sur les registres et permettant de garder plusieurs MMA en vol simultanément.
Mode 2-CTA MMA. Deux CTA (Cooperative Thread Arrays) coopèrent sur une même tuile de multiplication matricielle, chacune gérant la moitié des opérandes. Cela divise approximativement par deux le trafic shared memory pour le backward pass.
Warpgroup de correction dédié. Le rescaling en ligne du softmax est retiré du chemin critique et confié à un warpgroup spécialisé, minimisant les opérations non-matmul.
| Version | GPU cible | Performance peak | Utilisation GPU | Speedup vs précédent |
|---|---|---|---|---|
| FA1 (2022) | A100 (Ampere) | ~200 TFLOPS (FP16) | 25-40% | 2-4x vs PyTorch standard |
| FA2 (2023) | A100 (Ampere) | ~225 TFLOPS (FP16) | ~70% | ~1,5x vs FA1 |
| FA3 (2024) | H100 (Hopper) | 740 TFLOPS (FP16), ~1,2 PFLOPS (FP8) | 75% | 1,5-2x vs FA2 sur H100 |
| FA4 (2026) | B200 (Blackwell) | 1 605 TFLOPS (BF16) | 71% | 3,6x vs FA2, 1,3x vs cuDNN |
CuTe-DSL : la fin des compilations interminables
FA4 est entièrement écrit en CuTe-DSL, un langage spécifique embarqué dans Python développé par l’équipe CUTLASS de NVIDIA. Les versions précédentes reposaient sur des templates C++ qui nécessitaient des temps de compilation allant de minutes à heures. CuTe-DSL compile 20 à 30x plus vite, et l’installation se fait en secondes via pip install fa4.
L’intégration avec PyTorch est immédiate via FlexAttention. Les chercheurs peuvent écrire des variantes d’attention personnalisées (ALiBi, sliding window, document masking, soft-capping) comme de simples fonctions Python score_mod, qui sont compilées JIT en kernels FA4. Les gains mesurés sont de 1,2x à 3,2x par rapport au backend Triton précédent de FlexAttention.
Compatibilité matérielle
| GPU | FA2 | FA3 | FA4 |
|---|---|---|---|
| NVIDIA A100 / A800 (Ampere) | Oui (optimal) | Non | Non |
| NVIDIA RTX 3090/4090 (Ampere/Ada) | Oui | Non | Non |
| NVIDIA H100 / H200 (Hopper) | Oui | Oui (optimal) | Oui (gains modestes via CuTe-DSL) |
| NVIDIA B200 / B300 (Blackwell datacenter) | Oui (sous-optimal) | Non compatible | Oui (optimal) |
| NVIDIA RTX 5090 (Blackwell desktop, SM120) | Oui | Non | Non (architecture silicon différente) |
| AMD MI200x / MI300x / MI355x (ROCm) | Oui (backend CK ou Triton) | Oui (via backend Triton) | Non (CuTe-DSL est NVIDIA-spécifique) |
Intégration dans l’écosystème
FlashAttention est intégré dans la quasi-totalité de l’écosystème LLM :
PyTorch. FlexAttention supporte FA4 comme backend sur Hopper et Blackwell depuis mars 2026. Les variantes d’attention personnalisées sont compilées JIT en kernels FA4.
vLLM. La version 0.17.0 (7 mars 2026) intègre FA4, activé automatiquement sur hardware compatible. L’intégration a pris seulement deux jours après la publication de FA4, témoignant de la maturité de l’écosystème.
HuggingFace Transformers. FA2 est l’attention par défaut pour la plupart des modèles. Le flag attn_implementation="flash_attention_2" active explicitement FlashAttention.
Frameworks d’entraînement. Megatron-LM, DeepSpeed, FSDP de PyTorch, et tous les principaux frameworks de pré-entraînement utilisent FlashAttention. L’entraînement de GPT-3 avec FA est environ 3x plus rapide qu’avec l’implémentation de base.
FlashAttention vs autres optimisations
FlashAttention vs PagedAttention
Ces deux technologies sont complémentaires, pas concurrentes. FlashAttention optimise le calcul d’attention (comment les opérations matmul et softmax sont exécutées sur le GPU). PagedAttention optimise la gestion mémoire du KV cache (comment les clés et valeurs sont stockées et allouées). En production, les deux sont utilisés ensemble : FlashAttention pour le calcul rapide, PagedAttention pour l’utilisation efficace de la mémoire.
FlashAttention vs attention standard
FlashAttention calcule le même résultat exact que l’attention standard. Ce n’est pas une approximation. La seule différence est dans l’ordre des opérations et la gestion de la mémoire. L’attention standard matérialise la matrice N×N en HBM ; FlashAttention la calcule par blocs en SRAM. Le résultat numérique est identique (aux erreurs d’arrondi flottant près, qui sont du même ordre que l’attention standard).
FlashAttention vs attention sparse/approximée
Contrairement aux méthodes d’attention sparse (Longformer, BigBird) ou approximée (Linformer, Performer), FlashAttention est exacte. Elle ne réduit pas le nombre de paires de tokens considérées ; elle rend le calcul complet plus efficace. Cela signifie qu’il n’y a aucun compromis sur la qualité du modèle. FlashAttention peut d’ailleurs être combinée avec des patterns d’attention sparse (block-sparse attention) pour des gains encore plus importants sur les très longues séquences.
Impact sur l’industrie du LLM
FlashAttention a eu un impact transformateur sur l’ensemble de l’écosystème IA. Avant FlashAttention, les fenêtres de contexte des LLM étaient limitées à 2-4K tokens (GPT-3, OPT). Avec FlashAttention, elles sont passées à 128K (GPT-4), puis à 1M+ tokens (Gemini, Claude). Cette expansion des fenêtres de contexte a rendu possibles le RAG sur de longs documents, l’analyse de codebases entières, et les conversations multi-tours avec mémoire étendue.
Du côté de l’entraînement, FlashAttention a réduit les coûts de manière significative. L’entraînement de BERT-large est accéléré d’environ 15% par rapport au baseline MLPerf, et le training de GPT-2 de 3x. Ces gains se traduisent directement en économies sur les factures GPU.
FlashAttention est également un facteur clé du « CUDA moat » de NVIDIA. Tri Dao et les chercheurs qui optimisent FlashAttention travaillent exclusivement sur GPU NVIDIA, et FA4 est écrit en CuTe-DSL qui est spécifique à CUDA. Le portage vers AMD ROCm est significativement plus difficile pour FA4 que pour les versions précédentes en C++. Cela renforce la position dominante de NVIDIA dans l’infrastructure IA.
L’adoption de FlashAttention suit une trajectoire remarquable : le dépôt GitHub accumule plus de 19 100 étoiles et la technique est intégrée dans plus d’une quinzaine de frameworks LLM majeurs. Le délai entre la publication de FA4 (5 mars 2026) et son intégration dans vLLM 0.17 (7 mars 2026) illustre la réactivité de l’écosystème : seulement deux jours entre la publication amont et la disponibilité en production.
L’impact économique est direct et mesurable. Pour un cluster de 100 GPU H100 à 2 $/h par GPU, le passage de l’attention standard à FlashAttention peut réduire la facture de training de 30 à 50%. Sur l’inférence, les gains se composent avec le KV cache, la quantization, et le décodage spéculatif.
Utilisation pratique
Installation
# FlashAttention-2 (compatible Ampere, Ada, Hopper)
pip install flash-attn --no-build-isolation
# FlashAttention-4 (Hopper et Blackwell)
pip install fa4
# Avec vLLM (FA4 activé automatiquement sur hardware compatible)
pip install vllm==0.17.0
vllm serve meta-llama/Llama-3.3-70B-Instruct
Utilisation avec PyTorch
# FA2 direct
from flash_attn import flash_attn_func
output = flash_attn_func(q, k, v, causal=True)
# Via FlexAttention (FA4 backend sur Blackwell)
from torch.nn.attention.flex_attention import flex_attention
output = flex_attention(q, k, v)
# Via HuggingFace Transformers
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-3.3-70B-Instruct",
attn_implementation="flash_attention_2"
)
attn_implementation est nécessaire.
FlashAttention et attention block-sparse
FlashAttention supporte nativement les patterns d’attention block-sparse. Au lieu de calculer l’attention sur toutes les paires de tokens (dense), on peut spécifier un masque par blocs qui définit quels blocs de tokens s’attendent mutuellement. Cela permet d’implémenter efficacement des patterns comme la sliding window, l’attention locale + globale (à la Longformer), ou des masques de documents pour les entraînements multi-documents.
FA4 sur Blackwell supporte le block-sparse dans les deux passes (forward et backward), avec le support pack-GQA pour les mask mods diffusés. L’avantage par rapport aux implémentations sparse alternatives est que FlashAttention conserve toutes ses optimisations IO-aware (tiling, online softmax) même en mode sparse, ce qui n’est pas le cas des implémentations naïves basées sur les sparse tensors.
Limites et considérations
Dépendance matérielle. Chaque version de FlashAttention est optimisée pour une génération de GPU spécifique. FA4 ne fonctionne que sur Blackwell datacenter (SM100), pas sur les GPU grand public « Blackwell » (RTX 5090, SM120). Le portage AMD ROCm est limité à FA2 et FA3 (via Triton).
Dimension de tête. FA2 supporte toutes les dimensions de tête jusqu’à 256, mais le backward pass pour les dimensions > 192 nécessite des GPU datacenter (A100/H100). Les GPU grand public sont limités à 256 sans dropout.
Pas de support Windows natif. FlashAttention est principalement développé et testé sous Linux. Le support Windows progresse mais reste expérimental.
Compilation longue (FA2). Sans ninja, la compilation de FA2 peut prendre jusqu’à 2 heures. Avec ninja sur une machine 64 cœurs, elle descend à 3-5 minutes. FA4 élimine ce problème grâce à CuTe-DSL.
Questions fréquentes sur FlashAttention
FlashAttention modifie-t-il la qualité du modèle ?
Non. FlashAttention calcule le résultat mathématiquement exact de l’attention standard. Ce n’est pas une approximation. Il réorganise les opérations et optimise les accès mémoire, mais le résultat numérique est identique (aux erreurs d’arrondi flottant près, qui sont du même ordre que l’attention classique). Vous pouvez activer FlashAttention sans aucun impact sur la qualité de votre modèle, que ce soit pour l’entraînement ou l’inférence.
Quelle version de FlashAttention utiliser ?
Sur GPU NVIDIA Ampere (A100, RTX 3090) ou Ada (RTX 4090), utilisez FA2. Sur Hopper (H100, H200), FA3 est optimal, avec FA4 offrant des gains modestes via la compilation CuTe-DSL. Sur Blackwell datacenter (B200, B300), FA4 est indispensable (FA3 n’est pas compatible). Sur AMD ROCm (MI300x, MI355x), FA2 via le backend composable_kernel ou Triton. Pour les RTX 5090 « Blackwell desktop », FA2 reste la seule option.
FlashAttention fonctionne-t-il sur CPU ou Apple Silicon ?
FlashAttention est spécifiquement conçu pour les GPU NVIDIA (et partiellement AMD ROCm). Il n’existe pas de version CPU ou Apple Silicon. Sur ces plateformes, les frameworks comme PyTorch utilisent leurs propres implémentations d’attention optimisées (Metal Performance Shaders sur Apple Silicon, par exemple). Les gains de FlashAttention sont spécifiques à l’architecture mémoire des GPU.
Quelle est la différence entre FlashAttention et FlashDecoding ?
FlashAttention optimise le forward pass complet de l’attention (utilisé pendant l’entraînement et le prefill). Flash-Decoding (2023, Meta) étend ces optimisations spécifiquement au décodage autorégressif (génération token par token), où la dimension Query est très petite (souvent 1). Flash-Decoding charge les K et V du KV cache en parallèle pour atteindre jusqu’à 8x de speedup sur les séquences très longues pendant la génération.
FlashAttention est-il compatible avec le décodage spéculatif ?
Oui. FlashAttention accélère le forward pass de vérification dans le décodage spéculatif, où le modèle cible vérifie les tokens candidats proposés par le draft model. Les gains sont cumulatifs : FlashAttention réduit le coût de chaque forward pass, et le décodage spéculatif réduit le nombre de forward passes nécessaires. Les frameworks comme vLLM et SGLang activent automatiquement les deux optimisations ensemble.