pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.44k stars 452 forks source link

Autograd discrepancy in `nn.Linear` (`torch.nn.functional.linear`) between native PyTorch and PyTorch/XLA #3811

Open ronghanghu opened 2 years ago

ronghanghu commented 2 years ago

🐛 Bug

There seems to be a discrepancy (in addition to https://github.com/pytorch/xla/issues/3718) in how torch.nn.Linear (torch.nn.functional.linear) is implemented and dispatched between the native PyTorch and PyTorch/XLA. In particular, the XLA version of linear doesn't refer to the weight tensors in its backward pass.

This discrepancy makes the PyTorch autograd behaves differently on linear layers (which is arguably the most important layer in neural networks) between native PyTorch and PyTorch/XLA, on issues such as gradient checkpointing, FSDP, and other cases.

To Reproduce

Consider the following simple example, where we try to explicitly release the nn.Linear's weight parameter (that should be saved as a reference by autograd) via l1.weight.data = l1.weight.data.new_zeros(1) before the backward pass -- this should crash the program:

import os
os.environ["XLA_IR_DEBUG"] = "1"
import torch
import torch_xla
import torch_xla.core.xla_model as xm

use_xla = True
# use_xla = False

# model
device = xm.xla_device() if use_xla else torch.device("cpu")
l1 = torch.nn.Linear(512, 256, device=device, bias=False)

# forward
x = torch.ones(22, 512, device=device, requires_grad=True)
y = l1(x)
loss = y.sum()

# backward
l1.weight.data = l1.weight.data.new_zeros(1)  # try releasing the linear weight
loss.backward()

if use_xla:
    ir_txt = torch_xla._XLAC._get_xla_tensors_text([loss, x.grad])
    print(f"forward & backward to compute loss & x.grad: {ir_txt}")
    xm.mark_step()

in native PyTorch (CPU)

If one uses use_xla = False, then the code above gives a crash on CPU with the error below -- which is expected, since we're freeing the weight parameter that's required for backward pass via l1.weight.data = l1.weight.data.new_zeros(1), i.e. it's now a scalar:

Traceback (most recent call last):
  File "/home/ronghanghu/workspace/xla_fsdp_clean_up/test/test_linear_bwd.py", line 21, in <module>
    loss.backward()
  File "/usr/local/lib/python3.8/dist-packages/torch/_tensor.py", line 396, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "/usr/local/lib/python3.8/dist-packages/torch/autograd/__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
RuntimeError: output with shape [1] doesn't match the broadcast shape [256, 512]

This means that in the native PyTorch, the backward pass of nn.Linear (based on torch.nn.functional.linear) is actually referring to the weight parameter l1.weight -- a tensor whose reference handle is saved during the forward pass and then referred to in the backward pass. This is the expected behavior of nn.Linear.

Native PyTorch on GPUs (CUDA) could also reproduce this expected crash behavior.

in PyTorch/XLA (v3-8 TPU)

On the other hand, if one uses use_xla = True, then the code (magically) could still run on v3-8 TPU with PyTorch/XLA, and prints the following IR

forward & backward to compute loss & x.grad: IR {
  %0 = s64[] xla::device_data(), location=kaiming_uniform_@init.py:412, device=TPU:0
  %1 = s64[] prim::Constant(), location=kaiming_uniform_@init.py:412, value=214013
  %2 = s64[] aten::mul(%1, %0), location=kaiming_uniform_@init.py:412
  %3 = s64[] prim::Constant(), location=kaiming_uniform_@init.py:412, value=2531011
  %4 = s64[] aten::add(%3, %2), location=kaiming_uniform_@init.py:412
  %5 = f32[] xla::device_data(), location=kaiming_uniform_@init.py:412, device=TPU:0
  %6 = f32[] xla::device_data(), location=kaiming_uniform_@init.py:412, device=TPU:0
  %7 = f32[256,512]{1,0} aten::uniform(%6, %5, %4), location=kaiming_uniform_@init.py:412
  %8 = f32[512,256]{0,1} aten::permute(%7), location=forward@linear.py:114, dims=(1, 0)
  %9 = f32[] prim::Constant(), location=<module>@test_linear_bwd.py:15, value=1
  %10 = f32[22,512]{1,0} aten::expand(%9), location=<module>@test_linear_bwd.py:15, size=(22, 512)
  %11 = f32[22,256]{1,0} aten::mm(%10, %8), location=forward@linear.py:114
  %12 = f32[] aten::sum(%11), location=<module>@test_linear_bwd.py:17, dimensions=(0, 1), keep_reduced_dimensions=0, dtype=6, ROOT=0
  %13 = f32[512,256]{0,1} aten::permute(%7), dims=(1, 0)
  %14 = f32[256,512]{1,0} aten::permute(%13), dims=(1, 0)
  %15 = f32[] prim::Constant(), location=_make_grads@__init__.py:68, value=1
  %16 = f32[22,256]{1,0} aten::expand(%15), size=(22, 256)
  %17 = f32[22,512]{1,0} aten::mm(%16, %14), ROOT=1
}

This is weird because we're replacing the weight parameter's IR via l1.weight.data = l1.weight.data.new_zeros(1). If the autograd in this case faithfully looks up the previously saved handle of l1.weight, it should then try to use this new scalar-tensor and give the same crash as in the CPU case above.

However, the fact that it doesn't crash in PyTorch/XLA means that either torch.nn.functional.linear is implemented or dispatched differently in PyTorch/XLA, or that autograd traces XLA tensors somehow differently.

In several cases such as gradient checkpointing or FSDP, their mechanisms to save parameters rely on the exact behavior of how autograd works on certain layers or functionals (and arguably nn.Linear is the most important op, especially in transformers). This discrepancy (that XLA's torch.nn.functional.linear doesn't use its weight input during the backward pass, but perhaps used some other saved intermediate tensors) could break gradient checkpointing and FSDP implementation.

(For example, in FSDP, the mechanism to free full parameters rely on the behavior that the backward pass refers to the weight parameter itself, rather than another intermediate tensor that cannot be manipulated and freed by the FSDP wrapper class.)

Environment

Additional context

This issue was the root cause of another FSDP-related issue reported by @hjm-aws. It seems also to be related to the bfloat16 cast issue in the backward pass as mentioned in https://github.com/pytorch/xla/issues/3718.

cc: @JackCaoG @bdhirsh @hjm-aws @alanwaketan

ronghanghu commented 2 years ago

A workaround to address both this issue and https://github.com/pytorch/xla/issues/3718 is to add the following snippet before the model definition code (in distributed training, it needs to be added to each spawned training process):

import torch
from xla_patched_linear import xla_patched_linear
torch.nn.functional.linear = xla_patched_linear

where xla_patched_linear is in a simple module in https://gist.github.com/ronghanghu/d82ede74c434f8c12ae3ffb65ec84b45 that explicitly defines the autograd behavior of torch.nn.functional.linear.

ronghanghu commented 2 years ago

I guess the behavior difference above could be related to the different implementations of aten::t between native PyTorch and PyTorch/XLA, which is used in weight.t() in the torch.nn.functional.linear's underlying aten:linear implementation.

In PyTorch/XLA, aten::t seems to be dispatched differently as an XLA op in https://github.com/pytorch/xla/blob/b3342319e96a0becd139019620d8665605b78475/torch_xla/csrc/aten_xla_type.cpp#L3058-L3062, which might cause PyTorch autograd of aten::linear to directly save a reference to the transpose tensors weight.t() during the backward pass, instead of actually holding a reference to the weight tensor itself and directly using it in backward. I guess this is probably also the underlying cause of the cast back to fp32 around the permute op in the backward pass (likely coming from weight.t()'s gradients) in https://github.com/pytorch/xla/issues/3718#issuecomment-1191005431?

(This is just my speculation -- I'm not an expert in autograd)

miladm commented 2 years ago

Thanks for the detailed repro @ronghanghu. We will look into this issue and circle back.

ronghanghu commented 2 years ago

A follow-up on this: the underlying cause here seems to be that autograd treats .t() + .mm() differently between native PyTorch and PyTorch/XLA.

import os
os.environ["XLA_IR_DEBUG"] = "1"
import torch
import torch_xla
import torch_xla.core.xla_model as xm

# use_xla = True
use_xla = False

device = xm.xla_device() if use_xla else torch.device("cpu")

# forward
x = torch.ones(22, 32, device=device, requires_grad=True)
w = torch.ones(16, 32, device=device, requires_grad=True)

# this is essentially what `torch.nn.functional.linear` does
w_t = w.t()
y = x.mm(w_t)

loss = y.sum()

# backward
w.data = w.data.new_zeros(1)  # try releasing the linear weight
loss.backward()

if use_xla:
    ir_txt = torch_xla._XLAC._get_xla_tensors_text([loss, x.grad])
    print(f"forward & backward to compute loss & x.grad: {ir_txt}")
    xm.mark_step()

Similar to the original example, this example above also crashes with use_xla = False (native PyTorch) but runs OK in use_xla = True (PyTorch/XLA). I guess the deeper cause is that aten::t is implemented differently between native PT and PT/XLA as mentioned above in https://github.com/pytorch/xla/issues/3811#issuecomment-1200556731.

Not sure if this autograd discrepancy should be considered a bug -- I guess it should be seen as a bug since it has likely also caused the unnecessary cast to float32 in https://github.com/pytorch/xla/issues/3718#issuecomment-1191005431?

This discrepancy (an nn.Linear model doesn't actually use its .weight parameter in its backward pass in PT/XLA) breaks a few use cases such as selective activation checkpointing or FSDP -- although this can be resolved by the patch above in https://github.com/pytorch/xla/issues/3811#issuecomment-1200486615

JackCaoG commented 2 years ago

Thanks ronghang, I will try to take a look soon.

ronghanghu commented 2 years ago

Thanks ronghang, I will try to take a look soon.

Thanks, Jack! For now, we can patch this discrepancy via torch.nn.functional.linear = xla_patched_linear in https://gist.github.com/ronghanghu/d82ede74c434f8c12ae3ffb65ec84b45, so it's not blocking us at this moment.

(And https://github.com/pytorch/xla/pull/3830 tries to automatically turn on this patch in FSDP as a workaround to this).

JackCaoG commented 2 years ago

Right, ideally we need to at least understand why and have a plan to fix it. Having to patch linear layer sounds bad 😄

gkroiz commented 1 year ago

Hi @ronghanghu, I wanted to confirm whether the autograd discrepancy in nn.Linear() is specific to FSDP or applies to all instances of nn.Linear() when using PyTorch/XLA. cc @JackCaoG

ronghanghu commented 1 year ago

Hi @ronghanghu, I wanted to confirm whether the autograd discrepancy in nn.Linear() is specific to FSDP or applies to all instances of nn.Linear() when using PyTorch/XLA. cc @JackCaoG

Hi @gkroiz, this occurred to all instances of nn.Linear() when using PyTorch/XLA, as shown in the original example in this issue above.