pytorch / TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
https://pytorch.org/TensorRT
BSD 3-Clause "New" or "Revised" License
2.6k stars 351 forks source link

full_like to full decomposition moving to decomposition.py for dynami… #3289

Open apbose opened 2 weeks ago

narendasan commented 1 week ago

@apbose do you have a test case?

peri044 commented 1 week ago

@apbose I see your comment : https://github.com/pytorch/TensorRT/issues/3140#issuecomment-2463481654. Can you provide more context on why this change is required ?

apbose commented 1 week ago

@narendasan the test case already exists in https://github.com/pytorch/TensorRT/blob/main/tests/py/dynamo/lowering/test_decompositions.py#L424

apbose commented 1 week ago

@peri044 I removed the replace_full_like_to_full and instead moved it to the _decompositions.py. In the dynamic case, full op tries to get the meta data from the full_like input tensor meta data, but since the shape of the input tensor is dynamic, it gets undefined shape in the graph and the forward function complains. This is the graph node it gets lowered to from the full_like node input shape

full_default = torch.ops.aten.full.default([s0, 1, s2], 1, pin_memory = False, device = device(type='cuda', index=0), dtype = torch.float32) where s0 and s2 are undefined. Whereas now by making it a lowered op in whole, the graph has

 %sym_size_int_11 : [num_users=10] = call_function[target=torch.ops.aten.sym_size.int](args = (%args0, 0), kwargs = {})
  %sym_size_int_12 : [num_users=11] = call_function[target=torch.ops.aten.sym_size.int](args = (%args1, 1), kwargs = {})

%full : [num_users=1] = call_function[target=torch.ops.aten.full.default](args = ([%sym_size_int_11, 1, %sym_size_int_12], 1), kwargs = {dtype: torch.float32, device: cuda:0, pin_memory: False})
apbose commented 1 week ago

Oh the PR now is failing since the graph post lowering is an empty one

the graph now is==== graph():
    %arg0_1 : [num_users=0] = placeholder[target=arg0_1]
    %_frozen_param0 : [num_users=1] = get_attr[target=_frozen_param0]
    return (_frozen_param0,)

Modifying the test now to make it non empty