Streaming from Flash Attention

💡

Hire: Complete Parts A, B and the first part of C. You should be able to apply the insights from Flash Attention to the streaming mean or, with guidance, come to the same insight through discussions.

Strong Hire: Complete Parts A, B, and C. You should be able to recognize the core insight from either the paper or streaming mean, and be able to extend that idea to all parts of C.

Before attempting this question, you should study the Flash Attention paper. Prepare for it as though your interviewer had told you in advance that this was going to be the subject of your interview.

This question won't ask you for details in the paper, and instead, it will test your understanding of paper's core insights in a hands-on sort of way.

A. Flash Attention Paper

We can start by asking questions about what the core insights of the paper were.

What were the core insights of the Flash Attention paper?

Most people should understand the selling point from its abstract: Fewer memory reads and writes. However, a more subtle contribution to make that work, was making a streamed version of softmax. If the candidate recognizes both of these insights, this is a big green flag. However, I don't expect most people to catch on to both, and as a result, the rest of this question dives into the second part. For more details, see How Flash Attention works.

How does Flash Attention reduce memory reads and writes?

In short, Flash Attention's main contribution is to avoid storing the full set of attention scores $QK^T$. That tensor is a full n_heads x seq_len x seq_len, which is humongous, especially for long contexts. To accomplish this, Flash Attention fuses two back-to-back matrix multiplies. Candidates should be able to give a summary along these lines. Bonus points if they understand tiling, and how Flash Attention fuses tiling for two consecutive matrix multiplies. For more details, see When to fuse multiple matrix multiplies.

💡

These questions are just for checking off a box. I would likely not correct the candidate very much if they made a mistake, because the bulk of the question is in the coming parts — if they can complete the remaining parts, they would have derived one of the paper's core insight on their own anyways.

B. Streaming Dot Product

Same we have two streams of numbers, and we'd like to compute a streaming dot product.

How long can these streams be?

Let's consider them to be infinitely long, and let's also trust that Python can handle infinitely-large integer values.

Will the streams ever contain numbers?

Let's assume that all values will be valid floats. No NaN, inf, or imaginary numbers.

Do we have a bound on the value of the final product? What if it overflows the datatype its stored in?

For simplicity, let's assume Python can hold infinitely-large numbers. We don't have a bound in advance.

Pick a representation for the stream of numbers. Hint: You cannot use a data structure to hold all values explicitly.

You can use an iterator or a generator. In my own solutions, I'll opt for a generator, but either works.

import random

def make_stream():
    while True:
        yield random.random() / random.random()  # unknown min and max

a = make_stream()
b = make_stream()
Implement a streaming dot product, and design an interface for passing the stream in.

Note that you should also be streaming the output back out. Make sure your own function also returns an iterator or generator.

def streaming_dot_product(a_stream, b_stream):
    total = 0
    for a, b in zip(a_stream, b_stream):
        total += a * b
            yield total

C. Streaming Standardized Sum

Now, we'd like to standardize and sum values from a stream.

How would you compute the streaming mean?

This is a common interview question, so you may already have memorized the answer. If you haven't, the interviewer can work with you to come up with the answer.

The key insight is that you always keep a copy of the count so far, $k$. Then, you can "replace" the denominator in the old mean, by multiplying the old mean with $\frac{k}{k+1}$.

def streaming_mean(x_stream):
    count = 0
    mean = 0
    for x in x_stream:
        new_count = count + 1
        mean = mean * (count / new_count) + x / new_count
        count = new_count
        yield mean
How would you compute the streaming variance?

This question can be broken down into two parts.

  1. Handle the streaming count. Luckily, just like in the streaming mean, we can apply a multiplicative "correction factor" by multiplying $\frac{k}{k+1}$.
  2. To handle the streaming mean, we can apply a similar idea. In the previous bullet point, we "replaced" the denominator using a multiplicative correction factor. Now, we "replace" the mean that was subtracted using an additive correction factor.

This additive correction factor takes some explaining though. First, start with the formulation for variance, assuming we have $k$ entries.

$$\sigma^2 = \frac{1}{k}\sum_{i=1}^n (x_i - \mu)^2$$

Let's focus on the quadratic term. How do we "replace" the old mean $\mu_k$ with the new one $\mu_{k+1}$? Let's expand the quadratic to see.

$$\sigma_k^2 = \frac{1}{k}\sum_i (x_i^2 - 2\mu_kx_i + \mu_k^2)$$

Next, let's distribute the summation. We notice that $\mu_k$ is independent of the summation, so we can pull that out.

$$\sigma_k^2 = \frac{1}{k}\sum_i x_i^2 - 2\mu_k\frac{1}{k}\sum_i x_i + \frac{1}{k}\sum_i \mu_k^2$$

There are a few simplifications we can make:

  • Drop terms that don't include the old $\mu_k$, since those don't need fixing — so we can ignore the first term.
  • There's something crazy in the second term: The second term contains the mean! In fact, it's just the $\mu_k = \frac{1}{k}\sum_ix_i$, so we can simplify that term to be $2\mu_k^2$.
  • Notice that again $\mu_k$ is independent of the summation over $i$, so the third term simplifies to just $\mu_k^2$.

$$\sigma_k^2 = \frac{1}{k}\sum_i x_i^2 - 2\mu_k^2 + \mu_k^2 = \frac{1}{k}\sum_i x_i^2 - \mu_k^2$$

Crazily enough, this means that to fix our old correction factor, we simply add back $\mu_k^2$ and subtract the new one $\mu_{k+1}^2$.

def streaming_variance(x_stream):
    mean = 0
    count = 0
    variance = 0
    for x, new_mean in zip(x_stream, streaming_mean(x_stream)):
        new_count = count + 1

        # correct the old mean with the new one
        variance = variance + mean**2. - new_mean**2.

        # correct the count in the denominator
        variance = variance * (count / new_count) + (x - new_mean)**2. / new_count

        count += 1
        mean = new_mean
        yield variance

💡

Most interviews will probably end around here. Even if the streaming mean was a short discussion, the streaming variance will have taken time to derive and explain — even if you've seen this previously and already knew the answer.

How would you now construct a function that standardizes a stream?
def streaming_standardize(x_stream):
        mean_stream = streaming_mean(x_stream)
        var_stream = streaming_variance(x_stream)
    for x, mean, variance in zip(x_stream, mean_stream, var_stream):
        yield (x - mean) / variance
How would you now compute the streaming standardized sum?
def streaming_dot_product(x_stream):
    total = 0
    old_var = 1
    old_mean = 0
    mean_stream = streaming_mean(x_stream)
        var_stream = streaming_variance(x_stream)
    for x, mean, var in zip(a_stream, mean_stream, var_stream
        total = ((total * old_var) + old_mean - mean) / var + (x_stream - mean) / var
            yield total