KEMBAR78
Autodiff: checkpointing strategy · Issue #936 · tracel-ai/burn · GitHub
Skip to content

Autodiff: checkpointing strategy #936

@louisfd

Description

@louisfd

In autodiff, we should have a checkpointing strategy for better memory consumption (see for instance https://www-sop.inria.fr/tropics/papers/DauvergneHascoet06.pdf) .

Currently, for most operations run in the forward pass, a state will be saved for the backward pass. The state often consists of a few tensors, so it is needless to say that they accumulate and use a lot of memory.

A way to use less memory for the backward pass would be to, instead of having kept the state in memory, recompute the forward pass of the operation to re-obtain the state, just before computing its backward pass. This will lead to more computations, but less memory consumption.

This leads to a tradeoff between compute and memory. Some operations, like matrix multiplication, are "compute-bound", meaning the bottleneck is generally the actual computations, while some, such as element-wise multiplication, are "memory-bound", meaning the computation is actually so simple that the moving of data is the bottleneck.

For compute-bound operations, it is better to keep the state than to recompute. But for memory-bound operations, we would benefit from recomputing.

Also, if many operations are tagged as memory-bound, this will greatly help fusing kernels with Burn-Fusion, which will be able to fuse kernels transparently during the backward pass.

The current strategy, where every state is saved, would simply become a specific case of the new strategy, where everything is considered compute-bound.

Metadata

Metadata

Assignees

Labels

performanceAnything related to performancevery hardReserved for framework experts: Extremely challenging.

Type

No type

Projects

Status

Done

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions