Open bhack opened 2 months ago
/cc @avikchaudhuri @ezyang
It looks like the export call is specializing on the input height & width (224) being divisible by 32, which inside the model code traces the path where padding is 0, and so the model is just a simple conv. However in this case export is smart enough with the shape constraints to ask you to specify that height_dim
& width_dim
are a multiple of 32, by deriving it from another dim - here's the program printed out after applying the suggested fixes:
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[8, 1, 3, 3]", p_conv_bias: "f32[8]", inputs: "f32[1, 1, 32*s2, 32*s3]"):
# File: /data/users/pianpwk/pytorch/test-issue-136216.py:28 in forward, code: outputs = self.conv(inputs)
conv2d: "f32[1, 8, 32*s2, 32*s3]" = torch.ops.aten.conv2d.default(inputs, p_conv_weight, p_conv_bias, [1, 1], [1, 1]); inputs = p_conv_weight = p_conv_bias = None
return (conv2d,)
...
Range constraints: {32*s2: VR[224, 1024], 32*s3: VR[224, 1024], s2: VR[7, 32], s3: VR[7, 32]}
On the other hand If you feed sample inputs that require padding, export will specialize on the path requiring padding, and it seems like it'll specialize on the amount of padding too - for example here's the error exporting with size (1, 1, 225, 225):
Suggested fixes:
_height_dim = Dim('_height_dim', min=8, max=32)
height_dim = 32*_height_dim - 31
For full flexibility (capturing padding, and allowing for dynamic or 0 padding during runtime), I'd recommend exporting with inputs not divisible by 32, and automatic dynamic shapes, with Dim.AUTO:
height_dim = Dim.AUTO
width_dim = Dim.AUTO
...
# Attempt export
exported_model = exp.export(
model,
(inputs,),
dynamic_shapes=dynamic_shapes
)
print(exported_model)
print(exported_model.module()(torch.randn(1, 1, 225, 225).cuda()).shape)
print(exported_model.module()(torch.randn(1, 1, 336, 336).cuda()).shape)
print(exported_model.module()(torch.randn(1, 1, 224, 224).cuda()).shape)
...
ExportedProgram:
class GraphModule(torch.nn.Module):
def forward(self, p_conv_weight: "f32[8, 1, 3, 3]", p_conv_bias: "f32[8]", inputs: "f32[1, 1, s0, s1]"):
#
sym_size_int_6: "Sym(s0)" = torch.ops.aten.sym_size.int(inputs, 2)
sym_size_int_7: "Sym(s1)" = torch.ops.aten.sym_size.int(inputs, 3)
mod_4: "Sym(Mod(s1, 32))" = sym_size_int_7 % 32
mul_10: "Sym(-Mod(s1, 32))" = -1 * mod_4; mod_4 = None
add_11: "Sym(32 - Mod(s1, 32))" = 32 + mul_10; mul_10 = None
mod_5: "Sym(Mod(32 - Mod(s1, 32), 32))" = add_11 % 32; add_11 = None
floordiv_2: "Sym(((Mod(32 - Mod(s1, 32), 32))//2))" = mod_5 // 2
add_12: "Sym(s1 + ((Mod(32 - Mod(s1, 32), 32))//2))" = sym_size_int_7 + floordiv_2; sym_size_int_7 = None
lt_2: "Sym(Mod(32 - Mod(s1, 32), 32) < s1 + ((Mod(32 - Mod(s1, 32), 32))//2))" = mod_5 < add_12; add_12 = None
_assert_scalar_default = torch.ops.aten._assert_scalar.default(lt_2, "Runtime assertion failed for expression Mod(32 - Mod(s1, 32), 32) < s1 + ((Mod(32 - Mod(s1, 32), 32))//2) on node 'lt_2'"); lt_2 = _assert_scalar_default = None
mod_6: "Sym(Mod(s0, 32))" = sym_size_int_6 % 32
mul_11: "Sym(-Mod(s0, 32))" = -1 * mod_6; mod_6 = None
add_13: "Sym(32 - Mod(s0, 32))" = 32 + mul_11; mul_11 = None
mod_7: "Sym(Mod(32 - Mod(s0, 32), 32))" = add_13 % 32; add_13 = None
floordiv_3: "Sym(((Mod(32 - Mod(s0, 32), 32))//2))" = mod_7 // 2
add_14: "Sym(s0 + ((Mod(32 - Mod(s0, 32), 32))//2))" = sym_size_int_6 + floordiv_3; sym_size_int_6 = None
lt_3: "Sym(Mod(32 - Mod(s0, 32), 32) < s0 + ((Mod(32 - Mod(s0, 32), 32))//2))" = mod_7 < add_14; add_14 = None
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(lt_3, "Runtime assertion failed for expression Mod(32 - Mod(s0, 32), 32) < s0 + ((Mod(32 - Mod(s0, 32), 32))//2) on node 'lt_3'"); lt_3 = _assert_scalar_default_1 = None
# File: /data/users/pianpwk/pytorch/test-issue-136216.py:21 in forward, code: pad_bottom = pad_H - pad_top
sub: "Sym(-((Mod(32 - Mod(s0, 32), 32))//2) + Mod(32 - Mod(s0, 32), 32))" = mod_7 - floordiv_3; mod_7 = None
# File: /data/users/pianpwk/pytorch/test-issue-136216.py:23 in forward, code: pad_right = pad_W - pad_left
sub_1: "Sym(-((Mod(32 - Mod(s1, 32), 32))//2) + Mod(32 - Mod(s1, 32), 32))" = mod_5 - floordiv_2; mod_5 = None
# File: /data/users/pianpwk/pytorch/torch/nn/functional.py:5096 in pad, code: return torch._C._nn.pad(input, pad, mode, value)
pad: "f32[1, 1, s0 + Mod(32 - Mod(s0, 32), 32), s1 + Mod(32 - Mod(s1, 32), 32)]" = torch.ops.aten.pad.default(inputs, [floordiv_2, sub_1, floordiv_3, sub], 'reflect'); inputs = floordiv_2 = sub_1 = floordiv_3 = sub = None
# File: /data/users/pianpwk/pytorch/test-issue-136216.py:28 in forward, code: outputs = self.conv(inputs)
conv2d: "f32[1, 8, s0 + Mod(32 - Mod(s0, 32), 32), s1 + Mod(32 - Mod(s1, 32), 32)]" = torch.ops.aten.conv2d.default(pad, p_conv_weight, p_conv_bias, [1, 1], [1, 1]); pad = p_conv_weight = p_conv_bias = None
return (conv2d,)
Graph signature: ExportGraphSignature(input_specs=[InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_conv_weight'), target='conv.weight', persistent=None), InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='p_conv_bias'), target='conv.bias', persistent=None), InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='inputs'), target=None, persistent=None)], output_specs=[OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='conv2d'), target=None)])
Range constraints: {s0: VR[2, int_oo], s1: VR[2, int_oo]}
[W919 23:04:50.075243403 NNPACK.cpp:61] Could not initialize NNPACK! Reason: Unsupported hardware.
torch.Size([1, 8, 256, 256])
torch.Size([1, 8, 352, 352])
torch.Size([1, 8, 224, 224])
Isn't this a little bit counterintuitive? where we have:
height_dim = Dim("height_dim", min=224, max=1024) width_dim = Dim("width_dim", min=224, max=1024)
# Use ShapesCollection to define dynamic height and width
dynamic_shapes = ShapesCollection() dynamic_shapes[inputs] = (Dim.STATIC, Dim.STATIC, height_dim, width_dim)
I think It is clear that in this range we could have both different padding code paths.
Also, an user could think that there could be a better perf optimization in the export restricting the input with a range especially cause we don't support discrete dynamic inputs (see https://github.com/pytorch/pytorch/issues/136119) Vs Dim.Auto
.
Just an extra point:
In any case we need to update also the documentation for Dim.Auto
at https://pytorch.org/docs/main/export.html#module-torch.export
Isn't this a little bit counterintuitive? where we have:
height_dim = Dim("height_dim", min=224, max=1024) width_dim = Dim("width_dim", min=224, max=1024) # Use ShapesCollection to define dynamic height and width dynamic_shapes = ShapesCollection() dynamic_shapes[inputs] = (Dim.STATIC, Dim.STATIC, height_dim, width_dim)
I think It is clear that in this range we could have both different padding code paths.
Also, an user could think that there could be a better perf optimization in the export restricting the input with a range especially cause we don't support discrete dynamic inputs (see #136119) Vs
Dim.Auto
.
I agree it's a bit unintuitive, but the design is that the model tracing follows the path that the sample inputs specify, generating guards/shape constraints according to that tracing path, and the Dims specified are there to add dynamism to the traced graph, and ensure that the specified ranges/relations match up with the emitted guards.
When you have if/else statements, the compiler has to choose a path - it does so based on the sample inputs - and if your Dim range applies to both paths, we'll ask you to refine according to the path taken. What you might want is to apply torch.cond on your code, so branching happens based on padding.
Regarding adding min/max bounds to AUTO, this is also planned work in the near future.
The most confusing part of the API is that the Dim range is not the same concept as Auto with range boundaries.
Also, why torch.cond
isn't automatically injected?
@ezyang As we have discussed marginally about this in https://github.com/pytorch/pytorch/issues/137520 what is your opinion about this thread and eventually about the notification enforcing/rewriting a condition withtorch.cond
?
I find a lot unintuitive the guard we are introducing just cause the compiler taken a specific code path when instead we are going to specify a dynamic range that internally it is already constrained by the transformation we going to apply to to ensure that we have an input divisible by 32
.
I am totally ok to rewrite the condition with torch.cond
but I expect to be notify about this.
Whenever a guard occurs, torch.cond is potentially a way to eliminate it. But it really depends, and you should spend some time investigating if you can get rid of it some other way first. We can... ahem... remind you in the error message that torch.cond might work, but so might X other interventions, might as well just link you to https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit
Yes but if you see this specific case honestly is really not intuitive the role of the specific example has to take a specific branch of the code when then a few lines later we tell to the exporter that we want to work with a range of input values that of course will involve all the branches.
There is something really counterintuitive in this API.
I definitely agree there's probably some juice in "you told me it has to work for this entire range, but this condition seems to be splitting on it"
Also cause discrete values are not supported. See https://github.com/pytorch/pytorch/issues/136119
🐛 Describe the bug
I don't know if it is a valid bug report or more a feature request. But is technically possible that dynamo detect the model padding instead of enforcing input guards?
This minimal repro was create to debug the same issue on a different bigger model.
Versions
nightly
cc @ezyang @chauhang @penguinwu @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi @suo @ydwu4