I found there are some difference between fairseq checkpoint_wrapper and fairscale checkpoint_wrapper
Code
fairseq checkpoint_wrapper
def checkpoint_wrapper(m, offload_to_cpu=False):
"""
A friendlier wrapper for performing activation checkpointing.
Compared to the PyTorch version, this version:
- wraps an nn.Module, so that all subsequent calls will use checkpointing
- handles keyword arguments in the forward
- handles non-Tensor outputs from the forward
Usage::
checkpointed_module = checkpoint_wrapper(my_module, offload_to_cpu=True)
a, b = checkpointed_module(x, y=3, z=torch.Tensor([1]))
"""
# should I check whether original_forward has already been set?
assert not hasattr(
m, "precheckpoint_forward"
), "checkpoint function has already been applied?"
m.precheckpoint_forward = m.forward
m.forward = functools.partial(
_checkpointed_forward,
m.precheckpoint_forward, # original_forward
offload_to_cpu,
)
return m
fairscale checkpoint_wrapper
def checkpoint_wrapper(
module: nn.Module,
offload_to_cpu: bool = False,
) -> nn.Module:
"""
A friendlier wrapper for performing activation checkpointing.
Compared to the PyTorch version, this version:
- wraps an nn.Module, so that all subsequent calls will use checkpointing
- handles keyword arguments in the forward
- handles non-Tensor outputs from the forward
- supports offloading activations to CPU
Usage::
checkpointed_module = checkpoint_wrapper(my_module, offload_to_cpu=True)
a, b = checkpointed_module(x, y=3, z=torch.Tensor([1]))
To understand the benefits of checkpointing and the `offload_to_cpu` flag,
let's divide activations into 2 types: inner activations and outer
activations w.r.t. the checkpointed modules. The inner ones are saved
by activation checkpointing, the outer ones are saved by offload_to_cpu.
In terms of GPU memory savings:
- When inner ones are large in size and outer ones are small,
checkpointing helps a lot, offload_to_cpu may help a little.
- When inner ones are small and outer ones are large,
checkpointing helps little, offload_to_cpu helps a lot.
- When both inner and outer are large, both help and the
benefit is additive.
..Note::
The first and last layers are not likely to benefit from the `offload_to_cpu` flag
because (1) there are typically other references to the first layer's input, so
the GPU memory won't be freed; (2) the input to the last layer is immediately
used by the backward pass and won't result in memory savings.
Args:
module (nn.Module):
The module to be wrapped
offload_to_cpu (bool):
Whether to offload activations to CPU.
Returns:
(nn.Module):
Wrapped module
"""
# Patch the batchnorm layers in case there are any in this module.
patch_batchnorm(module)
# The use of weakref here is to prevent creating a ref cycle: m -> m.forward -> m.
# When such cycle exists, gc won't collect the module when the module is freed.
# That causes GPU memory to be leaked. See the unit test for how we catch that.
#
# We prefer this over a class wrapper since the class wrapper would have to
# proxy a lot of fields and methods.
module.forward = functools.partial( # type: ignore
_checkpointed_forward, type(module).forward, weakref.ref(module), offload_to_cpu
)
return module
What have you tried?
I want to know which checkpoint_wrapper is more safe? Shell we need to change the fairseq checkpoint_wrapper like fairscale checkpoint_wrapper
What's your environment?
fairseq Version (e.g., 0.12.2 or main):
PyTorch Version (e.g., 2.4.0+cu121)
OS (e.g., Linux): Ubuntu22.4
How you installed fairseq (pip, source): source
Build command you used (if compiling from source): pip install -e .
❓ Questions and Help
Before asking:
What is your question?
I found there are some difference between fairseq checkpoint_wrapper and fairscale checkpoint_wrapper
Code
fairseq checkpoint_wrapper
fairscale checkpoint_wrapper
What have you tried?
I want to know which checkpoint_wrapper is more safe? Shell we need to change the fairseq checkpoint_wrapper like fairscale checkpoint_wrapper
What's your environment?
pip
, source): source