Open mksit opened 7 months ago
Root cause, .data
here does not work on Fake Tensor:
with FakeTensorMode() as fake_mode:
x = torch.randn(10, 10, device="cuda")
y = torch.rand(10, 10)
y.data = x
print(y.device) # cpu
Why is .data
not supported by Fake Tensor?
I tried to bypass it by setting _overwrite_module_params_on_conversion
to True
before .cuda()
to force overwriting the old parameters. Is it a correct workaround?
In general, use of .data
is an anti-pattern. Turns out that nn.Module
is one of the remaining places where it's utilized, primarily for BC.
I'd expect the workaround to function correctly, yes:
torch.__future__.set_overwrite_module_params_on_conversion(True)
Should we just return False for tensor subclasses in _has_compatible_shallow_copy_type
?
Should we just return False for tensor subclasses in _has_compatible_shallow_copy_type ?
Possibly, but I'm a little scared of BC issues from this, specifically when using subclasses as module params. cc @mikaylagawarecki / @albanD opinions?
This is WIP yes. Will let @mikaylagawarecki give the exact timeline here
🐛 Describe the bug
I tried to create and convert a model to cuda in FakeTensorMode context, but it seems that cuda() didn't work and the model was still in cpu afterwards.
It failed with
Versions
2.2.0+cu121
cc @albanD @mruberry @jbschlosser @walterddr @mikaylagawarecki @eellison