Open kshitij12345 opened 4 days ago
In the module-based API, we deliberately clear the intermediate tensors to minimize memory usage (see https://github.com/NVIDIA/TransformerEngine/pull/509). For example, in _Linear.backward
:
https://github.com/NVIDIA/TransformerEngine/blob/56e0b351d0b7db6e622d3aa3eac6c6a1bf1ce4ab/transformer_engine/pytorch/module/linear.py#L587-L588
We should mark these autograd backwards with once_differentiable
for clarity.
Note that the functional API in https://github.com/NVIDIA/TransformerEngine/pull/707 does not manually deallocate intermediate tensors (the operators themselves do deallocate to achieve the same memory savings as the module-based API). Once that is merged, it should be better for this use-case.
This would be useful to support benchmarking just the backward pass.