> [!info] Course code
> Use the companion repository for runnable notebooks, figures, and implementation references for this lecture:
> - Theory notebook: [notebooks/multi_head_attention/lecture_walkthrough.ipynb](https://github.com/Montekkundan/llm/blob/main/notebooks/multi_head_attention/lecture_walkthrough.ipynb)
> - Serious model anchor: [picollm/accelerated/gpt.py](https://github.com/Montekkundan/llm/blob/main/picollm/accelerated/gpt.py)
> - Runtime/kernel anchor: [picollm/accelerated/flash_attention.py](https://github.com/Montekkundan/llm/blob/main/picollm/accelerated/flash_attention.py)
## What This Concept Is
If one attention pattern is useful, several different attention patterns at the same time are even more useful. Multi-head attention lets the model look at the same sequence through several learned views in parallel instead of forcing every relationship into one single attention map.
A good way to picture it is this: the same sentence is on the table, but different heads can focus on different kinds of structure such as nearby syntax, longer-range references, or delimiter patterns.
## Foundation Terms You Need First
A **head** is one full attention pathway with its own projections. That means each head gets its own way of turning the input into queries, keys, and values. After the heads run in parallel, their outputs are joined together by **concatenation** and then mixed again through one **output projection**.
So the key distinction in this note is between one head's local view and the model's final combined view. Each head reasons in a narrower subspace, and the output projection folds those partial views back into one representation.
<video src="https://assets.montek.dev/lectures/media/llm/concepts/Multi-head%20Attention/01_parallel_attention_heads.mp4" controls></video>
```mermaid
flowchart TD
A["Input sequence X"] --> B["Project to Q, K, V"]
B --> C1["Head 1 attention"]
B --> C2["Head 2 attention"]
B --> C3["Head 3 attention"]
B --> C4["Head h attention"]
C1 --> D["Concatenate head outputs"]
C2 --> D
C3 --> D
C4 --> D
D --> E["Output projection W^O"]
```
## How this lecture maps to picoLLM
The notebook introduces the canonical multi-head construction first.
Then inspect `picollm/accelerated/gpt.py` and notice that the serious stack is more opinionated:
- grouped-query attention changes the KV side of the usual MHA story
- RoPE is applied before attention
- QK normalization and fused kernels matter for stability and speed
That is the intended progression:
- theory first
- picoLLM implementation second
- `rasbt` as clean concept-first external comparison
- `nanochat` as systems-first external comparison
This lecture covers three connected ideas:
- the multi-head derivation: the projection matrices, the head split, the concatenate-and-project step, and why the usual $d_k = d_{\text{model}} / h$ heuristic keeps parameter count and compute roughly stable[^1]
- what multi-head attention buys conceptually: mixtures of attention patterns, representational subspaces, head specialization, and the limits of “just add more heads”[^3]
- the modern implementation story: head redundancy, pruning, head collapse, multi-query and grouped-query attention, KV-cache pressure, and fused-kernel details[^4]
## Why multi-head attention exists
### The single-distribution bottleneck
A single [[Glossary#Attention head|attention head]] gives each query token one probability distribution over keys. That means one normalized weighting has to express every notion of relevance the model currently needs: subject agreement, object linking, clause tracking, positional structure, and anything else the layer wants to retrieve.
That is the core bottleneck multi-head attention is designed to relieve.
The original Transformer explicitly motivates multi-head attention as a way to let the model jointly attend to information from different representation subspaces at different positions.[^1]
The simplest lecture line is:
> One head gives one distribution. Many heads give many simultaneous alignment channels.
> [!question] Quick check
> Why can one attention distribution per token be a bottleneck?
>> [!answer] one distribution has to express every kind of relevance at once, such as syntax, entity links, modifier structure, and positional cues.
### Multi-head is parallelization of representational pathways
There is a subtle but important nuance if you want more depth: multi-head attention is not the only possible way for a model to attend to multiple things. In principle, depth can also build multiple attention behaviors across layers. Some empirical work even argues that part of MHA’s advantage is training stability and parallelization of focus, rather than an absolute expressive impossibility for single-head alternatives.[^5]
That makes a good conceptual framing:
Multi-head attention is one architectural way of turning “multiple relational focuses” into a first-class parallel computation pattern.
## The canonical formulation
### Per-head definition
The original Transformer defines multi-head attention as:
$
\operatorname{head}_i = \operatorname{Attention}(QW_i^Q, KW_i^K, VW_i^V)
$
and then:
$
\operatorname{MultiHead}(Q, K, V) = \operatorname{Concat}(\operatorname{head}_1, \ldots, \operatorname{head}_h)W^O
$
where each head has its own learned projections $W_i^Q$, $W_i^K$, and $W_i^V$, and the concatenated output is remixed by $W^O$.[^1]
This is the equation to be able to reconstruct from memory after the scaled dot-product attention lecture.
### Shapes and the standard heuristic
With:
- $d_{\text{model}}$ as total model width
- $h$ as number of heads
- $d_k = d_v = d_{\text{model}} / h$ in the standard setup
each head works in a lower-dimensional subspace, and the concatenation restores the original width before the output projection.[^1]
This is the standard “constant compute” heuristic used in the original Transformer. It is a design choice, not a law of nature, but it is one of the most important implementation conventions in modern transformer code.
### What modern code actually does
In practice, implementations rarely use one separate matrix multiply per head. They use large projection matrices:
- $W^Q: d_{\text{model}} \to h d_k$
- $W^K: d_{\text{model}} \to h d_k$
- $W^V: d_{\text{model}} \to h d_v$
Then the result is reshaped from $[B, L, h d_k]$ into $[B, h, L, d_k]$, attention is computed in batch over heads, and the result is merged back before applying $W^O$.
This “single matmul plus reshape” view is the bridge between the paper formula and real implementation.
> [!example] Code for this section
> - Notebook: [notebooks/multi_head_attention/lecture_walkthrough.ipynb](https://github.com/Montekkundan/llm/blob/main/notebooks/multi_head_attention/lecture_walkthrough.ipynb)
> [!tip] TensorTonic follow-up
> - [TensorTonic: Transformers Multi-Head Attention](https://www.tensortonic.com/research/transformer/transformers-multi-head-attention)
> Use it here to practice the split-head, per-head attention, and merge-back pattern from this section.
## Parameter count and compute
### Why parameter count is often independent of head count
One of the best lecture facts here is that, under the standard $d_k = d_v = d_{\text{model}} / h$ rule, the total projection parameter count is roughly independent of the number of heads.
Why?
- each head gets smaller as $h$ increases
- but there are more of them
- the total projected width remains about $d_{\text{model}}$
So the Q/K/V projection blocks together still cost about $3 d_{\text{model}}^2$, and the output projection costs about $d_{\text{model}}^2$, ignoring biases.[^1]
This is why increasing head count changes the factorization of computation more than the total parameter budget.
> [!question] Quick check
> If you increase head count while keeping d_model fixed and use d_k = d_model / h, does parameter count grow dramatically?
>> [!answer] not usually. Each head gets narrower as head count increases, so the total projection width stays about the same.
### Compute asymptotics
Under the standard setup, multi-head attention does not change the asymptotic $O(L^2 d_{\text{model}})$ cost of full [[Glossary#Self-attention|self-attention]]. The sequence-length quadratic term still dominates for long contexts.[^1]
So the practical effect of head count is mostly:
- changing constants
- changing subspace dimension per head
- changing memory layout and KV-cache structure
- changing inductive bias
This is why “more heads” is not a free lunch even if parameter count stays stable.
## Why too many heads can hurt
### Shrinking head dimension creates a bottleneck
The original Transformer reports an important empirical result: at fixed overall compute, single-head attention is worse than a good multi-head setting, but too many heads also hurts performance.[^1]
This makes intuitive sense once you notice what happens at fixed $d_{\text{model}}$:
- as $h$ increases
- each head dimension $d_k$ shrinks
- each head has less representational bandwidth
Low-rank bottleneck analyses formalize this more sharply: if head size becomes too small, representational rank constraints can hurt what the layer can express.[^3]
This gives a useful heuristic:
More heads help only if each head still has enough dimensional bandwidth to do meaningful work.
### Head count is not the same as representational capacity
It is easy to assume “more heads = more expressive.” That is only partially true. More heads give more parallel attention channels, but if each channel becomes too narrow, the model may gain routing diversity while losing subspace capacity.
That tradeoff is one of the central reasons head count should be treated as an architectural hyperparameter, not a monotonic quality knob.
<video src="https://assets.montek.dev/lectures/media/llm/concepts/Multi-head%20Attention/02_head_count_tradeoff.mp4" controls></video>
## Gradients and coupling across heads
### Heads are computed separately but trained jointly
Each attention head computes its own Q/K/V projections and its own attention matrix. That can make you think the heads are independent experts.
They are not.
The key reason is the output projection `W^O`. After head outputs are concatenated, `W^O` mixes channels from all heads jointly. That means the [[Glossary#Loss|loss]] couples head learning even if the attention computations themselves are parallel.
This is a very useful sentence:
> Heads are separable in forward structure, but coupled by training.
### Gradient intuition
For a single head, the gradients follow the same structure as scaled dot-product attention:
- gradients flow from the output back through the weighted sum with `V`
- then through [[Glossary#Softmax|softmax]]
- then through the scaled query-key score matrix
- then into the projection matrices
The $1 / \sqrt{d_k}$ term still matters because it controls gradient magnitudes through the [[Glossary#Logits|logits]], exactly as in the single-head derivation.[^1]
Once the head outputs are concatenated and passed through $W^O$, the gradient entering any one head depends on how the model uses all heads together.
That is the easiest way to explain why heads can specialize, compete, or collapse into similar behavior.
## What heads actually learn
### Specialization is real
Interpretability work on Transformer heads finds that some heads learn recognizable roles:
- fixed positional offsets
- delimiter and separator focus
- syntactic relations
- coreference-like links
- rare-word or alignment behaviors[^8][^8][^10]
This is one of the most compelling reasons MHA is a good lecture topic: it is one of the few places where the architecture’s internal structure often maps to human-interpretable circuits.
### Redundancy is also real
At the same time, pruning work shows that many trained heads can be removed with surprisingly small impact in some settings.[^4]
This creates an important tension:
- some heads are highly meaningful and important
- many others appear redundant
That tension is central to both interpretability and model compression research.
### Head collapse
Without explicit pressure for diversity, several heads can converge to very similar patterns. This is often described as head collapse or attention collapse.[^11]
This is a useful failure mode to name directly, because you will often look at a few heatmaps and say “all my heads look the same” without realizing that this is a known phenomenon rather than a bug unique to their code.
> [!question] Quick check
> How would you quantify head collapse instead of just eyeballing attention maps?
>> [!answer] compare heads with measures such as pairwise cosine similarity or KL-divergence between flattened attention patterns.
## Head management: pruning, regularization, and variants
### Head pruning
Michel et al. show that many heads can be pruned at test time with little performance degradation in several settings, while some attention types, especially encoder-decoder attention, appear more dependent on multi-headedness.[^4]
Voita et al. reach a related conclusion from a different angle: only a small subset of heads often does the heavy lifting, and those heads frequently have interpretable roles.[^10]
This gives a strong research-facing message:
Head importance is highly uneven.
### DropHead and diversity-promoting methods
DropHead regularizes attention by dropping whole heads during training, analogous to structured dropout at the head level.[^12] Other work explicitly promotes diversity or repulsion between heads to combat collapse.[^11]
These methods are useful to teach because they show that “multiple heads” alone does not guarantee multiple useful behaviors.
### Talking-Heads and cross-head interaction
Talking-Heads attention adds learned mixing across the head dimension before and after softmax, allowing information to move between heads at the score or probability stage.[^13]
That is a nice conceptual extension because it softens the strict separation between heads without abandoning the multi-head structure entirely.
### Mixture-of-head and routing variants
More recent lines of work treat heads almost like experts, with explicit gating, routing, or sparse selection. The general motivation is straightforward:
- head importance is heterogeneous
- some heads dominate
- routing and sparsity can exploit that uneven usefulness
Even if you do not go deep into these models, this is worth mentioning as a modern continuation of the head-pruning and head-specialization story.
## Inference efficiency: MHA, MQA, and GQA
### Why KV-cache becomes the bottleneck
In autoregressive decoding, the model reuses past keys and values through the [[Glossary#KV cache|KV cache]]. For standard MHA, every query head has its own set of key and value heads, so KV-cache size grows with head count.
That becomes a memory-bandwidth problem during decoding: each step must load many cached key/value tensors.
### Multi-query attention
Multi-query attention (MQA) shares keys and values across query heads, dramatically reducing KV-cache size and memory traffic during generation.[^14]
The tradeoff is:
- much better decoding efficiency
- potentially less representational flexibility than full MHA
### Grouped-query attention
Grouped-query attention (GQA) sits between MHA and MQA. Instead of one shared KV pair for all query heads, it groups query heads so subsets share keys and values.[^15]
This is a useful lecture comparison:
- MHA: one KV set per head
- MQA: one shared KV set
- GQA: a few shared KV groups
That makes GQA one of the most practical modern compromises for LLM inference.
<video src="https://assets.montek.dev/lectures/media/llm/concepts/Multi-head%20Attention/03_mha_mqa_gqa_comparison.mp4" controls></video>
## Implementation and systems realities
### Modern kernels do not implement MHA naively
High-performance implementations usually rely on fused scaled dot-product attention kernels rather than literal step-by-step matmul-softmax-matmul code. PyTorch’s `MultiheadAttention` and `scaled_dot_product_attention` APIs route to optimized backends when possible.[^16][^17]
This matters because MHA’s real runtime bottlenecks are:
- attention-matrix memory traffic
- softmax and masking overhead
- KV-cache bandwidth during decoding
### FlashAttention and exact efficiency
FlashAttention is a critical systems point because it shows that exact attention can be much faster if the computation is scheduled in an IO-aware way.[^18]
In multi-head settings, this is especially important because dense attention matrices exist per head. The memory savings compound quickly as sequence length and head count grow.
### Mask semantics and dropout bugs
PyTorch’s fused SDPA has subtle mask semantics, and its boolean mask conventions differ from some older APIs. It also requires explicit `dropout_p = 0.0` in eval if you do not want dropout applied.[^17]
This is one of those engineering facts worth saying plainly:
Your math can be correct and your implementation can still be wrong because of API semantics.
> [!example] Notebook walkthroughs in this lecture
>
> If you want to study this note in code, use these notebook sections. If the viewer ignores the fragment, search for the exact heading text in the notebook:
>
> - [`One big projection plus reshape gives many heads`](https://github.com/Montekkundan/llm/blob/main/notebooks/multi_head_attention/lecture_walkthrough.ipynb#one-big-projection-plus-reshape-gives-many-heads)
> - [`Multi-head attention end to end`](https://github.com/Montekkundan/llm/blob/main/notebooks/multi_head_attention/lecture_walkthrough.ipynb#multi-head-attention-end-to-end)
> - [`Parameter count stays roughly fixed at fixed model width`](https://github.com/Montekkundan/llm/blob/main/notebooks/multi_head_attention/lecture_walkthrough.ipynb#parameter-count-stays-roughly-fixed-at-fixed-model-width)
> - [`Many heads mean many attention distributions`](https://github.com/Montekkundan/llm/blob/main/notebooks/multi_head_attention/lecture_walkthrough.ipynb#many-heads-mean-many-attention-distributions)
> - [`MHA versus GQA versus MQA KV-cache size`](https://github.com/Montekkundan/llm/blob/main/notebooks/multi_head_attention/lecture_walkthrough.ipynb#mha-versus-gqa-versus-mqa-kv-cache-size)
>
> A useful study order is:
>
> 1. inspect the reshape-based implementation first
> 2. compare parameter count at fixed model width
> 3. look at head diversity and redundancy together
> 4. then connect MHA to MQA and GQA in the runtime story
>
> <video src="https://assets.montek.dev/lectures/media/llm/concepts/Multi-head%20Attention/04_head_specialization_and_coupling.mp4" controls></video>
> [!tip] TensorTonic practice for this lecture
>
> If you want to practice this lecture in a more implementation-focused format, work through these TensorTonic exercises:
>
> - [TensorTonic: Transformers Multi-Head Attention](https://www.tensortonic.com/research/transformer/transformers-multi-head-attention)
> - [TensorTonic: GPT-2 MHA](https://www.tensortonic.com/research/gpt2/gpt2-mha)
>
> They are good follow-ups because they make the head split-and-merge pattern explicit:
>
> - splitting one model width into multiple heads
> - running attention independently per head
> - concatenating head outputs back into one stream
> - comparing the textbook Transformer view with a GPT-style decoder implementation
<div style="display:flex; gap:1rem; margin:1.5rem 0; flex-wrap:wrap;">
<div style="flex:1; min-width:220px; border:1px solid var(--background-modifier-border); border-radius:12px; padding:1rem; background:var(--background-secondary);">
<div style="font-size:0.85em; color:var(--text-muted); margin-bottom:0.35rem;">Previous</div>
<div><a class="internal-link" data-href="Scaled Dot-Product Attention" href="Scaled%20Dot-Product%20Attention">Scaled Dot-Product Attention</a></div>
</div>
<div style="flex:1; min-width:220px; border:1px solid var(--background-modifier-border); border-radius:12px; padding:1rem; background:var(--background-secondary);">
<div style="font-size:0.85em; color:var(--text-muted); margin-bottom:0.35rem;">Next</div>
<div><a class="internal-link" data-href="Feed-Forward Network" href="Feed-Forward%20Network">Feed-Forward Network</a></div>
</div>
</div>
### References
[^1]: Ashish Vaswani et al., "Attention Is All You Need," 2017. https://papers.neurips.cc/paper/7181-attention-is-all-you-need.pdf
[^3]: Srinadh Bhojanapalli et al., "Low-Rank Bottleneck in Multi-head Attention Models," 2020. https://proceedings.mlr.press/v119/bhojanapalli20a/bhojanapalli20a.pdf
[^4]: Elena Voita, David Talbot, Fedor Moiseev, Rico Sennrich, and Ivan Titov, "Are Sixteen Heads Really Better than One?," 2019. https://papers.neurips.cc/paper/9551-are-sixteen-heads-really-better-than-one.pdf
[^5]: Liyuan Liu, Jialu Liu, and Jiawei Han, "Multi-head or Single-head? An Empirical Comparison for Transformer Training," 2021. https://arxiv.org/abs/2106.09650
[^8]: Kevin Clark et al., "What Does BERT Look at? An Analysis of BERT's Attention," 2019. https://www-nlp.stanford.edu/pubs/clark2019what.pdf
[^10]: Elena Voita, David Talbot, Fedor Moiseev, Rico Sennrich, and Ivan Titov, "Analyzing Multi-Head Self-Attention: Specialized Heads Do the Heavy Lifting, the Rest Can Be Pruned," 2019. https://arxiv.org/abs/1905.09418
[^11]: Bang An et al., "Repulsive Attention: Rethinking Multi-head Attention as Bayesian Inference," 2020. https://aclanthology.org/2020.emnlp-main.17.pdf
[^12]: Wangchunshu Zhou, Tao Ge, and Ke Xu, "Scheduled DropHead: A Regularization Method for Transformer Models," 2020. https://arxiv.org/abs/2004.13342
[^13]: Noam Shazeer, Zhenzhong Lan, Youlong Cheng, Nan Ding, and Le Hou, "Talking-Heads Attention," 2020. https://arxiv.org/abs/2003.02436
[^14]: Noam Shazeer, "Fast Transformer Decoding: One Write-Head is All You Need," 2019. https://arxiv.org/abs/1911.02150
[^15]: Joshua Ainslie et al., "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints," 2023. https://aclanthology.org/2023.emnlp-main.298.pdf
[^16]: PyTorch, "torch.nn.MultiheadAttention," 2025. https://docs.pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html
[^17]: PyTorch, "torch.nn.functional.scaled_dot_product_attention," 2025. https://docs.pytorch.org/docs/stable/generated/torch.nn.functional.scaled_dot_product_attention.html
[^18]: Tri Dao et al., "FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness," 2022. https://openreview.net/forum?id=H4DqfPSibmx