Tips for Vectorization
"Vectorizing" just means "removing loops". Instead of looping manually, execute a single, large batched operation that is well optimized. Here's a simple example, where we multiple a matrix by a scalar.
import torch
A = torch.rand((16, 16))
scalar = torch.rand(1)
for i in range(A.shape[0]):
for j in range(A.shape[1]):
A[i, j] *= scalar
This can be simplified to just a single line.
A *= scalar
However, not all vectorizations are this trivial to recognize.
Use advanced indexing
Say I have a covariance matrix. How would I forcibly set the diagonal to be all ones? Here is the non-vectorized program.
A = torch.rand((16, 16))
dim = min(A.shape) # handle non-square matrices
for i in range(dim):
A[i, i] = 1
To fix this, we can simply assemble all of the coordinates in advance, then pass them into the slice operator all at once.
A = torch.rand((16, 16))
dim = min(A.shape) # handle non-square matrices
A[range(dim), range(dim) = 1
However, with a quick search or prompt, you'll realize that PyTorch actually has a convenience function available for us. This won't always be the case.
A.fill_diagonal_(1) # The underscore generally means the op is in-place
Next, how would I forcibly set values to mirror across the diagonal? Again, here is the non-vectorized program.
A = torch.rand((16, 16))
assert A.shape[0] == A.shape[1]
for i in range(A.shape[0]):
for j in range(i): # do not want to do redundant work (e.g., A.shape[1])
A[i, j] = A[j, i]
One option is to simply assemble all of the indices.
💡
Note: One easy way to speed up a program is to simply assemble all of the indices in advance and perform one advanced index, instead of many simple indices.
ys, xs = [], []
for i in range(A.shape[0]):
for j in range(i):
ys.append(i)
xs.append(j)
A[ys, xs] = A[xs, ys]
However, even the index-generating loop can be vectorized. The following is more lines of logic, but especially for large N, and especially if the tensors are created directly on GPU, this code will run way faster.
indices = torch.arange(A.shape[0])
y, x = torch.meshgrid(indices, indices, indexing='ij')
mask = y < x
A[y[mask], x[mask]] = A[x[mask], y[mask]]
Funnily enough, like before, we actually have a PyTorch built-in function that does this for us.
y, x = torch.triu_indices(N, N, offset=1)
A[y, x] = A[x, y]
Finally, how would you zero out values more than k positions off of the diagonal?Fortunately, since we now know about the existence of triu_indices
, this becomes fairly simple.
y, x = torch.triu_indices(N, N, offset=k)
A[y, x] = A[x, y] = 0
Additionally, we want to preserve the first column of values.
y, x = torch.triu_indices(N, N, offset=k)
y, x = y[x > 0], x[x > 0] # drop all indices for the first column (x = 0)
A[y, x] = A[x, y] = 0
This isn't a completely made-up problem. This is actually how attention coefficients values are modified — more or less — in the paper Efficient Streaming Language Models with Attention Sinks by Xiao et al.
Broadcast tensors
This happens all the time and is very important to get used to. The number one rule for broadcasting is that both tensors must have the same number of dimensions. As a result, you'll often need code like the following:
- To add a dimension to the start of the tensor, use
A[None]
orA.unsqueeze(0)
. - To add a dimension to the end of the tensor, use
A[..., None]
orA.unsqueeze(-1)
. - You can also insert multiple dimensions at once, such as
A[None, None]
, which will add two preceding dimensions orA[None, ..., None]
, which will add a dimension at the start and one at the end.
Say you're taking the dot product of one vector with every row of a matrix. You could start off with the for loop implementation.
import torch
N, D = 16, 8
A = torch.rand(N, D)
x = torch.rand(D)
y = torch.empty(N
for i in range(N):
y[i] = A[i] @ x
Alternatively, we can make sure to add a dimension in the right spot, then broadcast.
A @ x[:, None] # (N, D) x (D, 1) -> (N, 1)
You could also use this idea to take the outer product of two vectors.
N = 16
a = torch.rand(N)
b = torch.rand(N)
outer = a[None] * b[:, None]
One common application is to normalize every row a tensor by that row's magnitude.
A = torch.rand(16, 8)
magnitude = torch.norm(A, dim=1) # (16,)
A = A / magnitude[:, None] # (16, 8) / (16, 1)