Preparing argmax for training
💡
Hire: Complete parts A, B, and first half of C. You should be able to write argmax, report its algorithmic complexity, and vectorize existing code.
Strong Hire: Additionally complete part D or second half of C. Both are tricky, and even PhDs have gotten part D wrong — or missed the cleanest way to do C. Thus, getting either correct will land you in the strong hire category.
A. Implement argmax
Implement argmax on a matrix of values, without external libraries.
How do we handle NaNs?
You may assume that NaNs are not included in the input.
Are there non-numeric values we need to handle?
All values are guaranteed to be numeric.
What dimensionality is my input?
In any other interview, this is guaranteed to be a simple list of numbers. However, in an AI interview, you should see where this is leading: You're probably going to end up working on a matrix or a high-dimensional tensor. Given that, we should ask first about the tensor's dimensionality.
Our input is a list of list of floats.
How should we handle empty matrices?
The interviewer may just flip it back on you: What do you think is a reasonable way to handle this?
If the interviewer asks you this, it probably doesn't matter what you pick, because it's not the point of the question, but you should still have a reasonable response. For example, you could suggest throwing a ValueError. Or, in the code below, return the empty list.
What if there are multiple of the maximum value?
Again, like to flip it back on you: What do you think is a reasonable way to handle this?
Again, this is likely a question the interviewer would turn back on you. You could suggest returning the index of the first or the last one. Then, the interviewer may ask: Why not return all of the indices that match the maximum? The answer is simple: Then you may end up using O(N) memory. The less obvious answer would earn you bonus points though: Returning all indices would result in a dynamically-sized array, making this code unsuitable for static graph construction later on, if you use torchscript, CUDA graphs or a framework like JAX.
Solution
def argmax(lsts: list[list[float]]) -> list[int]:
# TIP: Always allocate the output explicitly, in advance. If you can,
# avoid .append which may cause dynamic array resizing. In other words,
# that may cause many runtime memory allocations, which slows down your
# program. No one will fault you for using .append. Just a pro tip, and
# extra brownie points if you can explain this in an interview.
max_idxs = [-1] * len(lsts)
for row_idx, lst in enumerate(lsts):
max_val, max_idx = -float('inf'), -1
for idx, val in enumerate(lst):
if val > max_val:
max_val, max_idx = val, idx
max_idxs[row_idx] = max_idx
return max_idxs
assert argmax([[1, 4.1], [-2, 3.5]]) == [1, 1]
If you are having a hard time: How would you approach this for a single vector first?
# I considered making the default max_idx=0. This could be a good idea
# or a horrible one. It *could be an okay idea, because it handles the
# case where our input is ONLY all `[-float('inf')] * 100`. However, it
# is potentially a bad idea, because if there's a bug in our code that
# causes the for loop to be a no-op (or the input is empty), we wouldn't
# be able to tell.
def argmax(lst: list[float]) -> int:
max_val, max_idx = -float('inf'), -1
for idx, val in enumerate(lst):
if val > max_val:
max_val, max_idx = val, idx
return max_idx
assert argmax([1, 4.1, -2, 3.5]) == 1
One really funny way to work around the problem is to use Python's built-in max. Technically not illegal, because it's a built-in, but it does defeat the point of the question.
def argmax(lst: list[float]) -> int:
return max(range(len(lst)), key=lambda idx: lst[idx])
assert argmax([1, 4.1, -2, 3.5]) == 1
B. Algorithmic complexities
Interviewer: Let's assess the current performance of your algorithm.
What is the time complexity and space complexity of your algorithm?
Assuming the input matrix is M x N, the algorithm is O(MN) time and O(M) space. Note that enumerate
is a lazily-evaluated iterator, so it does not allocate additional memory.
Can we improve our time complexity? How would you do it?
No, you must take at least one pass over all data to find the maximum.
Can we improve our space complexity? How would you do it?
Technically, you could pull the same trick again of storing your outputs in the input, but in most cases, we wouldn't expect argmax to operate in place.
💡
The initial question was fairly simple, so the interviewer would be expected to cover one more part. Personally, I'd jump straight to part C at this point, given I'd want to test more AI-relevant skills this question was building up to test.
Without changing the algorithmic complexity, what are some ways to make this argmax run faster?
Idea #1. We could parallelize argmax over subsets of the input list.
Response: However, whether Python is executing across multiple threads or just one, Python's GIL forces all Python code to be executed serially. In other words, multithreading would not help because Python is not parallelizable. Parallelization would only help if our task's slowdown came from something that is parallelizable — like file loading or database requests.
So, in the extreme case where our list is so big that it doesn't fit in RAM, we could conceivably parallelize loading from disk. For what it's worth though, that would be really hard to achieve. Even on consumer laptops, RAM is around 16 GB, which can store 8 billion FP16 values. That corresponds to a massive 89,000 x 89,000 matrix! No matrix, even in Large Language Models, is this large. See .
Idea #2. You could vectorized operations in a custom kernel. Fortunately, in reality, tensor libraries have already built highly-optimized C implementations.
C. Vectorize
Say you can use tensor libraries now. Start by using any library methods you need to simplify your function above, including built-ins that implement argmax already.
Solution. This possibly the most commonly called operation, so I'd expect you to be able to write this without thinking much.
# pytorch offers argmax as a method on the tensor class. Note you have to
# pass in the correct dim=-1 to generalize to N-D tensors.
assert torch.tensor([[1, 4.1], [-2, 3.5]]).argmax(dim=-1) == [1, 1]
# numpy does not offer an argmax method. Instead, you have use the provided
# utilities functionally. Note `axis` vs. `dim`.
assert np.argmax(np.array([[1, 4.1], [-2, 3.5]]), axis=-1) == [1, 1]
Vectorize your program without library methods related to max or argmax.
Can I use x
method?
You can use any torch method except torch.max
and torch.argmax
. Ditto for numpy.
Solution.*It turns out that you need to both vectorize and make the code graph-able at the same time. See Tips for "graphifying".*
def argmax(x: torch.Tensor) -> torch.Tensor:
max_idxs = -torch.ones(x.shape[0])
max_vals = -torch.full(x.shape[0], float('inf'))
for idx in range(x.shape[1]):
cur_vals = x[:, idx]
max_idxs = torch.where(cur_vals > max_vals, idx, max_idxs)
max_vals = torch.where(cur_vals > max_vals, cur_vals, max_vals)
return max_idxs
Generalize this higher-dimensional tensors with arbitrarily many dimensions. Always take the argmax along the last dimension.
def argmax(x: torch.Tensor) -> torch.Tensor:
max_idxs = -torch.ones(x.shape[:-1])
max_vals = -torch.full(x.shape[:-1], float('inf')
for idx in range(x.shape[-1]):
cur_vals = x[..., idx]
max_idxs = torch.where(cur_vals > max_vals, idx, max_idxs)
max_vals = torch.where(cur_vals > max_vals, cur_vals, max_vals)
return max_idxs
💡
Most interviews would end here, if not earlier. This is question is better suited for AI-specific interviews because the non-AI parts (e.g., parts A and B) are less cumbersome.
Still with tensors that contain arbitrarily many dimensions, take the argmax along an arbitrary dimension the user passes to you.
def argmax(x: torch.Tensor, dim: int) -> torch.Tensor:
x = x.transpose(0, dim)
max_idxs = -torch.ones(x.shape[1:])
max_vals = -torch.full(x.shape[1:], float('inf')
for idx in range(x.shape[0]):
cur_vals = x[idx]
max_idxs = torch.where(cur_vals > max_vals, idx, max_idxs)
max_vals = torch.where(cur_vals > max_vals, cur_vals, max_vals)
return max_idxs
Now, take the argmax along all dimensions except the one the user provided.
def argmax(x: torch.Tensor, dim: int) -> torch.Tensor:
x = x.transpose(0, dim)
x = x.reshape(x.shape[0], -1)
max_idxs = -torch.ones(x.shape[0])
max_vals = -torch.full(x.shape[0], float('inf')
for idx in range(x.shape[1]):
cur_vals = x[:, idx]
max_idxs = torch.where(cur_vals > max_vals, idx, max_idxs)
max_vals = torch.where(cur_vals > max_vals, cur_vals, max_vals)
return max_idxs
D. Argmax during training
Say we immediately use the argmax index to select a row from a second matrix.
seq_len, n_experts, d_model = 128, 12, 64
X = torch.randn(seq_len, d_model)
W = torch.randn(n_exprts, d_model)
A = torch.randn(seq_len, n_experts)
idx = argmax(A, dim=1)
Y = X @ W[idx]
Would this code run if you ran it through a pytorch optimizer, differentiating the Y tensor with respect to A?
Yes, it would technically run, as you can technically backpropagate through just the rows in W that were selected.
If it runs, is it correct? If it doesn't run, why not?
Even though it runs, it's not correct, because we didn't actually backpropagate the indexing operation itself. We just backpropagated through whatever rows were selected. Generally speaking, indexing is not differentiable. For the same reason, algorithms like k-means can't be differentiated end to end as-is.
The most concrete way to see this, is to realize that if all indices were 0 let's say, only the first row of A would ever be updated by gradient descent. All of the other rows would never be updated, so clearly, something is wrong.
To answer succinctly, the solution is to use something like gumbel softmax, where you can smoothly approach a one-hot distribution towards the end of training.
To explain at a conceptual level, the fix is to instead use a linear combination of all options, instead of selecting only one. Then, you can gradually force that linear combination to look one-hot over time.
Regardless of your familiarity with gumbel softmax, a simple followup would be to implement gumbel softmax — tests how easily you can translate math into code.