pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.38k stars 427 forks source link

Unexpected tensor behavior when moving module to XLA device #6861

Open YangFei1990 opened 3 months ago

YangFei1990 commented 3 months ago

🐛 Bug

Not sure if it is a bug or expected behavior, when doing module.to(xm.xla_device()) it will create new parameter tensors instead of modify the tensor in place.

To Reproduce

Running

model = torch.nn.Linear(4,4)
my_param = list(model.parameters())
model.to(xm.xla_device())
print([p.device for p in model.parameters()])
print([p.device for p in my_param])

We see

[device(type='xla', index=1), device(type='xla', index=1)]
[device(type='cpu'), device(type='cpu')]

Expected behavior

However if we run the same test with CUDA

Python 3.10.8 | packaged by conda-forge | (main, Nov 22 2022, 08:26:04) [GCC 10.4.0] on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> model = torch.nn.Linear(4,4)
>>> my_param = list(model.parameters())
>>> model.cuda()
Linear(in_features=4, out_features=4, bias=True)
>>> print([p.device for p in model.parameters()])
[device(type='cuda', index=0), device(type='cuda', index=0)]
>>> print([p.device for p in my_param])
[device(type='cuda', index=0), device(type='cuda', index=0)]
>>>

It shows that when we move module to cuda device it will modify the tensor in place, while on XLA device it will create new tensors.

This behavior could be very error-prone for users. For example

# This would work
model = Model()
mode.to(xm.xla_device())
opt = Optimizer(model.parameters(), ...)

# This will be wrong
model = Model()
opt = Optimizer(model.parameters(), ...)
mode.to(xm.xla_device())

Environment

@JackCaoG @jeffhataws @ezyang

JackCaoG commented 3 months ago

From the nn module's implementation it is because we fall into the else branch in https://github.com/pytorch/pytorch/blob/b5bef9bbfd3aa963645d915c0452f5e342b60039/torch/nn/modules/module.py#L824-L826

            elif p_should_use_set_data:
                param.data = param_applied
                out_param = param
            else:
                assert isinstance(param, Parameter)
                assert param.is_leaf
                out_param = Parameter(param_applied, param.requires_grad)
                self._parameters[key] = out_param

p_should_use_set_data is false because

bool _has_compatible_shallow_copy_type(const Tensor& self, const Tensor& from) {
  return self.unsafeGetTensorImpl()->has_compatible_shallow_copy_type(
      from.key_set());
}

From the code I think this seems to suggest we can not shallow copy XLATensorImpl into a CPU tensorImpl.

JackCaoG commented 3 months ago

@bdhirsh in case you have some insights.

ezyang commented 3 months ago

This is unfortunately expected, because you can't transmute a C++ TensorImpl into a C++ XLATensorImpl.

Actually, this is fixable now. What we do is add a stub in ExtraMetadata for XLA to hang its extra info, and then get rid of XLATensorImpl on XLA's side. This would make it possible to do an inplace move. I think you should seriously consider getting someone to do this.

JackCaoG commented 3 months ago

Took a look at https://github.com/pytorch/xla/blob/66ed39ba5fa6fb487790df03a9a68a6f62f2c957/torch_xla/csrc/tensor_impl.cpp and most of the code are boilerplate code except some sym size related logic. XLATensorImpl is a wrapper around the XLATensor, if we can put that in the ExtraMetadata field of the native TensorImpl I suppose this is doable.