facebookresearch / xformers

Hackable and optimized Transformers building blocks, supporting a composable construction.
https://facebookresearch.github.io/xformers/
Other
8.69k stars 621 forks source link

[fix] Fix the activation checkpointing when using SwiGLUPackedFusedOp #1127

Closed warpuv closed 3 weeks ago

warpuv commented 1 month ago

According to the docs (https://pytorch.org/docs/stable/autograd.html#torch.autograd.Function) forward() method should not be called directly, apply() method have to be used instead. After removing forward call, activation checkpointing starts working.

What does this PR do?

Fixes #1126

Before submitting

PR review

Anyone in the community is free to review the PR once the tests have passed. If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

pansershrek commented 1 month ago

@danthe3rd Can you review this PR please? This change fixes the integration with FSDP + activation_checkpointing

warpuv commented 1 month ago

Thanks for the PR!

It seems you're basically trying to revert the change introduced in #706. I don't have full context on that old PR, but I wonder whether there are ways to achieve both goals at once.

For example, we could take your PR and also replace the check if (ctx == nullptr) above with if (x.required_grad). I believe this should get us everything we want?

@lw thank you for your review and suggestions. The “if” conditional on x.requires_grad changes the behavior of the recomputation of the forward since x.requires_grad has different value as it is detached on recomputation phase, and in turn save_for_backward is not called. I have pushed an alternative solution using torch::GradMode::is_enabled(), I believe both goals are achieved this way.

warpuv commented 1 month ago

Dear @zyan0, do you have any objections to this change?

pansershrek commented 4 weeks ago

@lw, @zyan0 Is this solution ok or have you any objections to this change?

danthe3rd commented 4 weeks ago

Hi, I believe the first version of the fix https://github.com/facebookresearch/xformers/pull/1127/commits/46d282341cf978f7580c61d1e315702ff763f540 was simpler. Can you revert to that one? Then we can merge.

warpuv commented 4 weeks ago

@lw my guess that it was planned to optimize the inference path but this was never done. Current implementation of dual_gemm_silu_identity_mul produces two additional intermediate tensors which are not part of the output, but are used in the backward pass. I can imagine implementing the forward pass without producing these 2 intermediate tensors to improve speed in the inference mode. @danthe3rd do you still think the first solution is better? If so, I will revert to it. In the second solution it is possible to optimize the inference path sometime in the future.

lw commented 4 weeks ago

I discussed it with @danthe3rd and we agree that it's ok to undo the separation of apply and forward that was introduced in #706. If we ever need it again we will evaluate other options.

warpuv commented 4 weeks ago

@lw Ok, I've uploaded the first version. Please check it.