-
Notifications
You must be signed in to change notification settings - Fork 722
Description
In autodiff, we should have a checkpointing strategy for better memory consumption (see for instance https://www-sop.inria.fr/tropics/papers/DauvergneHascoet06.pdf) .
Currently, for most operations run in the forward pass, a state will be saved for the backward pass. The state often consists of a few tensors, so it is needless to say that they accumulate and use a lot of memory.
A way to use less memory for the backward pass would be to, instead of having kept the state in memory, recompute the forward pass of the operation to re-obtain the state, just before computing its backward pass. This will lead to more computations, but less memory consumption.
This leads to a tradeoff between compute and memory. Some operations, like matrix multiplication, are "compute-bound", meaning the bottleneck is generally the actual computations, while some, such as element-wise multiplication, are "memory-bound", meaning the computation is actually so simple that the moving of data is the bottleneck.
For compute-bound operations, it is better to keep the state than to recompute. But for memory-bound operations, we would benefit from recomputing.
Also, if many operations are tagged as memory-bound, this will greatly help fusing kernels with Burn-Fusion, which will be able to fuse kernels transparently during the backward pass.
The current strategy, where every state is saved, would simply become a specific case of the new strategy, where everything is considered compute-bound.
Metadata
Metadata
Assignees
Labels
Type
Projects
Status