This is the paper that proposed the Multi-Query Attention (MQA). The author of is from Google and the idea was explained in detail using TensorFlow code. Firstly, the traditional dot-product attention (single head) is like this:

def DotProductAttention(q, K, V):
    """Dot−Product Attention on one query.
    Using einsum() for generalized contractions

    Args:
        q: a vector with shape [k]
        K: a matrix with shape [m, k]
        V: a matrix with shape [m, v]

    Returns:
        y: a vector with shape [v]
    """
    logits = tf.einsum("k,mk−>m", q, K)
    weights = tf.softmax(logits)
    return tf.einsum("m,mv−>v", weights, V)

Converting this into multi-head attention (as in the transformer paper) is to run the above in parallel. Indeed, the attention input $Q$ can be projected from the input vector $x$, with $P_q$ from $h$ different learned FFNs. Similarly, sequence $M$ of length $m$ (e.g., $x$ or the encoder output in case of cross-attention) projected with $P_k,P_v$, also from $h$ different learned FFNs to form $K,V$. Output vector of the $h$ layers are projected into $P_o$ then summed in the transformer block.

def MultiHeadAttention(x, M, P_q, P_k, P_v, P_o):
    """Multi-Head Attention on one query.

    Args:
        x: a vector with shape [d]
        M: a matrix with shape [m, d]
        P_q: a tensor with shape [h, d, k]
        P_k: a tensor with shape [h, d, k]
        P_v: a tensor with shape [h, d, v]
        P_o: a tensor with shape [h, d, v]

    Returns:
        y: a vector with shape [d]
    """
    q = tf.einsum("d,hdk->hk", x, P_q)
    K = tf.einsum("md,hdk->hmk", M, P_k)
    V = tf.einsum("md,hdv->hmv", M, P_v)
    logits = tf.einsum("hk,hmk->hm", q, K)  # for simplicity, without scaling
    weights = tf.nn.softmax(logits)
    o = tf.einsum("hm,hmv->hv", weights, V)
    y = tf.einsum("hv,hdv->d", o, P_o)
    return y

Batched multi-head attention

Above concerns only attention applied to a single sequence. It can be extended to a batch of multiple sequences:

  • Generate queries $Q$ from $n$ different positions in a sequence (i.e., $X$ is not a vector but a sequence of $n$ vectors), all interact with the same keys $K$ and values $V$.
  • Process a batch of $b$ different, non-interacting sequences at once
  • In below, mask is used, with $-\infty$ on illegal positions
def MultiHeadAttentionBatched(X, M, mask, P_q, P_k, P_v, P_o):
    """Multi-Head Attention.

    Args:
        X: a tensor with shape [b, n, d]
        M: a tensor with shape [b, m, d]
        mask: a tensor with shape [b, h, n, m]
        P_q: a tensor with shape [h, d, k]
        P_k: a tensor with shape [h, d, k]
        P_v: a tensor with shape [h, d, v]
        P_o: a tensor with shape [h, d, v]

    Returns:
        Y: a tensor with shape [b, n, d]
    """
    Q = tf.einsum("bnd,hdk->bhnk", X, P_q)
    K = tf.einsum("bmd,hdk->bhmk", M, P_k)
    V = tf.einsum("bmd,hdv->bhmv", M, P_v)
    logits = tf.einsum("bhnk,bhmk->bhnm", Q, K)
    weights = tf.nn.softmax(logits + mask)
    O = tf.einsum("bhnm,bhmv->bhnv", weights, V)
    Y = tf.einsum("bhnv,hdv->bnd", O, P_o)
    return Y

In the paper, for the sake of performance analysis, some simplifications assumed:

  • Let $m=n < d$ (context length for $Q,K,V$ identical)
  • Make $k=v=d/h$ as suggested by Vaswani et al (2017), i.e., $d\gg k=v$

Then, each einsum() has complexity $O(bnd^2)$. The total number of arithmetic operation is $\Theta(bnd^2)$.

Total memory accessed is the sum of all tensors involved:

  • $X,M,Q,K,V,O,Y$ are $O(bnd)$
  • logits and weights are $O(bhn^2)$
  • projection tensors $P_q,P_k,P_v,P_o$ are $O(d^2)$
  • total memory complexity is $O(bnd+bhn^2+d^2)$
  • ratio of memory access to arithmetic operation is $O(\frac{1}{k}+\frac{1}{bn})$
    • low ratio = good performance on GPU/TPU
    • GPU computational capacity = 100x higher than bandwidth

Incremental multi-head attention

For example, in autoregression generation, tokens are generated sequentially. The data dependency prevents parallelism. But the previously projected key $K$ and values $V$ can be reused for the generation of next token. Hence, only the newly generated input $X$ needs projection. Much less computation is involved if the context length $m$ is significantly large. Below is how the projected tensors $K$ and $V$ are built incrementally from the one used to generate the previous token:

def MultiHeadSelfAttentionIncremental(x, prev_K, prev_V, P_q, P_k, P_v, P_o):
    """Multi-Head Self-Attention (one step).

    Args:
    x: a tensor with shape [b, d]
    prev_K: tensor with shape [b, h, m, k]
    prev_V: tensor with shape [b, h, m, v]
    P_q: a tensor with shape [h, d, k]
    P_k: a tensor with shape [h, d, k]
    P_v: a tensor with shape [h, d, v]
    P_o: a tensor with shape [h, d, v]

    Returns:
    y: a tensor with shape [b, d]
    new_K: tensor with shape [b, h, m+1, k]
    new_V: tensor with shape [b, h, m+1, v]
    """
    M = x  # assuming M is x
    q = tf.einsum("bd,hdk->bhk", x, P_q)
    new_K = tf.concat([prev_K,
                       tf.expand_dims(tf.einsum("bd,hdk->bhk", M, P_k), axis=2)],
                      axis=2)
    new_V = tf.concat([prev_V,
                       tf.expand_dims(tf.einsum("bd,hdv->bhv", M, P_v), axis=2)],
                      axis=2)
    logits = tf.einsum("bhk,bhmk->bhm", q, new_K)
    weights = tf.nn.softmax(logits)
    o = tf.einsum("bhm,bhmv->bhv", weights, new_V)
    y = tf.einsum("bhv,hdv->bd", o, P_o)  # note the mistake in variable name 'o' which was transformed to 'O'
    return y, new_K, new_V
  • note that no more dimension “$n$” in the einsum() since all projection and dot product are for incrementally one token at a time
  • the tensors $K,V$ increases by one in the sequence length dimension each time running this attention function

Performance:

  • Assume generating $n$ tokens, calling the attention function $n$ times, the total number of arithmetic operation is $\Theta(bnd^2)$ (same)
  • Memory access due to $K,V$ is $\Theta(bn^2d)$
  • Memory access due to $P_q,P_k,P_v,P_o$ is $\Theta(nd^2)$
  • Total number of memory access is $\Theta(bn^2d + nd^2)$
  • Ratio of memory access to arithmetic operation: $\Theta(\frac{n}{d}+\frac{1}{b})$
    • to make it efficient, we need $n\ll d$ (limiting sequence length, i.e., context size) and $b\gg 1$ (larger batch)
    • the $\frac{n}{d}$ term relates to reloading the $K,V$ tensors at each step

Multi-Query Attention

Same as multi-head attention, except that the different head share a single set of keys and values. The batched version:

def MultiQueryAttentionBatched(X, M, mask, P_q, P_k, P_v, P_o):
    """Multi-Query Attention.

    Args:
        X: a tensor with shape [b, n, d]
        M: a tensor with shape [b, m, d]
        mask: a tensor with shape [b, h, n, m]
        P_q: a tensor with shape [h, d, k]
        P_k: a tensor with shape [d, k]
        P_v: a tensor with shape [d, v]
        P_o: a tensor with shape [h, d, v]

    Returns:
        Y: a tensor with shape [b, n, d]
    """
    Q = tf.einsum("bnd,hdk->bhnk", X, P_q)
    K = tf.einsum("bmd,dk->bmk", M, P_k)
    V = tf.einsum("bmd,dv->bmv", M, P_v)
    logits = tf.einsum("bhnk,bmk->bhnm", Q, K)
    weights = tf.nn.softmax(logits + mask)  # moved mask addition here
    O = tf.einsum("bhnm,bmv->bhnv", weights, V)
    Y = tf.einsum("bhnv,hdv->bnd", O, P_o)
    return Y

and the incremental version:

def MultiQuerySelfAttentionIncremental(x, prev_K, prev_V, P_q, P_k, P_v, P_o):
    """Multi-Query Self-Attention (one step).

    Args:
        x: a tensor with shape [b, d]
        prev_K: tensor with shape [b, m, k]
        prev_V: tensor with shape [b, m, v]
        P_q: a tensor with shape [h, d, k]
        P_k: a tensor with shape [d, k]
        P_v: a tensor with shape [d, v]
        P_o: a tensor with shape [h, d, v]

    Returns:
        y: a tensor with shape [b, d]
        new_K: tensor with shape [b, m+1, k]
        new_V: tensor with shape [b, m+1, v]
    """
    M = x  # assuming M is x
    q = tf.einsum("bd,hdk->bhk", x, P_q)
    K = tf.concat([prev_K,
                   tf.expand_dims(tf.einsum("bd,dk->bk", M, P_k), axis=-2)],
                  axis=-2)
    V = tf.concat([prev_V,
                   tf.expand_dims(tf.einsum("bd,dv->bv", M, P_v), axis=-2)],
                  axis=-2)
    logits = tf.einsum("bhk,bmk->bhm", q, K)
    weights = tf.nn.softmax(logits)
    o = tf.einsum("bhm,bmv->bhm", weights, V)
    y = tf.einsum("bhv,hdv->bd", o, P_o)
    return y, K, V

Note that no “$h$” dimension used in einsum() functions since $K,V$ are shared among all heads. Considering the case of calling incremental multi-query attention for $n$ times:

  • the total number of arithmetic operations is $\Theta(bnd^2)$ (same)
  • total memory access due to $x,q,o,y$ is $\Theta(bnd)$
  • total memory access due to $K,V$ is $\Theta(bn^2k)$
  • total memory access due to $P_q,P_k,P_v,P_o$ is $\Theta(nd^2)$
  • total amount of memory access is $\Theta(bnd+bn^2k+nd^2)$
  • Ratio of memory access to arithmetic operations is $\Theta(\frac{1}{d} + \frac{n}{dh} + \frac{1}{b})$
    • reduced $\frac{n}{d}$ by a factor of $h$

Bibliographic data

@unpublished{
   title = "Fast Transformer Decoding. One Write-Head is All You Need",
   author = "Noam Shazeer",
   year = "2019",
   arXiv = "1911.0215",
}