ActivationCheckpointing#
- class vision_architectures.utils.activation_checkpointing.ActivationCheckpointing(fn_checkpoint_level, training_checkpoint_level)[source]#
Bases:
Module
This class is used to perform activation checkpointing during training. Users can set a level of checkpointing for each module / function in their architecture. While training, the module / function will be checkpointed if the training checkpoint level is greater than or equal to the checkpoint level set for the module / function.
A general guide of the Activation checkpointing levels in this repository:
Level 0: No checkpointing
Level 1: Single layers are checkpointed e.g. linear layer + activation, conv layer + dropout
Level 2: Small blocks are checkpointed e.g. residual blocks, attention blocks, MLP blocks
Level 3: Medium-sized modules are checkpointed e.g. transformer layers, decoder blocks
Level 4: Large modules are checkpointed e.g. groups of transformer layers, decoder stages
Level 5: Very large modules are checkpointed e.g. entire encoders, decoders etc.
- __init__(fn_checkpoint_level, training_checkpoint_level)[source]#
Initialize the ActivationCheckpointing class.
- Parameters:
fn_checkpoint_level (
int
) – Level at which the module / function should be checkpointedtraining_checkpoint_level (
int
) – Checkpointing level at which the model is being trained
Example
class MyModel(nn.Module): def __init__(self, training_checkpointing_level: int = 0): super().__init__() my_network = nn.Sequential( nn.Linear(784, 256), nn.ReLU(), nn.Linear(256, 10) ) self.activation_checkpointing_level2 = ActivationCheckpointing(2, training_checkpointing_level) def forward(self, x): y = self.activation_checkpointing_level2(self.my_network, x) return y
In this example, a
training_checkpointing_level
of greater than or equal to 2 will checkpointmy_network
during training. If it’s less than 2, the network will not be checkpointed.
- __call__(fn, *fn_args, use_reentrant=False, **fn_kwargs)[source]#
Checkpoint the module / function if the checkpointing level is greater than or equal to the training checkpoint level.
- Parameters:
fn (
Callable
) – The module / function to checkpointuse_reentrant (
bool
) – Passed on to torch.utils.checkpoint.checkpoint. Defaults to False.*fn_args – Arguments to pass to the module / function
**fn_kwargs – Keyword arguments to pass to the module / function
- Returns:
The checkpointed module / function if checkpointing is performed, else the module / function itself.