This is the paper that proposed the Grouped-Query Attention (GQA). While MQA, by sharing the $K,V$ tensors in attention, speeds up the decoder inference, it is found to degrade in quality. GQA is a generalization of MQA. Just like group norm is a generalization between instance norm and layer norm, GQA is between MHA and MQA in the sense that queries are grouped and let each group share the same $K,V$ tensor. So with $h$ heads, MQA has one $K,V$, MHA has $h$, but GQA has $g$ which $1\le g\le h$. The paper also proposed a method to uptrain an existing multi-head attention layer into MQA.

The paper has the following figure for GQA:

As attention is $\textrm{softmax}(QK^\top)V$, the queries $Q$ (in blue) are one for each head (i.e., in case of self-attention, projected from the same input $X$ but using different projection matrices). The key $K$ (red) and value $V$ (orange) can be a one-to-one mapping to the queries (in case of multi-head attention), all-to-one (in case of multi-query attention), or grouped in $G$ groups (in GQA). GQA is a trade-off between MHA and MQA.

Uptraining is to generate MQA model from MHA. Takes the $h$ projection matrices for $K$ and $V$ from MHA and mean-pooled into single projection matrices, then rerun a small portion ($\alpha$) of the original pre-training steps to fine-tune.

This paper experimented this on T5-XXL, which was a MHA model. Uptraining a MQA version with $\alpha=0.05$ lowered the quality a bit but cuts the generation time in less than 1/3. GQA with $G=8$ gives a good trade-off in which the quality is very close to MHA but the speed is very close to GQA.

Bibliographic data

@inproceedings{
   title = "GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints",
   author = "Joshua Ainslie and James Lee-Thorp and Michiel de Jong and Yury Zemlyanskiy and Federico Lebrón and Sumit Sanghai",
   booktitle = "In Proceedings of Conference on Empirical Methods in Natural Language Processing (EMNLP)",
   month = "Dec",
   year = "2023",
   address = "Singapore",
   arXiv = "2305.13245",
}