Medusa
Medusa est un framework d’accélération de l’inférence des LLM qui ajoute des « têtes Medusa » (couches FFN supplémentaires) au modèle pour prédire plusieurs tokens futurs en parallèle, sans nécessiter de modèle brouillon séparé. Combiné à un mécanisme d’attention en arbre, Medusa accélère la génération d’environ 2× tout en étant simple à implémenter.
- Catégorie
- Technique d’accélération d’inférence / Parallel decoding
- Origine
- Together AI / Princeton / UIUC (septembre 2023), inspiré de Stern et al. (2018)
- Principe
- K têtes de prédiction FFN attachées au modèle, chacune prédisant un token futur différent
- Speedup
- ~2× (varie selon le modèle et la tâche)
- Entraînement
- Léger : seules les têtes sont entraînées, le modèle principal reste gelé
- Frameworks
- TensorRT-LLM, vLLM
- Relation
- Alternative au speculative decoding classique (pas de modèle brouillon séparé)
Le problème que Medusa résout
Le speculative decoding classique accélère les LLM de 2-3× mais impose de maintenir un modèle brouillon séparé : il faut le stocker, le servir, s’assurer qu’il utilise le même tokenizer, et le mettre à jour quand le modèle cible change. Pour beaucoup d’équipes, cette complexité est un frein à l’adoption.
Medusa prend l’approche inverse : au lieu d’utiliser un modèle externe pour proposer des tokens, pourquoi ne pas étendre le modèle lui-même pour qu’il prédise plusieurs tokens à la fois ? C’est un retour à l’idée originale de Stern et al. (2018) dans « Blockwise Parallel Decoding for Deep Autoregressive Models », avec des ingrédients modernes qui la rendent pratique.
Comment fonctionnent les têtes Medusa
Architecture des têtes
Dans un LLM decoder-only standard, une seule « LM head » (couche de projection linéaire) convertit la représentation cachée du dernier bloc en distribution de probabilité sur le vocabulaire pour prédire le prochain token. Medusa ajoute K têtes supplémentaires, chacune responsable de prédire un token futur spécifique :
La tête Medusa 1 prédit le token à la position t+2 (le token d’après le prochain). La tête Medusa 2 prédit le token à la position t+3. La tête Medusa K prédit le token à la position t+K+1.
Chaque tête est un FFN simple : une seule couche de réseau feed-forward avec une connexion résiduelle. L’entrée est la représentation cachée de la dernière couche du modèle (la même entrée que la LM head standard). La tête prend cette représentation, la projette dans un espace intermédiaire, applique une activation, et la reprojette vers la taille du vocabulaire.
Attention en arbre (tree-structured verification)
Chaque tête Medusa peut proposer non pas un seul token mais ses top-K candidats (par exemple, les 3 tokens les plus probables). Avec K=5 têtes et 3 candidats par tête, on obtient un arbre de continuations possibles avec potentiellement des centaines de chemins.
Medusa vérifie toutes les branches de cet arbre en un seul passage du modèle cible, grâce à un masque d’attention spécialisé qui encode la structure arborescente. Chaque nœud de l’arbre ne « voit » que ses ancêtres dans l’arbre (pas les branches alternatives). Le modèle évalue la probabilité de chaque chemin et accepte le chemin le plus long dont tous les tokens sont valides.
L’attention en arbre est la vraie source du speedup de Medusa : au lieu de vérifier une seule séquence linéaire de K tokens, on vérifie un arbre entier de continuations, augmentant considérablement la probabilité qu’au moins un chemin long soit accepté.
Entraînement des têtes
L’entraînement de Medusa est remarquablement simple et peu coûteux :
Le modèle principal reste gelé. Seules les têtes Medusa sont entraînées. Le modèle de base ne change pas, préservant toutes ses capacités et ses propriétés d’alignement.
Données d’entraînement flexibles. On peut utiliser le même corpus qui a servi à entraîner le modèle original, ou générer un nouveau corpus en faisant tourner le modèle lui-même (données synthétiques auto-générées).
Convergence rapide. L’entraînement des têtes converge en quelques heures sur GPU, beaucoup plus vite que l’entraînement d’un modèle brouillon séparé. Les têtes Medusa ajoutent un nombre modeste de paramètres au modèle total.
Medusa-1 vs. Medusa-2
Medusa-1 entraîne les têtes avec le modèle principal gelé. C’est simple et rapide, mais les prédictions des têtes ne passent pas par un rejection sampling exact. La qualité de sortie est quasi-identique au modèle de base, avec de très légères différences possibles.
Medusa-2 introduit un mécanisme de vérification plus rigoureux, s’approchant du rejection sampling exact du speculative decoding classique. Les tokens sont acceptés avec une probabilité basée sur la correspondance entre la distribution de la tête Medusa et celle du modèle cible. La qualité est plus proche de la garantie exacte du speculative decoding, au prix d’un taux d’acceptation légèrement plus faible.
Medusa vs. EAGLE : les différences
| Critère | Medusa | EAGLE-3 |
|---|---|---|
| Architecture de la tête | FFN simple (1 couche) par position future | Couche Transformer complète + fusion multi-couches |
| Entrée de la tête | Features de la dernière couche uniquement | Features de plusieurs couches internes (fusion) |
| Mode de prédiction | Parallèle (chaque tête prédit indépendamment) | Autorégressif au niveau features (utilise le token précédemment prédit) |
| Robustesse aux erreurs | Limitée (têtes indépendantes) | Élevée (training-time test sur ses propres erreurs) |
| Taux d’acceptation | Bon | Meilleur (stable même aux positions éloignées) |
| Paramètres ajoutés | Très léger (~100M pour K=5 têtes) | Léger (~1-2 % du modèle) |
| Simplicité | Plus simple (FFN standard) | Plus complexe (couche Transformer + fusion) |
| Speedup | ~2× | 2-3× |
En résumé : Medusa est plus simple et plus rapide à déployer. EAGLE-3 est plus performant (meilleur taux d’acceptation, speedup supérieur) mais plus complexe. Pour un premier déploiement de parallel decoding, Medusa est un excellent point d’entrée. Pour maximiser les performances, EAGLE-3 est l’état de l’art.
Déployer Medusa en production
Medusa est supporté par les principaux frameworks de serving :
TensorRT-LLM (NVIDIA). Support natif de Medusa avec des kernels GPU optimisés. NVIDIA a publié des benchmarks montrant un speedup de 1,9× sur LLaMA 3.1 avec Medusa sur HGX H200 avec NVLink Switch. L’intégration est directe : il suffit de fournir les poids des têtes Medusa en plus du modèle principal.
vLLM. Support de Medusa via le module Speculators. Configuration simple : spécifier le chemin vers les poids Medusa dans les paramètres de serving.
Comparaison avec un modèle PyTorch standard. Medusa offre des gains particulièrement visibles quand le modèle principal n’est pas aussi optimisé que dans TensorRT-LLM. Sur une implémentation PyTorch standard, les gains relatifs de Medusa sont plus prononcés car le baseline est plus lent.
Limites de Medusa
Pas de garantie de qualité exacte (Medusa-1). Contrairement au speculative decoding classique avec rejection sampling, Medusa-1 ne garantit pas que les tokens acceptés suivent exactement la distribution du modèle cible. Medusa-2 améliore ce point mais avec un taux d’acceptation plus faible.
Taux d’acceptation inférieur à EAGLE-3. Les têtes Medusa, étant des FFN simples prédisant indépendamment, ont un taux d’acceptation qui décroît aux positions éloignées (la tête 5 est moins précise que la tête 1). EAGLE-3, avec sa prédiction autoréessive au niveau features et son training-time test, maintient un taux d’acceptation plus stable.
Overhead paramétrique. Bien que modeste, Medusa ajoute des paramètres (environ 100M pour Kangaroo’s comparison : 67M Medusa-1 pour un petit modèle vs. 591M pour des modèles plus grands). Pour des modèles très contraints en mémoire, chaque Mo compte.
Dépendance au modèle de base. Les têtes Medusa sont spécifiques à un modèle donné. Si vous mettez à jour le modèle de base (nouvelle version, nouveau fine-tuning), les têtes doivent être ré-entraînées. C’est rapide (quelques heures), mais c’est un pas de maintenance supplémentaire.
Verdict
Medusa est la technique de parallel decoding la plus accessible. Son architecture simple (des FFN attachés au modèle), son entraînement rapide (quelques heures, modèle gelé) et son intégration dans les frameworks majeurs (TRT-LLM, vLLM) en font un excellent premier pas pour accélérer l’inférence d’un LLM. Le speedup de ~2× sans modèle brouillon séparé est un rapport effort/gain difficile à battre.
EAGLE-3 le surpasse en performance pure (meilleur taux d’acceptation, speedup supérieur), mais au prix d’une complexité accrue. Pour les déploiements où la simplicité prime, Medusa reste pertinent. Pour les déploiements optimisés pour la latence minimale, EAGLE-3 et P-EAGLE sont désormais les meilleures options. Mais Medusa a le mérite historique d’avoir démocratisé le parallel decoding et rendu cette famille de techniques accessible à la communauté open-source.
Questions fréquentes sur Medusa
Medusa change-t-il les réponses du LLM ?
Avec Medusa-1, de très légères différences sont possibles car le mécanisme de vérification n’est pas un rejection sampling exact. En pratique, les utilisateurs ne perçoivent pas de différence de qualité. Medusa-2 se rapproche de la garantie exacte du speculative decoding classique. Pour les applications critiques nécessitant une qualité strictement identique, le speculative decoding avec rejection sampling exact est préférable.
Combien de temps faut-il pour entraîner les têtes Medusa ?
Quelques heures sur GPU pour un modèle de taille standard (7B-70B). L’entraînement est très efficient car seules les têtes (quelques dizaines à centaines de millions de paramètres) sont mises à jour, le modèle principal restant gelé. Le corpus d’entraînement peut être le même que celui du modèle original ou des données générées par le modèle lui-même.
Medusa est-il compatible avec la quantification ?
Oui. Le modèle principal peut être quantifié (INT4, INT8, FP8) normalement, et les têtes Medusa fonctionnent par-dessus. La quantification réduit la mémoire et la bande passante (accélérant chaque passe), tandis que Medusa réduit le nombre de passes. Les deux optimisations sont complémentaires et cumulables.
Peut-on utiliser Medusa sur un modèle MoE ?
En théorie, oui. Les têtes Medusa prennent en entrée la représentation cachée de la dernière couche, quelle que soit l’architecture sous-jacente (dense ou MoE). En pratique, les implémentations actuelles dans les frameworks ciblent principalement les modèles denses, mais l’extension aux MoE est techniquement directe.
Medusa ou speculative decoding : lequel déployer en premier ?
Si vous avez déjà un modèle brouillon compatible (même famille, plus petit) : le speculative decoding classique est le plus simple (pas d’entraînement requis, garantie de qualité exacte). Si vous n’avez pas de brouillon ou souhaitez éviter la complexité de deux modèles : Medusa est le meilleur point d’entrée (quelques heures d’entraînement des têtes, un seul modèle à servir). Pour maximiser la performance à terme : migrez vers EAGLE-3 ou P-EAGLE.