pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.44k stars 451 forks source link

Migrate PyTorch/XLA's gradient checkpointing to upstream one #7024

Open JackCaoG opened 4 months ago

JackCaoG commented 4 months ago

🚀 Feature

Today PyTorch/XLA ask user to use its own version of the gradient checkpointing in https://github.com/pytorch/xla/blob/d1235858628417ed7abc0d61e6e9be50df3e1a87/torch_xla/utils/checkpoint.py#L145-L146 We should extend upstream's api instead of asking user to use our version.

Motivation

Upstream gradient checkpointing doesn't work because XLA's CSE(common subexpression elimation) pass will undo the gradient checkpointing. More details in https://github.com/pytorch/xla/issues/5766#issuecomment-1792913756 . As a result I copied the upstream checkpointing and add a optimization_barrier_ on inputs of backward recompute. This is bad because

  1. Our implementation get outdated very quickly
  2. It is difficult for user to discover our version of the gradient checkpointing

Pitch

I chatted with @soulitzer , there is a way to pass context manager to extend the gradient checkpointing behavior. @soulitzer even went ahead and wrote a draft

from torch.utils.weak import WeakTensorKeyDictionary
import contextlib
from torch.overrides import TorchFunctionMode
from torch.utils._pytree import tree_map_only
from torch.utils.checkpoint import checkpoint

class MarkInputsToRegion(TorchFunctionMode):
    def __init__(self, mark_fn):
        # tensor -> bool
        self.is_marked = WeakTensorKeyDictionary()
        self.mark_fn = mark_fn

    def __torch_function__(self, func, types, args=(), kwargs=None):
        if kwargs is None:
            kwargs = {}

        def mark(x):
            self.mark_fn(x)
            self.is_marked[x] = True

        tree_map_only(torch.Tensor, mark, (args, kwargs))
        out = func(*args, **kwargs)
        tree_map_only(torch.Tensor, mark, out)
        return out

def context_fn():
    def mark_fn(x):
        print("input to region: ", x)
    return contextlib.nullcontext(), MarkInputsToRegion(mark_fn)

# Test a tensor that is closed over
y = torch.tensor([2.], requires_grad=True)
x = torch.tensor([1.], requires_grad=True)

def func(x):
  # the output of this mul or this clone should not be wrapped
  out = x * y
  return out.clone()

out = checkpoint(func, x, context_fn=context_fn, use_reentrant=False)
out.sum().backward()

What we should verified is that optimization_barrier_ should only be applied on the input of the backward recompute, not to all the backwards. I think we should take above code and play with it and verified if we can use this approach to extend gradient checkpointing.

Also on top of the optimization_barrier_ we also do some pytorch/xla rng seed state management in https://github.com/pytorch/xla/blob/d1235858628417ed7abc0d61e6e9be50df3e1a87/torch_xla/utils/checkpoint.py#L151-L155 We should think about how to handle this part in the extension as well.

cc @alanwaketan @jonb377 @albanD

alanwaketan commented 4 months ago

This sounds exciting!

jonb377 commented 4 months ago

Awesome! One thing we may need to handle is autocast state with gradient checkpointing - the upstream restores state using device modules (e.g. torch.cuda or torch.cpu), and it fetches the device module using getattr(torch, device) which won't work out-of-the-box for us.

We can probably just extend the _get_device_module logic in the upstream to support torch_xla.

albanD commented 4 months ago

@jonb377 we should update the checkpoint code to use the new device-generic API for amp here https://github.com/pytorch/pytorch/pull/124479

miladm commented 1 month ago

cc @tengyifei