Tips for "graphifying"
There isn't an official name for this process, but many times on the job, you'll find yourself converting from "eager mode" to "graph mode".
- In eager mode, each operation executes immediately as they are called. This means that tensor values are readily available, which is great for debugging. However, every time you access tensor values, those values must be moved from C++ to Python. This incurs overhead and slows down the program. This is what frameworks like PyTorch do — great for developer productivity but bad for runtime speed.
- In graph mode, your code specifies a computation graph, which you can later provide inputs for and execute. Only the output tensor is returned, so the intermediate tensor values are inaccessible1. This means that during execution, there is no Python overhead, and the program generally runs faster. This is what frameworks like JAX2 use — bad for developer productivity but good for runtime speed.
PyTorch and numpy are eager-mode-first, but you can benefit from graph mode if you choose to. In particular, PyTorch can use CUDA graphs or Torchscript's JIT trace. Alternatively, you might just switch frameworks entirely and use a graph-mode-first framework like JAX. Either way, you'll often find yourself converting from eager mode to graph mode, for its speed.
There are a few kinds of logic to remove when converting into a static graph:
- No conditional execution is allowed. This means that you can't have input-dependent tensors in an if statement's condition, a while loop's condition, or the for loop's start, step, or end statements. For example, you could not write
if x.sum() > 0:
, wherex
is an input to the computation graph. - For CUDA graphs, you additionally must have static sizes for all tensors. This means you cannot have input-dependent sizes such as
x[x > 0]
3. All tensor sizes must be fully determined when the graph is traced.
Replace if with where
Let's see a hello world example of un-traceable code.
import torch
class NonTraceableModel(torch.nn.Module):
def forward(self, x):
if x.sum() > 0:
return x * 2
return x + 2
model = NonTraceableModel()
x = torch.randn(3)
traced_fn = torch.jit.trace(model, x)
The trace call on the last line simply records all PyTorch operations that were executed during the tracing process. This means that at tracing time, the value of x determines what computation graph is produced; at runtime, the value of x does not determine which if block is run. Here's why that's a problem. Let's say we were able to trace the model above.
>>> model(torch.tensor(1)) # since x.sum() > 0, we run x * 2
torch.tensor(2) # what we expect
>>> traced_model = torch.jit.trace(model, torch.tensor(0)) # forces x + 2
>>> traced_model(torch.tensor(1)) # runs x + 2 :(
torch.tensor(3) # does not match expectations!
Luckily, PyTorch recognizes this problem and issues a warning. The problem is precisely what the warning says: If the input x
changes, the program will not generalize correctly!
graphvseager.py:5: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
return x * 2 if x.sum() > 0 else x / 2
To fix this, use a PyTorch operation torch.where
that has conditional logic built-in.
import torch
class TraceableModel(torch.nn.Module):
def forward(self, x):
return torch.where(x.sum() > 0, x * 2, x + 2)
model = TraceableModel()
x = torch.randn(3)
traced_fn = torch.jit.trace(model, x)
💡
Note: In the above code, both branches of the if statement are evaluated, so
x * 2
andx / 2
are both computed. This could be a problem if both branches are expensive to compute. In a static computation graph, this is largely unavoidable unfortunately.
Now, the traced model behaves as expected!
>>> traced_fn(torch.tensor(1))
torch.tensor(2)
>>> traced_fn(torch.tensor(0))
torch.tensor(2)
Use fixed-size masks
Before jumping in, let's see what the "Hello World" CUDA graph code would look like.
import torch
graph = torch.cuda.CUDAGraph()
stream = torch.cuda.Stream()
# Allocate all tensors BEFORE graph capture
x = torch.randn(3, device="cuda")
w = torch.randn(3, 3, device="cuda")
def model(w, x):
return w @ x
# Warm-up run to initialize operations
with torch.cuda.stream(stream):
y = model(w, x)
# Capture Graph
with torch.cuda.graph(graph):
y = model(w, x)
# Replay Graph
graph.replay()
print(y)
Now that we've seen boilerplate for CUDA graphs, let's see some non-CUDA-graph-able code to start. Let's say we want to re-arrange x so that all positive values are placed first. We want to preserve the order of the input otherwise.
def model(x, w):
x = torch.cat([x[x > 0], x[x <= 0]])
return x @ w
Notice that x[x > 0]
can have different sizes, depending on the contents of x
. As a result, this code is no good. Luckily, we have a workaround for this. See the following exploration, which presents one way of circumventing this issue.
>>> x = torch.rand(3) * 2 - 1
>>> x
tensor([ 0.1900, -0.7620, -0.1896])
>>> mask = x <= 0 # 1 if positive, 0 if negative
>>> mask
tensor([False, True, True])
>>> indices = torch.argsort(mask, stable=True)
>>> indices
tensor([0, 1, 2])
>>> x = x[indices]
>>> x
tensor([ 0.1900, -0.7620, -0.1896])
>>> torch.cat([x[x > 0], x[x <= 0]]) # equivalent!
tensor([ 0.1900, -0.7620, -0.1896])
This is a hack, but it achieves our original goal using static-sized tensors. In short, we use the fact that 0 is less than 1, and ask PyTorch to perform a stable sort — effectively, placing all positive values before all negative values. Let's plug this back into our script.
def model(w, x):
mask = x <= 0
indices = torch.argsort(mask, stable=True)
x = x[indices]
return w @ x
And ta-da!
You'd be surprised at how many random functions you can no longer use. For example, you can't use torch.nonzero()
anymore, because it returns different numbers of coordinates based on how many truth-y values there are. Just like here, you can usually get by with masks instead. For example, take the following simple fix.
def model(x, w):
indices = (x == 0).nonzero()
x[indices] = 1e-4
return w @ x
You can replace this with a mask instead.
def model(x, w):
x[x == 0] = 1e-4
return w @ x
Again, ta-da!
-
There are exceptions to this. JAX has the ability to encode introspection tools in the graph at construction time, via jax.debug. ↩
-
You may be wondering: What about Tensorflow or MXNet? These are dying frameworks. Use PyTorch or JAX instead. ↩
-
For example, say
x = [0, 0, 1]
, thenx[x > 0]
would have size 1. However, ifx = [1, 1, 1]
, thenx[x > 0]
would have size 3. This is not allowed. ↩