pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
84.07k stars 22.67k forks source link

[torch.export] Detect internal constrains #136216

Open bhack opened 2 months ago

bhack commented 2 months ago

🐛 Describe the bug

I don't know if it is a valid bug report or more a feature request. But is technically possible that dynamo detect the model padding instead of enforcing input guards?

This minimal repro was create to debug the same issue on a different bigger model.

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.export as exp
from torch.export import Dim, ShapesCollection

class MinimalModel(nn.Module):
    def __init__(self):
        super(MinimalModel, self).__init__()
        # A simple convolution layer expecting 1 input channel
        self.conv = nn.Conv2d(1, 8, kernel_size=3, stride=1, padding=1)

    def forward(self, inputs):
        B, C, H, W = inputs.shape
        # Calculate necessary padding to make dimensions multiples of 32
        pad_H = (32 - H % 32) % 32
        pad_W = (32 - W % 32) % 32

        if pad_H > 0 or pad_W > 0:
            pad_top = pad_H // 2
            pad_bottom = pad_H - pad_top
            pad_left = pad_W // 2
            pad_right = pad_W - pad_left
            # Apply reflection padding
            inputs = F.pad(inputs, (pad_left, pad_right, pad_top, pad_bottom), mode='reflect')

        # Apply convolution
        outputs = self.conv(inputs)
        return outputs

# Initialize the minimal model
model = MinimalModel().eval()

# Example input (1-channel input with arbitrary size)
inputs = torch.randn(1, 1, 224, 224)  # 1-channel input with height and width of 224

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
inputs = inputs.to(device)

# Define dynamic dimensions for export
height_dim = Dim("height_dim", min=224, max=1024)
width_dim = Dim("width_dim", min=224, max=1024)

# Use ShapesCollection to define dynamic height and width
dynamic_shapes = ShapesCollection()
dynamic_shapes[inputs] = (Dim.STATIC, Dim.STATIC, height_dim, width_dim)

# Attempt export
try:
    exported_model = exp.export(
        model,
        (inputs,),
        dynamic_shapes=dynamic_shapes
    )
    exp.save(exported_model, "exported_model.pt")
    print("Model exported successfully.")
except torch._dynamo.exc.UserError as e:
    print(f"Failed to export: {e}")
E0917 20:00:03.648000 545 site-packages/torch/_guards.py:283] [0/0] Error while creating guard:
E0917 20:00:03.648000 545 site-packages/torch/_guards.py:283] [0/0] Name: ''
E0917 20:00:03.648000 545 site-packages/torch/_guards.py:283] [0/0]     Source: shape_env
E0917 20:00:03.648000 545 site-packages/torch/_guards.py:283] [0/0]     Create Function: SHAPE_ENV
E0917 20:00:03.648000 545 site-packages/torch/_guards.py:283] [0/0]     Guard Types: None
E0917 20:00:03.648000 545 site-packages/torch/_guards.py:283] [0/0]     Code List: None
E0917 20:00:03.648000 545 site-packages/torch/_guards.py:283] [0/0]     Object Weakref: None
E0917 20:00:03.648000 545 site-packages/torch/_guards.py:283] [0/0]     Guarded Class Weakref: None
E0917 20:00:03.648000 545 site-packages/torch/_guards.py:283] [0/0] Traceback (most recent call last):
E0917 20:00:03.648000 545 site-packages/torch/_guards.py:283] [0/0]   File "/opt/conda/lib/python3.11/site-packages/torch/_guards.py", line 281, in create
E0917 20:00:03.648000 545 site-packages/torch/_guards.py:283] [0/0]     return self.create_fn(builder, self)
E0917 20:00:03.648000 545 site-packages/torch/_guards.py:283] [0/0]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E0917 20:00:03.648000 545 site-packages/torch/_guards.py:283] [0/0]   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/guards.py", line 1844, in SHAPE_ENV
E0917 20:00:03.648000 545 site-packages/torch/_guards.py:283] [0/0]     guards = output_graph.shape_env.produce_guards(
E0917 20:00:03.648000 545 site-packages/torch/_guards.py:283] [0/0]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
E0917 20:00:03.648000 545 site-packages/torch/_guards.py:283] [0/0]   File "/opt/conda/lib/python3.11/site-packages/torch/fx/experimental/symbolic_shapes.py", line 4194, in produce_guards
E0917 20:00:03.648000 545 site-packages/torch/_guards.py:283] [0/0]     raise ConstraintViolationError(
E0917 20:00:03.648000 545 site-packages/torch/_guards.py:283] [0/0] torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (height_dim, width_dim)! For more information, run with TORCH_LOGS="+dynamic".
E0917 20:00:03.648000 545 site-packages/torch/_guards.py:283] [0/0]   - Not all values of height_dim = L['inputs'].size()[2] in the specified range 224 <= height_dim <= 1024 satisfy the generated guard Mod(32 - Mod(L['inputs'].size()[2], 32), 32) <= 0.
E0917 20:00:03.648000 545 site-packages/torch/_guards.py:283] [0/0]   - Not all values of width_dim = L['inputs'].size()[3] in the specified range 224 <= width_dim <= 1024 satisfy the generated guard Mod(32 - Mod(L['inputs'].size()[3], 32), 32) <= 0.
E0917 20:00:03.652000 545 site-packages/torch/_guards.py:285] [0/0] Created at:
E0917 20:00:03.652000 545 site-packages/torch/_guards.py:285] [0/0]   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 615, in transform
E0917 20:00:03.652000 545 site-packages/torch/_guards.py:285] [0/0]     tracer = InstructionTranslator(
E0917 20:00:03.652000 545 site-packages/torch/_guards.py:285] [0/0]   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2670, in __init__
E0917 20:00:03.652000 545 site-packages/torch/_guards.py:285] [0/0]     output=OutputGraph(
E0917 20:00:03.652000 545 site-packages/torch/_guards.py:285] [0/0]   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 317, in __init__
E0917 20:00:03.652000 545 site-packages/torch/_guards.py:285] [0/0]     self.init_ambient_guards()
E0917 20:00:03.652000 545 site-packages/torch/_guards.py:285] [0/0]   File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 463, in init_ambient_guards
E0917 20:00:03.652000 545 site-packages/torch/_guards.py:285] [0/0]     self.guards.add(ShapeEnvSource().make_guard(GuardBuilder.SHAPE_ENV))
Failed to export: Constraints violated (height_dim, width_dim)! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of height_dim = L['inputs'].size()[2] in the specified range 224 <= height_dim <= 1024 satisfy the generated guard Mod(32 - Mod(L['inputs'].size()[2], 32), 32) <= 0.
  - Not all values of width_dim = L['inputs'].size()[3] in the specified range 224 <= width_dim <= 1024 satisfy the generated guard Mod(32 - Mod(L['inputs'].size()[3], 32), 32) <= 0.

Suggested fixes:
  _height_dim = Dim('_height_dim', min=7, max=32)
  _width_dim = Dim('_width_dim', min=7, max=32)
  height_dim = 32*_height_dim
  width_dim = 32*_width_dim

Versions

nightly

cc @ezyang @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4

bhack commented 2 months ago

/cc @avikchaudhuri @ezyang

pianpwk commented 1 month ago

It looks like the export call is specializing on the input height & width (224) being divisible by 32, which inside the model code traces the path where padding is 0, and so the model is just a simple conv. However in this case export is smart enough with the shape constraints to ask you to specify that height_dim & width_dim are a multiple of 32, by deriving it from another dim - here's the program printed out after applying the suggested fixes:

ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_conv_weight: "f32[8, 1, 3, 3]", p_conv_bias: "f32[8]", inputs: "f32[1, 1, 32*s2, 32*s3]"):
             # File: /data/users/pianpwk/pytorch/test-issue-136216.py:28 in forward, code: outputs = self.conv(inputs)
            conv2d: "f32[1, 8, 32*s2, 32*s3]" = torch.ops.aten.conv2d.default(inputs, p_conv_weight, p_conv_bias, [1, 1], [1, 1]);  inputs = p_conv_weight = p_conv_bias = None
            return (conv2d,)
...
Range constraints: {32*s2: VR[224, 1024], 32*s3: VR[224, 1024], s2: VR[7, 32], s3: VR[7, 32]}

On the other hand If you feed sample inputs that require padding, export will specialize on the path requiring padding, and it seems like it'll specialize on the amount of padding too - for example here's the error exporting with size (1, 1, 225, 225):

Suggested fixes:
  _height_dim = Dim('_height_dim', min=8, max=32)
  height_dim = 32*_height_dim - 31

For full flexibility (capturing padding, and allowing for dynamic or 0 padding during runtime), I'd recommend exporting with inputs not divisible by 32, and automatic dynamic shapes, with Dim.AUTO:

height_dim = Dim.AUTO
width_dim = Dim.AUTO
...
# Attempt export
exported_model = exp.export(
    model,
    (inputs,),
    dynamic_shapes=dynamic_shapes
)
print(exported_model)
print(exported_model.module()(torch.randn(1, 1, 225, 225).cuda()).shape)
print(exported_model.module()(torch.randn(1, 1, 336, 336).cuda()).shape)
print(exported_model.module()(torch.randn(1, 1, 224, 224).cuda()).shape)
...
ExportedProgram:
    class GraphModule(torch.nn.Module):
        def forward(self, p_conv_weight: "f32[8, 1, 3, 3]", p_conv_bias: "f32[8]", inputs: "f32[1, 1, s0, s1]"):
             # 
            sym_size_int_6: "Sym(s0)" = torch.ops.aten.sym_size.int(inputs, 2)
            sym_size_int_7: "Sym(s1)" = torch.ops.aten.sym_size.int(inputs, 3)
            mod_4: "Sym(Mod(s1, 32))" = sym_size_int_7 % 32
            mul_10: "Sym(-Mod(s1, 32))" = -1 * mod_4;  mod_4 = None
            add_11: "Sym(32 - Mod(s1, 32))" = 32 + mul_10;  mul_10 = None
            mod_5: "Sym(Mod(32 - Mod(s1, 32), 32))" = add_11 % 32;  add_11 = None
            floordiv_2: "Sym(((Mod(32 - Mod(s1, 32), 32))//2))" = mod_5 // 2
            add_12: "Sym(s1 + ((Mod(32 - Mod(s1, 32), 32))//2))" = sym_size_int_7 + floordiv_2;  sym_size_int_7 = None
            lt_2: "Sym(Mod(32 - Mod(s1, 32), 32) < s1 + ((Mod(32 - Mod(s1, 32), 32))//2))" = mod_5 < add_12;  add_12 = None
            _assert_scalar_default = torch.ops.aten._assert_scalar.default(lt_2, "Runtime assertion failed for expression Mod(32 - Mod(s1, 32), 32) < s1 + ((Mod(32 - Mod(s1, 32), 32))//2) on node 'lt_2'");  lt_2 = _assert_scalar_default = None
            mod_6: "Sym(Mod(s0, 32))" = sym_size_int_6 % 32
            mul_11: "Sym(-Mod(s0, 32))" = -1 * mod_6;  mod_6 = None
            add_13: "Sym(32 - Mod(s0, 32))" = 32 + mul_11;  mul_11 = None
            mod_7: "Sym(Mod(32 - Mod(s0, 32), 32))" = add_13 % 32;  add_13 = None
            floordiv_3: "Sym(((Mod(32 - Mod(s0, 32), 32))//2))" = mod_7 // 2
            add_14: "Sym(s0 + ((Mod(32 - Mod(s0, 32), 32))//2))" = sym_size_int_6 + floordiv_3;  sym_size_int_6 = None
            lt_3: "Sym(Mod(32 - Mod(s0, 32), 32) < s0 + ((Mod(32 - Mod(s0, 32), 32))//2))" = mod_7 < add_14;  add_14 = None
            _assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(lt_3, "Runtime assertion failed for expression Mod(32 - Mod(s0, 32), 32) < s0 + ((Mod(32 - Mod(s0, 32), 32))//2) on node 'lt_3'");  lt_3 = _assert_scalar_default_1 = None

             # File: /data/users/pianpwk/pytorch/test-issue-136216.py:21 in forward, code: pad_bottom = pad_H - pad_top
            sub: "Sym(-((Mod(32 - Mod(s0, 32), 32))//2) + Mod(32 - Mod(s0, 32), 32))" = mod_7 - floordiv_3;  mod_7 = None

             # File: /data/users/pianpwk/pytorch/test-issue-136216.py:23 in forward, code: pad_right = pad_W - pad_left
            sub_1: "Sym(-((Mod(32 - Mod(s1, 32), 32))//2) + Mod(32 - Mod(s1, 32), 32))" = mod_5 - floordiv_2;  mod_5 = None

             # File: /data/users/pianpwk/pytorch/torch/nn/functional.py:5096 in pad, code: return torch._C._nn.pad(input, pad, mode, value)
            pad: "f32[1, 1, s0 + Mod(32 - Mod(s0, 32), 32), s1 + Mod(32 - Mod(s1, 32), 32)]" = torch.ops.aten.pad.default(inputs, [floordiv_2, sub_1, floordiv_3, sub], 'reflect');  inputs = floordiv_2 = sub_1 = floordiv_3 = sub = None

             # File: /data/users/pianpwk/pytorch/test-issue-136216.py:28 in forward, code: outputs = self.conv(inputs)
            conv2d: "f32[1, 8, s0 + Mod(32 - Mod(s0, 32), 32), s1 + Mod(32 - Mod(s1, 32), 32)]" = torch.ops.aten.conv2d.default(pad, p_conv_weight, p_conv_bias, [1, 1], [1, 1]);  pad = p_conv_weight = p_conv_bias = None
            return (conv2d,)

Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_conv_weight'), target='conv.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_conv_bias'), target='conv.bias', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='inputs'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='conv2d'), target=None)])
Range constraints: {s0: VR[2, int_oo], s1: VR[2, int_oo]}

[W919 23:04:50.075243403 NNPACK.cpp:61] Could not initialize NNPACK! Reason: Unsupported hardware.
torch.Size([1, 8, 256, 256])
torch.Size([1, 8, 352, 352])
torch.Size([1, 8, 224, 224])
bhack commented 1 month ago

Isn't this a little bit counterintuitive? where we have:

height_dim = Dim("height_dim", min=224, max=1024) width_dim = Dim("width_dim", min=224, max=1024) 
# Use ShapesCollection to define dynamic height and width 
dynamic_shapes = ShapesCollection() dynamic_shapes[inputs] = (Dim.STATIC, Dim.STATIC, height_dim, width_dim)

I think It is clear that in this range we could have both different padding code paths.

Also, an user could think that there could be a better perf optimization in the export restricting the input with a range especially cause we don't support discrete dynamic inputs (see https://github.com/pytorch/pytorch/issues/136119) Vs Dim.Auto.

bhack commented 1 month ago

Just an extra point: In any case we need to update also the documentation for Dim.Auto at https://pytorch.org/docs/main/export.html#module-torch.export

pianpwk commented 1 month ago

Isn't this a little bit counterintuitive? where we have:

height_dim = Dim("height_dim", min=224, max=1024) width_dim = Dim("width_dim", min=224, max=1024) 
# Use ShapesCollection to define dynamic height and width 
dynamic_shapes = ShapesCollection() dynamic_shapes[inputs] = (Dim.STATIC, Dim.STATIC, height_dim, width_dim)

I think It is clear that in this range we could have both different padding code paths.

Also, an user could think that there could be a better perf optimization in the export restricting the input with a range especially cause we don't support discrete dynamic inputs (see #136119) Vs Dim.Auto.

I agree it's a bit unintuitive, but the design is that the model tracing follows the path that the sample inputs specify, generating guards/shape constraints according to that tracing path, and the Dims specified are there to add dynamism to the traced graph, and ensure that the specified ranges/relations match up with the emitted guards.

When you have if/else statements, the compiler has to choose a path - it does so based on the sample inputs - and if your Dim range applies to both paths, we'll ask you to refine according to the path taken. What you might want is to apply torch.cond on your code, so branching happens based on padding.

Regarding adding min/max bounds to AUTO, this is also planned work in the near future.

bhack commented 1 month ago

The most confusing part of the API is that the Dim range is not the same concept as Auto with range boundaries.

Also, why torch.cond isn't automatically injected?

bhack commented 3 weeks ago

@ezyang As we have discussed marginally about this in https://github.com/pytorch/pytorch/issues/137520 what is your opinion about this thread and eventually about the notification enforcing/rewriting a condition withtorch.cond?

I find a lot unintuitive the guard we are introducing just cause the compiler taken a specific code path when instead we are going to specify a dynamic range that internally it is already constrained by the transformation we going to apply to to ensure that we have an input divisible by 32.

I am totally ok to rewrite the condition with torch.cond but I expect to be notify about this.

ezyang commented 3 weeks ago

Whenever a guard occurs, torch.cond is potentially a way to eliminate it. But it really depends, and you should spend some time investigating if you can get rid of it some other way first. We can... ahem... remind you in the error message that torch.cond might work, but so might X other interventions, might as well just link you to https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit

bhack commented 3 weeks ago

Yes but if you see this specific case honestly is really not intuitive the role of the specific example has to take a specific branch of the code when then a few lines later we tell to the exporter that we want to work with a range of input values that of course will involve all the branches.

There is something really counterintuitive in this API.

ezyang commented 3 weeks ago

I definitely agree there's probably some juice in "you told me it has to work for this entire range, but this condition seems to be splitting on it"

bhack commented 3 weeks ago

Also cause discrete values are not supported. See https://github.com/pytorch/pytorch/issues/136119