pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.66k stars 22.25k forks source link

[export] tensor creation ops burn in device #104998

Open suo opened 1 year ago

suo commented 1 year ago
import torch
import torch._export

def foo(x):
    return x + torch.ones(2, 2)

e = torch._export.export(foo, (torch.ones(2, 2),))
print(e.graph_module.graph)

produces:

graph():
    %arg0_1 : [num_users=3] = placeholder[target=arg0_1]
    %sym_size_int : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%arg0_1, 0), kwargs = {})
    %sym_size_int_1 : [num_users=1] = call_function[target=torch.ops.aten.sym_size.int](args = (%arg0_1, 1), kwargs = {})
    %eq : [num_users=1] = call_function[target=operator.eq](args = (%sym_size_int_1, 2), kwargs = {})
    %scalar_tensor_default : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%eq,), kwargs = {})
    %_assert_async_msg : [num_users=0] = call_function[target=torch.ops.aten._assert_async.msg](args = (%scalar_tensor_default, Input arg0_1.shape[1] is specialized at 2), kwargs = {})
    %eq_1 : [num_users=1] = call_function[target=operator.eq](args = (%sym_size_int, 2), kwargs = {})
    %scalar_tensor_default_1 : [num_users=1] = call_function[target=torch.ops.aten.scalar_tensor.default](args = (%eq_1,), kwargs = {})
    %_assert_async_msg_1 : [num_users=0] = call_function[target=torch.ops.aten._assert_async.msg](args = (%scalar_tensor_default_1, Input arg0_1.shape[0] is specialized at 2), kwargs = {})
    %full_default : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 2], 1), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu, pin_memory: False})
    %add_tensor : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%arg0_1, %full_default), kwargs = {})
    return (add_tensor,)

Note the following line burns in device=cpu.

    %full_default : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([2, 2], 1), kwargs = {dtype: torch.float32, layout: torch.strided, device: cpu, pin_memory: False})

These kwargs are optional; I'd like for them to not show up in the graph if the user didn't specify them.

cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519 @chauhang @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @ydwu4 @wconstab

suo commented 1 year ago

@bdhirsh I thought we fixed this for some variation of tensor creation ops already? Is this a regression or a different thing?

suo commented 1 year ago

ah it was just for *_like factories: https://github.com/pytorch/pytorch/pull/97564. Could a similar strategy be applied to these tensor creation functions?

ezyang commented 1 year ago

Direct tensor creation cannot easily be done this way, because there is no "tensor" to copy the device off of. So you need symbolic device or post facto fixup.

suo commented 1 year ago

But conceptually if I call torch.empty((2, 3)), I would not expect the graph to have device="cpu", since I am declaring my intent only to create a tensor on the default device, which only happens to be the cpu right now.

I would expect the graph to just have torch.ops.aten.empty((2, 3)) without device specified, which when invoked will do the same default device lookup (which may be return a different result depending on the runtime state).

ezyang commented 1 year ago

Oh, this is the thing where we apply default device in torch layer, so you go from device=None to device="cpu" when you get to dispatcher. Hmmm yeah this should be fixable.

zhxchen17 commented 6 months ago

@suo I think in practice this issue will go away when we enable pre-dispatch ir, but we will also add a post export check to detect device burn-in.