The LLaMA and LLaMA 2 model released by Meta/Facebook is available on GitHub and there’s a guide to help you using it. Of course, from Meta, this model is using PyTorch. But surprisingly, the repo on GitHub is very short that you can read and understand it in a day or two.

The structure of the repo is as follows. It is the LLaMA 2 model in its main branch. The older version is moved to another branch:


This is just the model code. To use it, such as to pretrain it from scratch (if you have such resources) or fine-tune it, you need to look into the scripts on the repo llama-recipes.


All language model for text should start with a tokenizer that breaks a string into tokens. In LLaMA 2, the defines the class Tokenizer with the encode() and decode() methods.

The LLaMA tokenizer uses SentencePiece model, which is based on BPE and Unigram. It added the BOS, EOS, and PAD as special tokens. The tokens are numbered as integer, which the encode and decode functions are to convert between a string and a list of such integers.


There are several building block functions defined in the

The static one is the dataclass ModelArgs, which holds the parameters for inference or construction of the model. The default are:

  • dim = 4096
  • n_heads = n_layers = 32
  • max_seq_len = 2048, but this number is doubled in the model, hence effective sequence length limit is 4096

The building block functions are described one by one as follows:

Position encoding

The LLaMA model uses RoPE as the position encoding. Essentially it performs complex multiplication $xe^{it}$ for input $x$ and position $t$. The $e^{it}$ part is pre-computed and cached, up to the max sequence length, in the following function:

def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
    freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
    t = torch.arange(end, device=freqs.device)
    freqs = torch.outer(t, freqs).float()
    freqs_cis = torch.polar(torch.ones_like(freqs), freqs)  # complex64
    return freqs_cis

which the function torch.polar(x,t) computes $xe^{it}$, and $x=1$ here. The output of this function is a tensor of shape $T\times D$. When it is used, the input tensor is usually in the shape of $N\times T\times d$ for a batch of $N$ sequences. To perform element-wise multiplication of the position encoding tensor and the input tensor, the following function is used to reshape the position encoding tensor to fit the input tensor using view():

def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
    ndim = x.ndim
    assert 0 <= 1 < ndim
    assert freqs_cis.shape == (x.shape[1], x.shape[-1])
    shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
    return freqs_cis.view(*shape)

In most cases, the above function is to add a new dimension as the first axis to the position encoding. The actual rotary embedding operation is defined in another function:

def apply_rotary_emb(
    xq: torch.Tensor,
    xk: torch.Tensor,
    freqs_cis: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2))
    xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2))
    freqs_cis = reshape_for_broadcast(freqs_cis, xq_)
    xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3)
    xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3)
    return xq_out.type_as(xq), xk_out.type_as(xk)

The first two lines are to transform xq and xk from shape $(T,D)$ into shape $(T,\frac{D}{2},2)$. Then the third line transforms the frequency tensor, and finally xq_out is defined as the original xq_ elementwise multiply with the position encoding, and take only the real part, and then reshaped back into shape $(T,D)$.

The rotary embedding (RoPE) is explained as follows: The $d$-dimensional input vector $x$ is rearranged as $d/2$ pairs $(x_m^{(1)},x_m^{(2)})$. Here $m=1,\dots,T$ denotes the sequence position. Consider this as a coordinate pair in 2D plane, the transformation is to rotate with a constant angle $\theta$:

\[\textrm{RoPE}(x_m^{(1)},x_m^{(2)},m) = \begin{pmatrix}\cos m\theta & -\sin m\theta\\ \sin m\theta & \cos m\theta \end{pmatrix}\begin{pmatrix}x_m^{(1)}\\ x_m^{(2)}\end{pmatrix}\]

The transformed output fulfills:

\[\begin{aligned} & \langle \text{RoPE}(x_m^{(1)},x_m^{(2)},m), \text{RoPE}(x_m^{(1)},x_m^{(2)},n)\rangle\\ = & \langle \text{RoPE}(x_m^{(1)},x_m^{(2)},m-n), \text{RoPE}(x_m^{(1)},x_m^{(2)},0)\rangle \end{aligned}\]

which means the dot product is relative. In the implementation, we can consider $x$ at position $m$ is encoded as $xe^{im\epsilon}$ for some $\epsilon\in(0,\frac{\pi}{2N}]$. So the features $1,\dots,d$ will have feature $i$ pair up with feature $i+d/2$, and using angle $\theta_i = 10000^{-2(i-1)/d}$. The inner product above $\langle x, y\rangle = x\cdot \bar{y}$ is using complex inner product.

Lastly, to help multihead attention, there is a function to replicate the input tensor x for n_rep times on the third dimension:

def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    bs, slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return (
        x[:, :, :, None, :]
        .expand(bs, slen, n_kv_heads, n_rep, head_dim)
        .reshape(bs, slen, n_kv_heads * n_rep, head_dim)

At input, the shape of x is (bs, slen, n_kv_heads, head_dim) and at output, it becomes shape (bs, slen, n_rep*n_kv_heads, head_dim). The replication is done as a view hence no new memory allocation.

Transformer building blocks

The overall transformer model is defined in the class Transformer, which has the following workflow:

  1. The input tokens is a tensor of token ids (shape of batch size $N\times$ sequence length $L$) and start_pos as integer
    • the parameter start_pos is for caching, useful in case of distributed model
  2. Convert tokens into embedding h using ParallelEmbedding module from fairscale (an PyTorch extension). The output tensor h is of dimension $d$
  3. Extract pre-computed encoding freqs_cis at position range start_pos:start_pos+seq_len to match the input
  4. Create a mask of size $L\times L$ such that its upper triangular elements above offset start_pos are all $-\infty$ and all other values are zero. This is the mask for causal inference.
  5. There is a sequence of TransformerBlock. Each block transforms h, with the other parameters: start_pos, freqs_cis, and mask.
  6. Process the final h with RMSNorm
  7. Apply ColumnParallelLinear from fairscale with no activation function, which
    • the input h is a sequence of vectors of hidden dimension $d$
    • the output is a sequence of logits of the length of the tokenizer vocabulary size

Fairscale is an alternative to torch.nn.parallel.DistributedDataParallel (DDP) to do sharding on data, model, and optimizer.

The building blocks to support the overall transformer architecture is as follows:

The class RMSNorm calculates

\[x := \frac{x}{\sqrt{\bar{x^2} + \epsilon}}\]

where the mean is applied on the square $x^2$ along the last dimension (i.e., embedding dimension). This is implementing the layer normalization.

The class Attention: This is where the parallelism applied in case of distributed model execution

  • defined weight matrices self.wq, self.wk, self.wv to multiply with x
    • for the embedding dimension $d$ and number of query heads $H$, the attention dimension is $d_H = d/H$
    • weight matrix $W^Q$ is of shape $d\times d$ but matrices $W^K,W^V$ are of shape $d\times H_{kv}d_H$, where $H_{kv}$ not necessarily equals to $H$
  • defined weight matrix $W^O$ as self.wo for output, as shape $d\times d$
    • same shape as $W^Q$, to make the output shape match the input shape
  • defined self.cache_k and self.cache_v as tensors of shape $(N,T,H_{kv},d_H)$ for batch size $N$, sequence length $T$, number of key/value head $H_{kv}$, and attention dimension $d_H$
  • the workflow: takes input x, start_pos, freqs_cis, and mask
    1. input tensor of shape $(N,T,d)$ is transformed into $X_Q=W^QX, X_K=W^KX, X_V=W^VX$
    2. reshape $X_Q$ into shape $(N,T,H,d_H)$ and reshape $X_K,X_V$ into shape $(N,T,H_{kv},d_H)$
    3. apply rotary embedding on $X_Q,X_K$ with the tensor freqs_cis
    4. save $X_K,X_V$ on CUDA device and then replicate them for $n=H/H_{kv}$ times
    5. transform $X_Q$ and the replicated $X_K,X_V$ into shape $(N,H,T,d_H)$. At this point, all three tensors are in the same shape
    6. calculate the attention score $A=S(X_QX_K^\top / \sqrt{d_H} + M)$, where
      • the inner product $X_QX_K^\top$ is multiplying along the $d_H$ dimension using torch.matmul()
      • function $S(\cdot)$ is a softmax function, applied on the $T$ dimension of tensor $K$
      • the mask $M$ is optional
      • the argument to $S(\cdot)$ is of shape $(N,H,T,T)$, so as the output tensor $A$
    7. then the output is calculated as $X_O = AX_V$
      • tensor $A$ has shape $(N,H,T,T)$ and $X_V$ has shape $(N,H,T,d_H)$, the multiplication is along the last $T$ dimension of $A$ and the second last $T$ dimension of $X_V$
      • multiplication done by torch.matmul(A,Xv) and the output has shape $(N,H,T,d_H)$, which is then transposed into $(N,T,H,d_H)$. The resulting dimension $T$ is from a 1-to-1 matching between $X_Q$ and $X_V$
    8. Final output: $O = W^O X_O$, to bring back a tensor of dimension $d$

The class FeedForward: This is just a fully-connected layer used by TransformerBlock that computes

\[Y := W_2(s(W_1 X)\odot W_3 X)\]


  • weight matrices $W_1,W_3$ are of shape $(d_{in},d_h)$ and matrix $W_2$ is of shape $(d_h,d_{in})$
  • $d_h$ is specified in the TransformerBlock as $4d_{in}$, but then adjusted to $\lfloor\frac23 d_h\rfloor$ and then round up to the multiple of a factor (e.g., 256)
    • for example: 13B model specified $d_{in}=5120$, then $d_h=13824$ ($\lfloor\frac83 d_{in}\rfloor = 13674$, then round up to multiples of 256); each saved model has 6912 with 2 shards
  • $W_1,W_3$ are implemented as ColumnParallelLinear and $W_2$ is implemented as RowParallelLinear; which only W_2 has set gather_output=True to synchronize parallel runs

The class TransformerBlock: It connects a Attention module, a FeedForward module, and two RMSNorm modules (one for attention and one for feedforward). Its workflow is:

\[\begin{aligned} h := x + \text{Attention}(\text{LN}_{attn}(x)) y := h + \text{FF}(\text{LN}_{FF}(h)) \end{aligned}\]

of which:

  • pre-LN architecture is used, at both the attention and feedforward
  • only self-attention but without cross-attention module as in the vanilla Transformer model, since this is a decoder-only architecture
  • in diagram, it is as follows:

Generation Code

At, the class Llama and the function sample_top_p() are defined.

The function sample_top_p() takes a production distribution tensor probs and a probability threshold p as input. Then:

  1. Sorts probs in descending order into probs_sort (and remembers the original index), then computes the cumulative sum probs_sum
  2. Find the mask probs_sum - probs_sort > p; the mask at position $i$ means the cumulative probability of positions 0 to $i-1$ is strictly beyond $p$
  3. Zero out those masks in probs_sort (i.e., position $i$ until the end) and renormalize it to make it sum to 1
  4. Random sample based on the renormalized probability distribution, then convert the sampled value to its original index

The class Llama ties everything together. In its constructor, it created a Transformer object and a Tokenizer object. In the static factory method build(), it takes inputs ckpt_dir (checkpoint dir), tokenizer_path, max_seq_len, max_batch_size, model_parallel_size, and seed, and returns a Llama object. In this build() function:

  • it initialized torch.distributed with nccl backend (the only backend for GPU, the other backends gloo and mpi are for CPU only)
  • using torch.manual_seed() to seed RNG to 1
  • sort *.pth files in the checkpoint dir; based on the current machine’s index in the distributed cluster, load the indexed file by get_model_parallel_rank()
    • the number of machines in a cluster must exactly match the number of shards in the checkpoint dir
    • all checkpoints (e.g., consolidated.{00,01}.pth) are of the same byte size
  • model parameters are loaded from params.json from the checkpoint dir and updates ModelArgs, except the parameters max_seq_len and max_batch_size are provided by build() only. Example:
{"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1}
  • recreated Tokenizer from the tokenizer_path, e.g., set the vocabulary size
  • create the Transformer object with the prepared arguments, then load the checkpoint. This works as one shard
  • create and return the Llama object with the Tokenizer and Transformer models

In the Llama class, the method generate() is decorated with a PyTorch inference mode decorator. It takes as input prompt_tokens (a batch of prompts, as a list of list of vocab id), max_gen_len, temperature, top_p, logprobs (boolean), echo (boolean). It returns the list of output tokens (list of list of ids) as well as the corresponding logprob.

  • at start, it validated that the input is within model’s max batch size and all prompts below the model’s max sequence length
    • the model has been created. The pre-computed RoPE tensor cannot handle exceeding length
  • the set total_len to be the min of max_seq_len and max_gen_len + max_prompt_len; this should be the max output sequence length that this model can produce
  • the prepare tensors
    • tokens: tensor of size (batch_size, total_len) filled with the pad id, which is retrieved from the tokenizer
      • afterwards, fill in the tokens tensor with the input prompts (left aligned)
    • token_logprobs: tensor of size (batch_size, total_len) filled with zero
    • eos_reached: tensor of size (1, batch_size), filled with False to indicate if a prediction reached the EOS token
    • input_text_mask as boolean tensor to tell if tokens is not pad id

The workflow for generate() is as follows:

  1. if min_prompt_len reached total_len, run the model once to get the logit, then calculate token_lgoprobs based on the cross entropy between the logits and the original tokens
  2. run autoregressively until min_prompt_len reached total_len, with the position cursor cur_pos: a. if temperature=0, simply pick the next token by argmax b. if temperature > 0, use softmax and sample_top_p() to pick the next token c. next token is filled into the tensor tokens at position cur_pos only if it is masked d. then update token_logprobs by comparing the model output logits to the “next token”
    • model is passed in the tokens of range prev_pos:cur_pos across the batch to generate next token of each position; only the last output is used for next_token, which correspond to the position cur_pos+1
    • initial prev_pos is 0; each iteration updates prev_pos to cur_pos
    • comparing output logit to tokens of range prev_pos+1:cur_pos+1 tells how accurate is the model output using as much information as possible; all tokens at prev_pos+1:cur_pos+1 are considered equally in cross entropy because the entire sequence is presented to the model e. update eos_reached to check if EOS has been generated at cur_pos; terminate the autoregressive for-loop if all input in the batch has EOS f. update prev_pos to cur_pos before next iteration
    • first iteration use subsequence from 0, subsequently, the input to the model is only the one token generated in the previous iteration
    • this design is to minimize the computation:
      • At second iteration onward, all tokens has been seen except the last
      • Only the attention layer will need the entire sequence; the feedforward layer process each token individually.
      • At attention layer, the start position is provided, which is used to update the key/value cache to insert the new token. Afterwards, the entire sequence can be read from the cache
  3. at finish, convert the generated tensor tokens into list of tokens, cut off at EOS, optionally also produce the list of logprob
  4. output: return the list of output tokens (list of list of ids) and logprobs

The methods text_completion() and chat_completion() are applications of generate() method. Both takes the similar input as the last one (e.g., temperature, top-p) but in particular, in text_completion(),

  • prompts are presented as a list of strings, then converted into a list of ids using the tokenizer in a loop
  • then the converted prompts are passed to self.generate() to get the generated tokens and logprobs
  • decode the generated tokens into string and return, optionally together with the logprobs

and in chat_completion(),

  • prompts are presented as a list of Dialog objects, which in turn is a list of Message, a typed-dict of role and content
  • check if the Dialog contains any “unsafe” string, i.e., special tags ([INST], [/INST], <<SYS>>, <</SYS>>)
  • format the input: first message’s role can be “system”. If so, merge the first message with the second using the template

    {first message}
    {second message}
  • the list of dialog is assumed to have the prompt and answer interleaving; which then each pair is formatted as


    and then the pairs are concatenated. The last message in dialog is the final prompt, also concatenated using the same template.

  • the concatenated prompts are then tokenized and preserved in the list prompt_tokens and sent to self.generate()


The repo has an example script, This is not a vanilla PyTorch code, but using fairscale. Hence you cannot run the script with a barebone python interpreter. The suggested command (from its readme) is:

torchrun --nproc_per_node 1 \
    --ckpt_dir ../llama-2-7b \
    --tokenizer_path ../llama-2-tokenizer/tokenizer.model

The version in Hugging Face has these removed, hence we can run with the python interpreter directly. Example code:

from transformers import AutoTokenizer
import transformers
import torch
import accelerate

model = "meta-llama/Llama-2-7b-chat-hf"

sequences = pipeline(
    'Hi! I like cooking. Can you suggest some recipes?\n')
for seq in sequences:
    print(f"Result: {seq['generated_text']}")

where trust_remote_code is required to download the weight from the hub, and device_map should set to "auto" when using accelerate library. Other parameters are to control the decoding strategies, such as multinomial sampling, beam-search, top-K, and top-P. Only the beam-search and sampling support num_return_sequences more than 1.

What is Fairscale?

It is an alternative to PyTorch DDP. It is a tool for sharding. To use:

  • shard to optimizer: use wrapper fairscale.optim.oss.OSS(optim=torch.optim.SGD, params=model.parameters, **otherargs) to create an optimizer instead of the simple torch.optim.SGD; then replace the original model with wrapper ShardedDDP(model, optimizer); afterward, run model.train() as usual
  • similarly, there are wrapper to parallelize models, e.g. fairscale.nn.Pipe(model, balance=[a,b]) to run a model across two GPUs such that the layers are loaded to each at a ratio of 2:1

In LLaMA-2 code, the parallelism is determined by the parameters WORLD_SIZE and LOCAL_RANK, which the WORLD_SIZE should match the number of checkpoints, e.g., 13B is 2. The model will have sharding on the number of attention heads in the multi-head attention module.

  • model is defined in the unsharded format, which the matrices are defined using ColumnParallelLinear and RowParallelLinear
  • in Attention class, self.cache_k and self.cache_v are set up to the size of local heads to manage the xk and xv (i.e., key and value tensors)
  • As ColumnParallelLinear is used for wq, wk, and wv, the option gather_output=False makes all matrix multiplication local; as RowParallelLinear is used for wo, the output is gathered, and specified input_is_parallel=True
  • example: 13B model, 2 shards, model param dim=5120, n_heads=40, n_layers=40. Based on the code, the attention dimension head_dim should be 5120/40=128 and the matrices $W^Q,W^K,W^V,W^O$ should be all 5120×5120 (after input vector multiplied with matrices, reshape it to (40,128) for attention score computation). But in the .pth file, these matrices are of shape 5120×2560.