facebookresearch / fairscale

PyTorch extensions for high performance and large scale training.
Other
3.17k stars 279 forks source link

[FSDP] issue with different ways of doing activation checkpointing #491

Open min-xu-ai opened 3 years ago

min-xu-ai commented 3 years ago

unit test are being added here: https://github.com/facebookresearch/fairscale/pull/476

but I don't have a big picture of what's needed to be fixed yet. Some observations:

  1. we know that if root FSDP is empty with zero params and checkpoint is used, backward can finish too early
  2. I have seen issues around checkpoint(FSDP()) vs. FSDP(checkpoint()). It is only the former work in vissl.
  3. I have seen issue that if we checkpoint the AnyStage and let the inner blocks be FSDP wrapped in vissl, we also crash.
  4. In the past, we have seen issues around for-loop in the forward pass doesn't play well with checkpointing. (i.e. the same module forwarded multiple times in the forward pass.)

cc: @prigoyal @myleott

I am going to document a list of issues we found so far for tracking purpose.

anj-s commented 2 years ago

@min-xu-ai Is there an action item to follow up here? Which of the above listed issues still occur and what is the priority to fix them?

min-xu-ai commented 2 years ago

I think the backward firing cases is improved a lot since then. @zhaojuanmao

Different wrapping order may still have issues since we don't test all the combinations exhaustively.

anj-s commented 2 years ago

@min-xu-ai Does it make sense to take any of the above issues or should we figure out the larger issue behind the behaviors mentioned above? It seems like checkpoint_wrapper may not have the most consistent behavior with FSDP.