How to read transformer code

Relative to other deep learning code, there are two specific reasons why Large Language Model code is a bit less readable:

  1. Number of tensor dimensions: There are many dimensions that need to be handled individually, so tensors can end up being up to 5-dimensional — sometimes even more.
  2. Rearranging for matrix multiplies: The dimension we're batching across differs from operation to operation, so many lines of code are dedicated to just rearranging dimensions in a tensor.

There are other reasons like vectorization, which you'll see in Tips for Vectorization, but let's focus on the two above. This guide will give us insight into how to read — and even how to build — transformer code.

💡

Note: You might be wondering — why would I read transformer code? Inevitably, you'll need to dive into the internals on the job, and knowing what you're looking at will help tremendously, even if the understanding is high level. The hardest part is associating high level concepts with specific lines of code, so that's what this guide aims to solve. Plus, you'll be expected to know this in AI interviews.

Why are there so many tensor dimensions?

In transformer code, tensors tend to have a ridiculous number of dimensions, depending on which part of the transformer you're in. To start off, there are two dimensions that tell you how many tokens are in this tensor: You have the number of requests you're handling and the number of tokens per request. We will call the former the batch_size and the latter the sequence length, or seq_len.

Confusingly, for some operations, you have to combine the two into one "batch" dimension. For example, you may have a matrix multiply like torch.bmm1 that only accepts 3D tensors, meaning only one dimension can be used to hold all dimensions you want to batch across. You then see a series of confusing-looking operations like the following.

x = torch.randn(4, 64, 16, 8)
w = torch.randn(4, 64, 4, 12)
x = x.reshape(x.reshape(-1, *x.shape[1:])
w = w.reshape(w.reshape(-1, *w.shape[1:])
y = torch.bmm(x, w)
y = y.reshape(batch_size, seq_len, *x.shape[1:])

For your initial layers, like your embedding layer, this is it for the complexity. However, in your attention code, you'll have even more dimensions to keep track of, one for the number of heads, like n_heads and one for your head dimension, like d_head.

Furthermore, newer architectures use multiple queries for each key, so that's another hyper-parameter or even possibly another dimension you need to track. Your tensors end up having quite a few dimensions.

# batch_size, seq_len, n_keys_per_group, n_heads_per_group, d_head
q = torch.randn(4, 64, 1, 16, 8)
k = torch.randn(4, 64, 2, 16, 8)

Although this doesn't entirely fix the issue, it's worth noting that PyTorch does support named tensors, which can at least make the above more readable.

x = torch.randn(
    4, 64, 1, 16, 8,
    names=('batch_size', 'seq_len', 'n_keys', 'n_heads', 'd_head')
)

JAX doesn't support named tensors in the same way but supports named shards.

Why do we keep moving dimensions around?

Just due to the nature of a transformer, we often switch the dimension we're operating over, from operation to operation. For example, attention operates across tokens but the feed-forward network operates within tokens. Even for attention itself, there are switches to be aware of. For example, take the typical attention formulation.

$$\text{softmax}(\frac{QK^T}{\sqrt{d_h}})V$$

Just for sake of explanation, I'm going to grossly oversimplify this expression, by dropping every operation that does not change dimensionality of the activations. Let's calculate a small part of the attention2.

$$QK^TV$$

For 2-dimensional tensors (i.e., matrices) in a single head, the code for this looks pretty straightforward. We are multiplying matrices over seq_len x d_head and batch along batch_size. We would expect to see something like the following.

Q = torch.randn(4, 64, 8)  # batch_size, seq_len, d_head
KT = torch.randn(4, 8, 64)
V = torch.randn(4, 64, 8)

y = Q @ KT @ V

That looks decent. Let's batch across all of the heads at once now. The last two dimensions seq_len and d_head need to stay in the last two positions, since those are the dimensions involved in the matrix multiply. As a result, we add the n_head dimension in the second position. Everything else stays the same.

Q = torch.randn(4, 16, 64, 8)  # batch_size, n_head, seq_len, d_head
KT = torch.randn(4, 16, 8, 64)
V = torch.randn(4, 16, 64, 8)

y = Q @ KT @ V

That doesn't look bad either. Now, based on this code snippet, keep in mind: For attention, we need Q to have shape batch_size, n_head, seq_len, d_head. We will use this fact later.

Now, recall that we first need to project the input $X$ into queries, keys, and values using model weights $W_Q, W_K, W_V$.

$$\begin{align}Q &= X W_Q\\K &= X W_K\\V &= X W_V\end{align}$$

Here's what that might look like in code. For simplicity, let's just focus on queries. These weights $W_Q, W_K, W_V$ are usually stored as d_model x (n_head * d_head).

X = torch.randn(4, 64, 128) # batch_size, seq_len, d_model
W_Q = torch.randn(128, 128) # d_model, n_head * d_head

Q = X @ W_Q # batch_size, seq_len, n_head * d_head

Now, let the chaos begin. We have a tensor Q that has shape batch_size, seq_len, n_head * d_head. However, recall what we said before: Q needs to have shape batch_size, n_head, seq_len, d_head.

(batch_size, seq_len, n_head * d_head)  # what we have
(batch_size, n_head, seq_len, d_head)   # what we need

We can accomplish this shape change by splitting the last dimension, then reordering dimensions. Here's what that looks like.

batch_size, seq_len, n_head, d_head, d_model = 4, 64, 16, 8, 128
X = torch.randn(batch_size, seq_len, d_model)
W_Q = torch.randn(d_model, n_head * d_head)

Q = X @ W_Q # batch_size, seq_len, n_head * d_head
Q = Q.reshape(batch_size, seq_len, n_head, d_head)
Q = Q.permute(0, 2, 1, 3) # batch_size, n_head, seq_len, d_head

We would have to redo that for K and V, then run our attention computation from above. From just few snippets, you can start to see how attention code can end up fairly gnarly. So as you're reading through transformer code, just remember: Most of those reshapes and permutes are just to get all the right dimensions next to each other for a matmul.

Takeaways

In short, most of the complexity of transformer code isn't even due to the "math". It's just a whole lot of preparation for the math, so that the matrix multiplication is as efficient as possible. Keep this in mind as you connect code with concepts, so that you know what to filter out when you're hunting for certain logic in the codebase.

For what it's worth, there are workarounds to the above problems.

  • einsum: You could use einsum — this can allow you to perform complex operations like batched matrix multiplications, transpositions, contractions etc. all in one operation. However, in exchange for its flexibility, einsum is inefficient, so you'd have to translate into matmuls anyways for an efficient production-ready codebases, whether for training or inference.
  • functorch.dims: You could also use functorch's named tensors, which are more powerful than PyTorch's own named tensors. The readability is vastly improved, and it doesn't have the same inefficiencies that einsum does. This functional form of PyTorch is less popular, but the first-class treatment of named dimensions in their examples seems really promising.

  1. torch.bmm can only take 3D tensors. Granted, you could just switch to torch.matmul if you wanted to handle N-dimensional tensors and weren't strictly "shape-checking" your tensors. 

  2. For simplicity, I've omitted details such as grouped attention, the scaling factor, multi-head weights, etc. This is not the full equation for any attention implementation, but it's a part of every attention implementation.