Attention from scratch

💡

Hire: Finish Part A and Part B. You should demonstrate working knowledge of self-attention and have the ability to convert that understanding into code. You should also be comfortable with the actual formulation of softmax, to understand how to mask attention scores.

Strong Hire: Completes Parts A and B. Finishes calculating gradients for C by hand and ideally completes at least pseudocode for the last part. A strong candidate may still be allowed guidance for the last part, but the first two parts should be done independently.

This questions largely tests your ability to convert concepts into code. This includes self-attention, attention variants, and backpropagation.

A. Write attention

Write multi-headed attention.

What hyperparameters should I use, for n_heads, d_head, etc.?

Suggest reasonable defaults. For example, n_head=32, d_head=16, d_model=512, seq_len=64. There are two important details to note here:

  • Ensure all dimensions have different values, so that we can easily tell when dimensions have been flipped when we're writing code.
  • Per Model Architecture in numbers to memorize, it's slightly more sightly to have n_head * d_head = d_model.
Which variant of attention should I write?

Write the multi-headed attention introduced by the original transformer paper. Consider just the attention that was used in the original encoder branch (If you're not sure what this means, just write the simplest version you can think of). Assume there is a one-to-one mapping between queries, keys, and values.

What are my inputs?

Assume you are given Q, K, and V in any shape you deem appropriate. You will also need the final cross-head mixing weights $W_O$. You can pull numbers from Math to memorize or just make up any power of two that makes your tensors small and fast to work with.

What is the batch size?

For now, you can ignore the batch dimension. It doesn't tangibly effect the problem's difficult or takeaways — just adds more bookkeeping.

Construct your input tensor(s). Somehow indicate the name of each tensor dimension.
import torch

d_model = 512
n_heads = 32
d_head = 16
seq_len = 64

Q = torch.randn(n_heads, seq_len, d_head)
K = torch.randn(n_heads, seq_len, d_head)
V = torch.randn(n_heads, seq_len, d_head)
W_O = torch.randn(n_heads * d_head, d_model)
Write your attention function using the tensors you defined above as inputs.
import torch.nn.functional as F
from torch import Tensor

# NOTE: Ask yourself. What is the dimensionality of the tensor at each
# point? Write out the dimension of our explicitly to make it clear
# for both you and your interviewer.
# NOTE: I've used rather unreadable single-letter variable names just
# because this will make it easier for us to reference tensors later
# in the problem.
def attention(Q: Tensor, K: Tensor, V: Tensor, W_O: Tensor) -> Tensor:
        S = Q @ K.transpose(-2, -1) / sqrt(d_head)
        A = F.softmax(scores, dim=-1)  # (n_heads, seq_len, seq_len)
        O = A @ V  # (n_heads, seq_len, d_head)
    O = out.transpose(0, 1).reshape(seq_len, n_heads * d_head)
    return out @ W_O  # (seq_len, d_model)

out = attention(Q, K, V)  # make sure it runs

B. Write Llama3 attention

Bridge the gap between your implementation and state-of-the-art.

Update your input tensors for grouped-query attention.
import torch

d_model = 512
n_heads = 32
d_head = 16
seq_len = 64
group = 4 # number of queries per key

Q = torch.randn(n_heads, group, seq_len, d_head)
K = torch.randn(n_heads, 1, seq_len, d_head)
V = torch.randn(n_heads, 1, seq_len, d_head)
W_O = torch.randn(n_heads * group * d_head, d_model)
Implement grouped-query attention.

Ideally, you should simply reuse your previous attention function.

def gqa(Q: Tensor, K: Tensor, V: Tensor, W_O: Tensor) -> Tensor:
    K = K.repeat(group, dim=1)
    V = V.repeat(group, dim=1)
    return attention(Q, K, V, W_O)

We are missing a critical component of attention — namely, the attention mask.

Conceptually, what is the attention mask, and why do we need it?

The attention mask intuitively ensures that all tokens can only "attend" to tokens that came before it. Concretely, we implement this by just zero'ing the upper diagonal of the attention weights — i.e., the activations that are multiplied by V.

How would you implement the attention mask, at a high level?

This is actually a non-trivial question, unless you've looked at the implementation or written the attention from scratch before.

  • One natural reaction is to simply set the upper diagonal attention weights to zero. However, attention weights in each row are no longer guaranteed to sum to one.
  • Another more buggy possibility is to zero the upper diagonal before the softmax. However, zeros before the softmax will become non zero after: After all, $e^0 = 1$.
  • Instead, the correct answer is to set the upper diagonal before the softmax to negative infinity, since $e^{-\text{inf}} = 0$.
Implement masked attention.
def attention(Q: Tensor, K: Tensor, V: Tensor, W_O: Tensor) -> Tensor:
        S = Q @ K.transpose(-2, -1) / sqrt(d_head)

        # Code to mask out the upper diagonal. You could have alternatively
        # subtracted, such as `scores -= mask * float('-inf')`.
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
        S.masked_fill_(mask, float('-inf'))

        A = F.softmax(scores, dim=-1)
        O = A @ V
    out = out.transpose(0, 1).reshape(seq_len, n_heads * d_head)
    return out @ W_O  # (seq_len, d_model)

💡

Most interviews will likely end here. However, if you practice writing attention from scratch in advance, you should most definitely advance further, and at least get to computing derivatives for backpropagation. And really, this is something you can refine with practice. I'm giving you the answer here, so there is no reason to not practice the most obvious question.

C. Backpropagate

Implement backpropagation for self-attention. Consider just a single head, and ignore $W_O$.

Let's start by calculating manually, on paper. How would you work out what to implement for backpropagation? In particular, what is the key insight in backpropagation that makes autodiff — and your manual calculation — much simpler?

The key insight in backpropagation is that you could take the derivative separately for each operation, of that operation's output with respect to that operation's input. You could then string all of these derivatives together using the chain rule. Knowing this, we simply need to calculate derivatives starting from the output of attention, working backwards towards each of the four inputs that our attention operated on — operation by operation.

Calculate gradient for V

V is the simplest to calculate gradients for. The final output $O = AV$, so

$$\frac{\partial L}{\partial V} = A^T\frac{\partial L}{\partial O}$$

Let's denote all loss gradients with respect to a tensor as $\frac{\partial L}{\partial X} = dX$, which makes our expression

$$dV = A^T dO$$

This is our first gradient!

Calculate the gradient for our post-softmax attention coefficients.

Now, we need gradients for everyone else, so let's start with the partial of the output with respect to the attention coefficients, which we'll call $A$.

$$\frac{\partial L}{\partial A} = \frac{\partial L}{\partial O}V^T$$

Rewritten in terms of full gradients, we have

$$dA = dO V^T$$

This is the last of the simple ones.

Calculate the gradient for pre-softmax attention scores.

Let's call the pre-softmax scaled scores $S$. We'll need to do this one step by step. To save some writing, let's define $\mathcal{L} = \sum_\ell e^{S_{i\ell}}$.

$$A_{ij} = \text{softmax}(S_i)_j = \frac{e^{S_{ij}}}{\sum_\ell e^{S_{i\ell}}} = \frac{e^{S_{ij}}}{\mathcal{L}}$$

To compute the partial, use the quotient rule. My childhood mnemonic device was "low d-high minus high d-low, all over low low". Remember that $\frac{d}{dx}(e^x) = e^x$.

$$\frac{\partial A_{ij}}{\partial S_{ik}} = \frac{\mathcal{L}(\frac{d}{dS_{ik}}e^{S_{ij}}) - e^{S_{ij}}(\frac{d}{dS_{ik}}\mathcal{L})}{\mathcal{L}^2}$$

Focus on the first partial in the numerator. Let's say $j = k$. Then, we get that

$$\frac{d}{dS_{ij}}(e^{S_{ij}}) = e^{S_{ij}}$$

Alternatively, if $j \neq k$, then we get

$$\frac{d}{dS_{ik}}(e^{S_{ij}}) = 0$$

We can simplify this using an indicator variable that is 1 if $j = k$ and 0 otherwise.

$$\frac{d}{dS_{ik}}e^{S_{ij}} = e^{S_{ij}}\mathbb{1}_{j=k}$$

Plugging this back in, we get

$$\frac{\partial A_{ij}}{\partial S_{ik}} = \frac{\mathcal{L}(\mathbb{1}_{j=k}e^{S_{ij}}) - e^{S_{ij}}e^{S_{ik}}}{\mathcal{L}^2} = \frac{e^{S_{ij}}}{\mathcal{L}}\mathbb{1}_{j=k} - \frac{e^{S_{ij}}}{\mathcal{L}}\frac{e^{S_{ik}}}{\mathcal{L}}$$

Notice the fractional terms are simply $A_{ij}, A_{ik}$! Let's replace those fractions.

$$\frac{\partial A_{ij}}{\partial S_{ik}} = A_{ij}\mathbb{1}_{j=k} - A_{ij}A_{ik} = A_{ij}(\mathbb{1}_{j=k} - A_{ik})$$

Now, we need write the full partial derivative of the loss with respect to the pre-softmax values. Note that the pre-softmax values in an entire row affect the post-softmax values. As a result, we need to consider all of their contributions to the gradient.

$$\begin{aligned}\frac{\partial L}{\partial S_{ik}} &= \sum_j (\frac{\partial A_{ij}}{\partial S_{ik}} \frac{\partial L}{\partial A_{ij}}) \\ &= \sum_j \frac{\partial L}{\partial A_{ij}}A_{ij}(\mathbb{1}_{j=k} -A_{ik}) \\&= \sum_j (\frac{\partial L}{\partial A_{ij}}A_{ij}\mathbb{1}_{j=k}) - A_{ik} \sum_j (\frac{\partial L}{\partial A_{ij}}A_{ij}) \\&= A_{ik}(\frac{\partial L}{\partial A_{ik}} - \sum_j\frac{\partial L}{\partial A_{ij}}A_{ij})\end{aligned}$$

Later in our code, we'll compute the inner summation first, which we'll denote

$$C_i = \sum_j \frac{\partial L}{\partial A_{ij}} A_{ij}$$

Again, rewriting as full gradients, we now have the following.

$$dS = A (dA - C)$$

where C is defined above as a vector with the summed terms.

Calculate the gradient for our Q and K.

Finally, compute the gradients for $Q, K$. The attention scores are defined as

$$S = \frac{QK^T}{\sqrt{d_h}}$$

Following the same convention for $V$, for picking which side is which, we have the following for $Q$:

$$\frac{\partial L}{\partial Q} = \frac{\partial L}{\partial S}\frac{1}{\sqrt{d_h}}(K^T)^T = \frac{\partial L}{\partial S}^T\frac{1}{\sqrt{d_h}}K$$

Then, we have the following for $K$:

$$\frac{\partial L}{\partial K^T} = \frac{1}{\sqrt{d_h}} Q^T\frac{\partial L}{\partial S} \implies \frac{\partial L}{\partial K} = \frac{\partial L}{\partial S}^T Q\frac{1}{\sqrt{d_h}}$$

💡

I don't expect anyone to be able to fully implement a working backpropagation during the interview, in the allotted time. However, you should be able to finish calculating each of the partials, possibly with guidance, and you may get time to start an implementation — or to produce pseudocode.

Implement backpropagation for self-attention.
import torch
import torch.nn.functional as F

def backward_self_attention(Q, K, V, dO):
    # Copy our code from the forward pass
    S = Q @ K.transpose(-2, -1) / sqrt(d_head)
    A = F.softmax(S, dim=-1)
        O = A @ V

    # 1. Compute gradient with respect to V
    dV = A.transpose(-2, -1) @ dO

    # 2. Compute gradient with respect to A
    dA = dO @ V.transpose(-2, -1)

    # 3. Backprop through softmax to get dS from dA.
    C = (dA * A).sum(dim=-1, keepdim=True)
    dS = A * (dA - C)

    # 4. Compute gradient with respect to Q and K
    dQ = dS @ K / sqrt(d_head)
    dK = dS.transpose(-2, -1) @ Q / sqrt(d_head)

    return dQ, dK, dV

# Example dimensions and random inputs for a single-head case
Q = torch.randn(seq_len, d_head, requires_grad=True)
K = torch.randn(seq_len, d_head, requires_grad=True)
V = torch.randn(seq_len, d_head, requires_grad=True)
dO = torch.randn(seq_len, d_head)  # upstream gradient

dQ, dK, dV = backward_self_attention(Q, K, V, dO)