pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.49k stars 22.2k forks source link

module.to(device) does not work under FakeTensorMode #119665

Open mksit opened 7 months ago

mksit commented 7 months ago

🐛 Describe the bug

I tried to create and convert a model to cuda in FakeTensorMode context, but it seems that cuda() didn't work and the model was still in cpu afterwards.

class MLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 20, bias=False)
        self.fc2 = nn.Linear(20, 30, bias=False)

    def forward(self, x):
        x = self.fc1(x)
        x = self.fc2(x)
        return x

x = torch.randn(10, 10, device="cuda")

with FakeTensorMode() as fake_mode:
    fake_x = fake_mode.from_tensor(x)
    model = MLP().cuda()
    # print(model.fc1.weight.device) # cpu
    y = model(fake_x)

It failed with

  File "/home/mankit/workspace/cpt/tests/test_analyzer.py", line 26, in test_trace_model
    y = model(fake_x)
  File "/mnt/data/mksit/anaconda3/envs/cpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/data/mksit/anaconda3/envs/cpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mankit/workspace/cpt/tests/test_analyzer.py", line 17, in forward
    x = self.fc1(x)
  File "/mnt/data/mksit/anaconda3/envs/cpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/mnt/data/mksit/anaconda3/envs/cpt/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/mnt/data/mksit/anaconda3/envs/cpt/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 116, in forward
    return F.linear(input, self.weight, self.bias)
  File "/mnt/data/mksit/anaconda3/envs/cpt/lib/python3.10/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/mnt/data/mksit/anaconda3/envs/cpt/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1392, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/mnt/data/mksit/anaconda3/envs/cpt/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1717, in dispatch
    return self.wrap_meta_outputs_with_default_device_logic(
  File "/mnt/data/mksit/anaconda3/envs/cpt/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1820, in wrap_meta_outputs_with_default_device_logic
    return tree_map(wrap, r)
  File "/mnt/data/mksit/anaconda3/envs/cpt/lib/python3.10/site-packages/torch/utils/_pytree.py", line 602, in tree_map
    return tree_unflatten([func(i) for i in flat_args], spec)
  File "/mnt/data/mksit/anaconda3/envs/cpt/lib/python3.10/site-packages/torch/utils/_pytree.py", line 602, in <listcomp>
    return tree_unflatten([func(i) for i in flat_args], spec)
  File "/mnt/data/mksit/anaconda3/envs/cpt/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1795, in wrap
    ) = FakeTensor._find_common_device(func, flat_args)
  File "/mnt/data/mksit/anaconda3/envs/cpt/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1271, in _find_common_device
    merge_devices(arg)
  File "/mnt/data/mksit/anaconda3/envs/cpt/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1266, in merge_devices
    raise RuntimeError(
RuntimeError: Unhandled FakeTensor Device Propagation for aten.mm.default, found two different devices cuda:0, cpu

Versions

2.2.0+cu121

cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @eellison

eellison commented 7 months ago

Root cause, .data here does not work on Fake Tensor:

with FakeTensorMode() as fake_mode:
    x = torch.randn(10, 10, device="cuda")
    y = torch.rand(10, 10)
    y.data = x
    print(y.device) # cpu
mksit commented 7 months ago

Why is .data not supported by Fake Tensor?

I tried to bypass it by setting _overwrite_module_params_on_conversion to True before .cuda() to force overwriting the old parameters. Is it a correct workaround?

jbschlosser commented 7 months ago

In general, use of .data is an anti-pattern. Turns out that nn.Module is one of the remaining places where it's utilized, primarily for BC.

I'd expect the workaround to function correctly, yes:

torch.__future__.set_overwrite_module_params_on_conversion(True)
eellison commented 7 months ago

Should we just return False for tensor subclasses in _has_compatible_shallow_copy_type ?

jbschlosser commented 7 months ago

Should we just return False for tensor subclasses in _has_compatible_shallow_copy_type ?

Possibly, but I'm a little scared of BC issues from this, specifically when using subclasses as module params. cc @mikaylagawarecki / @albanD opinions?

albanD commented 7 months ago

This is WIP yes. Will let @mikaylagawarecki give the exact timeline here