This is the paper that proposed the Multi-Query Attention (MQA). The author of is from Google and the idea was explained in detail using TensorFlow code. Firstly, the traditional dot-product attention (single head) is like this:
def DotProductAttention(q, K, V):
"""Dot−Product Attention on one query.
Using einsum() for generalized contractions
Args:
q: a vector with shape [k]
K: a matrix with shape [m, k]
V: a matrix with shape [m, v]
Returns:
y: a vector with shape [v]
"""
logits = tf.einsum("k,mk−>m", q, K)
weights = tf.softmax(logits)
return tf.einsum("m,mv−>v", weights, V)
Converting this into multi-head attention (as in the transformer paper) is to run the above in parallel. Indeed, the attention input $Q$ can be projected from the input vector $x$, with $P_q$ from $h$ different learned FFNs. Similarly, sequence $M$ of length $m$ (e.g., $x$ or the encoder output in case of cross-attention) projected with $P_k,P_v$, also from $h$ different learned FFNs to form $K,V$. Output vector of the $h$ layers are projected into $P_o$ then summed in the transformer block.
def MultiHeadAttention(x, M, P_q, P_k, P_v, P_o):
"""Multi-Head Attention on one query.
Args:
x: a vector with shape [d]
M: a matrix with shape [m, d]
P_q: a tensor with shape [h, d, k]
P_k: a tensor with shape [h, d, k]
P_v: a tensor with shape [h, d, v]
P_o: a tensor with shape [h, d, v]
Returns:
y: a vector with shape [d]
"""
q = tf.einsum("d,hdk->hk", x, P_q)
K = tf.einsum("md,hdk->hmk", M, P_k)
V = tf.einsum("md,hdv->hmv", M, P_v)
logits = tf.einsum("hk,hmk->hm", q, K) # for simplicity, without scaling
weights = tf.nn.softmax(logits)
o = tf.einsum("hm,hmv->hv", weights, V)
y = tf.einsum("hv,hdv->d", o, P_o)
return y
Batched multi-head attention
Above concerns only attention applied to a single sequence. It can be extended to a batch of multiple sequences:
- Generate queries $Q$ from $n$ different positions in a sequence (i.e., $X$ is not a vector but a sequence of $n$ vectors), all interact with the same keys $K$ and values $V$.
- Process a batch of $b$ different, non-interacting sequences at once
- In below, mask is used, with $-\infty$ on illegal positions
def MultiHeadAttentionBatched(X, M, mask, P_q, P_k, P_v, P_o):
"""Multi-Head Attention.
Args:
X: a tensor with shape [b, n, d]
M: a tensor with shape [b, m, d]
mask: a tensor with shape [b, h, n, m]
P_q: a tensor with shape [h, d, k]
P_k: a tensor with shape [h, d, k]
P_v: a tensor with shape [h, d, v]
P_o: a tensor with shape [h, d, v]
Returns:
Y: a tensor with shape [b, n, d]
"""
Q = tf.einsum("bnd,hdk->bhnk", X, P_q)
K = tf.einsum("bmd,hdk->bhmk", M, P_k)
V = tf.einsum("bmd,hdv->bhmv", M, P_v)
logits = tf.einsum("bhnk,bhmk->bhnm", Q, K)
weights = tf.nn.softmax(logits + mask)
O = tf.einsum("bhnm,bhmv->bhnv", weights, V)
Y = tf.einsum("bhnv,hdv->bnd", O, P_o)
return Y
In the paper, for the sake of performance analysis, some simplifications assumed:
- Let $m=n < d$ (context length for $Q,K,V$ identical)
- Make $k=v=d/h$ as suggested by Vaswani et al (2017), i.e., $d\gg k=v$
Then, each einsum()
has complexity $O(bnd^2)$. The total number of arithmetic operation is $\Theta(bnd^2)$.
Total memory accessed is the sum of all tensors involved:
- $X,M,Q,K,V,O,Y$ are $O(bnd)$
- logits and weights are $O(bhn^2)$
- projection tensors $P_q,P_k,P_v,P_o$ are $O(d^2)$
- total memory complexity is $O(bnd+bhn^2+d^2)$
- ratio of memory access to arithmetic operation is $O(\frac{1}{k}+\frac{1}{bn})$
- low ratio = good performance on GPU/TPU
- GPU computational capacity = 100x higher than bandwidth
Incremental multi-head attention
For example, in autoregression generation, tokens are generated sequentially. The data dependency prevents parallelism. But the previously projected key $K$ and values $V$ can be reused for the generation of next token. Hence, only the newly generated input $X$ needs projection. Much less computation is involved if the context length $m$ is significantly large. Below is how the projected tensors $K$ and $V$ are built incrementally from the one used to generate the previous token:
def MultiHeadSelfAttentionIncremental(x, prev_K, prev_V, P_q, P_k, P_v, P_o):
"""Multi-Head Self-Attention (one step).
Args:
x: a tensor with shape [b, d]
prev_K: tensor with shape [b, h, m, k]
prev_V: tensor with shape [b, h, m, v]
P_q: a tensor with shape [h, d, k]
P_k: a tensor with shape [h, d, k]
P_v: a tensor with shape [h, d, v]
P_o: a tensor with shape [h, d, v]
Returns:
y: a tensor with shape [b, d]
new_K: tensor with shape [b, h, m+1, k]
new_V: tensor with shape [b, h, m+1, v]
"""
M = x # assuming M is x
q = tf.einsum("bd,hdk->bhk", x, P_q)
new_K = tf.concat([prev_K,
tf.expand_dims(tf.einsum("bd,hdk->bhk", M, P_k), axis=2)],
axis=2)
new_V = tf.concat([prev_V,
tf.expand_dims(tf.einsum("bd,hdv->bhv", M, P_v), axis=2)],
axis=2)
logits = tf.einsum("bhk,bhmk->bhm", q, new_K)
weights = tf.nn.softmax(logits)
o = tf.einsum("bhm,bhmv->bhv", weights, new_V)
y = tf.einsum("bhv,hdv->bd", o, P_o) # note the mistake in variable name 'o' which was transformed to 'O'
return y, new_K, new_V
- note that no more dimension “$n$” in the
einsum()
since all projection and dot product are for incrementally one token at a time - the tensors $K,V$ increases by one in the sequence length dimension each time running this attention function
Performance:
- Assume generating $n$ tokens, calling the attention function $n$ times, the total number of arithmetic operation is $\Theta(bnd^2)$ (same)
- Memory access due to $K,V$ is $\Theta(bn^2d)$
- Memory access due to $P_q,P_k,P_v,P_o$ is $\Theta(nd^2)$
- Total number of memory access is $\Theta(bn^2d + nd^2)$
- Ratio of memory access to arithmetic operation: $\Theta(\frac{n}{d}+\frac{1}{b})$
- to make it efficient, we need $n\ll d$ (limiting sequence length, i.e., context size) and $b\gg 1$ (larger batch)
- the $\frac{n}{d}$ term relates to reloading the $K,V$ tensors at each step
Multi-Query Attention
Same as multi-head attention, except that the different head share a single set of keys and values. The batched version:
def MultiQueryAttentionBatched(X, M, mask, P_q, P_k, P_v, P_o):
"""Multi-Query Attention.
Args:
X: a tensor with shape [b, n, d]
M: a tensor with shape [b, m, d]
mask: a tensor with shape [b, h, n, m]
P_q: a tensor with shape [h, d, k]
P_k: a tensor with shape [d, k]
P_v: a tensor with shape [d, v]
P_o: a tensor with shape [h, d, v]
Returns:
Y: a tensor with shape [b, n, d]
"""
Q = tf.einsum("bnd,hdk->bhnk", X, P_q)
K = tf.einsum("bmd,dk->bmk", M, P_k)
V = tf.einsum("bmd,dv->bmv", M, P_v)
logits = tf.einsum("bhnk,bmk->bhnm", Q, K)
weights = tf.nn.softmax(logits + mask) # moved mask addition here
O = tf.einsum("bhnm,bmv->bhnv", weights, V)
Y = tf.einsum("bhnv,hdv->bnd", O, P_o)
return Y
and the incremental version:
def MultiQuerySelfAttentionIncremental(x, prev_K, prev_V, P_q, P_k, P_v, P_o):
"""Multi-Query Self-Attention (one step).
Args:
x: a tensor with shape [b, d]
prev_K: tensor with shape [b, m, k]
prev_V: tensor with shape [b, m, v]
P_q: a tensor with shape [h, d, k]
P_k: a tensor with shape [d, k]
P_v: a tensor with shape [d, v]
P_o: a tensor with shape [h, d, v]
Returns:
y: a tensor with shape [b, d]
new_K: tensor with shape [b, m+1, k]
new_V: tensor with shape [b, m+1, v]
"""
M = x # assuming M is x
q = tf.einsum("bd,hdk->bhk", x, P_q)
K = tf.concat([prev_K,
tf.expand_dims(tf.einsum("bd,dk->bk", M, P_k), axis=-2)],
axis=-2)
V = tf.concat([prev_V,
tf.expand_dims(tf.einsum("bd,dv->bv", M, P_v), axis=-2)],
axis=-2)
logits = tf.einsum("bhk,bmk->bhm", q, K)
weights = tf.nn.softmax(logits)
o = tf.einsum("bhm,bmv->bhm", weights, V)
y = tf.einsum("bhv,hdv->bd", o, P_o)
return y, K, V
Note that no “$h$” dimension used in einsum()
functions since $K,V$ are
shared among all heads. Considering the case of calling incremental multi-query
attention for $n$ times:
- the total number of arithmetic operations is $\Theta(bnd^2)$ (same)
- total memory access due to $x,q,o,y$ is $\Theta(bnd)$
- total memory access due to $K,V$ is $\Theta(bn^2k)$
- total memory access due to $P_q,P_k,P_v,P_o$ is $\Theta(nd^2)$
- total amount of memory access is $\Theta(bnd+bn^2k+nd^2)$
- Ratio of memory access to arithmetic operations is $\Theta(\frac{1}{d} + \frac{n}{dh} + \frac{1}{b})$
- reduced $\frac{n}{d}$ by a factor of $h$
Bibliographic data
@unpublished{
title = "Fast Transformer Decoding. One Write-Head is All You Need",
author = "Noam Shazeer",
year = "2019",
arXiv = "1911.0215",
}