Closed warpuv closed 3 weeks ago
@danthe3rd Can you review this PR please? This change fixes the integration with FSDP + activation_checkpointing
Thanks for the PR!
It seems you're basically trying to revert the change introduced in #706. I don't have full context on that old PR, but I wonder whether there are ways to achieve both goals at once.
For example, we could take your PR and also replace the check
if (ctx == nullptr)
above withif (x.required_grad)
. I believe this should get us everything we want?
@lw thank you for your review and suggestions.
The “if” conditional on x.requires_grad changes the behavior of the recomputation of the forward since x.requires_grad has different value as it is detached on recomputation phase, and in turn save_for_backward is not called.
I have pushed an alternative solution using torch::GradMode::is_enabled()
, I believe both goals are achieved this way.
Dear @zyan0, do you have any objections to this change?
@lw, @zyan0 Is this solution ok or have you any objections to this change?
Hi, I believe the first version of the fix https://github.com/facebookresearch/xformers/pull/1127/commits/46d282341cf978f7580c61d1e315702ff763f540 was simpler. Can you revert to that one? Then we can merge.
@lw my guess that it was planned to optimize the inference path but this was never done.
Current implementation of dual_gemm_silu_identity_mul
produces two additional intermediate tensors which are not part of the output, but are used in the backward pass. I can imagine implementing the forward pass without producing these 2 intermediate tensors to improve speed in the inference mode.
@danthe3rd do you still think the first solution is better? If so, I will revert to it. In the second solution it is possible to optimize the inference path sometime in the future.
I discussed it with @danthe3rd and we agree that it's ok to undo the separation of apply
and forward
that was introduced in #706. If we ever need it again we will evaluate other options.
@lw Ok, I've uploaded the first version. Please check it.
According to the docs (https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function) forward() method should not be called directly, apply() method have to be used instead. After removing forward call, activation checkpointing starts working.
What does this PR do?
Fixes #1126
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged.