Deconstructing Torch Compile
Aug 16, 2025
Breaking down how torch.compile() optimizes large models
Torch.compile is a high-level operation that takes high-level Pytorch functions like nn.Linear()
and reduces the functions down to primitive PrimTorch operations before optimizing them using Triton Kernels.
TorchDynamo
TorchDynamo is a graph-capture tool that takes Python code and extracts a graph (FX IR) of tensor operations. First, PyTorch runs in forward mode, executing the Python code. The Python bytecode is intercepted and tracked. Any tensor op is recorded and added to a graph. This is done as the code is encountered, in an “eager” way. Sometimes, operations cannot be captured. An example of this is the following code
def f(x):
if random.random() > 0.5:
return torch.sin(x)
else:
return torch.cos(x)
TorchDynamo cannot create a graph node here, so it passes the graph up till that point to the compiler.
TorchDynamo Optimizations
In general, TorchDynamo wants static values and code at runtime. In the previous random example, we can replace it with the following code, which moves randomness and branching to tensor ops, allowing them to be compiled by TorchDynamo. Branching is also a common operation that can be optimized using where()
, which uses branchless programming to evaluate both branches and pick one based on the condition.
# Previous random generation
if random.random() > 0.5:
x = torch.sin(x)
# Better Solution
r = torch.rand(()) # Defines a random tensor
x = torch.where(r > 0.5, torch.sin(x), x)
Another common optimization is using stack()
to stack tensors and then applying aggregate functions to do computations rather than using loops.
The final product of TorchDynamo is an FXGraph, which is a compact IR for compilers to further process
AOTAutograd
If you’re also training the model and require gradients, then after Pytorch runs, you also need to trace the backwards path to run backpropogation. A naive solution would be to allow Pytorch to compute a forward pass and then do the reverse operation at runtime for the backward pass. However, AOTAutograd (Ahead-of-time autograd) minimizes overhead at runtime by generating all the backward functions ahead of time. Every PyTorch function has a backwards counterpart that is used by AOTAutograd. It is important to note that AOTAutograd is completely static, just like the forward graph.
Another important benefit of this is activation rematerialization. During backwards passes, activations must be stored so that the weights can be modified. With AOTAutograd, it is very cheap to recompute the activations, which saves GPU memory. Computing the static graph beforehand allows for further optimizations such as dead code removal and kernel fusion.
PrimTorch
As the machine learning field has grown, so has the number of functions PyTorch has needed to support. As a result, PrimTorch was created – a set of 250 primitive operations that all high-level PyTorch functions compile down to. By the end of this, we are left with a forward and backward graph of primitive operations.
Backends
All that remains is to optimize this code to run on a GPU. This can be done through several torch backends. The default Pytorch backend is TorchInductor. This process of optimization starts with kernels written in Triton for each primitive operation. Then TorchInductor fuses any kernels that it can to further optimize the computations. Finally, it maps and schedules these kernels to run on the GPU. It tiles memory, optimizes cache access, and coalesces memory as much as possible. For loops in matmuls, it unrolls the loops to reduce branching. Finally, after all these optimizations comes a large kernel that only needs to load data once, maximizes cache usage, and keeps threads and streaming multiprocessors fully utilized.