-
Notifications
You must be signed in to change notification settings - Fork 0
Open
Description
Activation checkpointing in PyTorch is a useful memory management technique which reduces the GPU memory usage on the backwards pass (it instead recalculates parts of the forward model on the fly). Blog post explaining more about this feature of PyTorch: https://medium.com/pytorch/how-activation-checkpointing-enables-scaling-up-training-deep-learning-models-7a93ae01ff2d
It would be neat to have this in Caskade as an option, where each module can have a checkpoint flag, which would make it so the entire module's forward call is recalculated on the fly during the backwards pass.
Metadata
Metadata
Assignees
Labels
No labels