Closed zou3519 closed 1 year ago
cc @kshitij12345 @samdow -- since you folks have some access to an M1 machine, could one of you try to take a look please?
I am able to reproduce the test failure locally, will try to produce minimum repro.
functorch/test/test_vmap.py::TestVmapOperatorsOpInfoCPU::test_vmap_exhaustive_nn_functional_conv2d_cpu_float32 FAILED
============================================================================================ FAILURES =============================================================================================
________________________________________________________ TestVmapOperatorsOpInfoCPU.test_vmap_exhaustive_nn_functional_conv2d_cpu_float32 _________________________________________________________
self = <test_vmap.TestVmapOperatorsOpInfoCPU testMethod=test_vmap_exhaustive_nn_functional_conv2d_cpu_float32>, device = 'cpu', dtype = torch.float32
op = OpInfo(name='nn.functional.conv2d', ref=None, aliases=(<torch.testing._internal.opinfo.core.AliasInfo object at 0x1262...les=True, test_neg_view=True, assert_jit_shape_analysis=True, supports_expanded_weight=True, is_factory_function=False)
@ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
@opsToleranceOverride('TestVmapOperatorsOpInfo', 'test_vmap_exhaustive', (
tol1('linalg.det',
{torch.float32: tol(atol=1e-04, rtol=1e-04)}, device_type='cuda'),
# The following is often flaky, but just on windows.
# We should investigate if it's actually a problem or not.
tol1('nn.functional.conv_transpose3d',
{torch.float32: tol(atol=1e-04, rtol=1e-02)}, device_type='cuda'),
))
@toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
@skipOps('TestVmapOperatorsOpInfo', 'test_vmap_exhaustive', vmap_fail.union({
xfail('cat'),
}))
def test_vmap_exhaustive(self, device, dtype, op):
# needs to be fixed
inplace_failure_list = (
)
> self.opinfo_vmap_test(device, dtype, op, check_has_batch_rule=False,
skip_inplace=inplace_failure_list)
functorch/test/test_vmap.py:3291:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
functorch/test/test_vmap.py:3186: in opinfo_vmap_test
test()
functorch/test/test_vmap.py:3175: in test
self.vmap_outplace_test(func, args, kwargs, in_dims, check_shape_only, postprocess_fn)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <test_vmap.TestVmapOperatorsOpInfoCPU testMethod=test_vmap_exhaustive_nn_functional_conv2d_cpu_float32>, func = <built-in method conv2d of type object at 0x10832e130>
args = (tensor([[[[ 4.7265, 5.5618, -2.2841, -5.6017, -3.4699, -0.2500],
[ 6.7974, 6.0568, -6.6377, -3.1288, 5.2...]]]), tensor([[ 5.3820, 5.3820],
[-2.7438, -2.7438],
[ 1.5609, 1.5609],
[-1.3622, -1.3622]]))
kwargs = {'groups': 4}, in_dims = (None, None, -1), check_shape_only = False, postprocess_fn = None
def vmap_outplace_test(self, func, args, kwargs, in_dims, check_shape_only=False,
postprocess_fn=None):
for loop_out, vmap_out in compute_quantities_for_vmap_test(func, args, kwargs, in_dims):
if postprocess_fn is not None:
loop_out = postprocess_fn(loop_out)
vmap_out = postprocess_fn(vmap_out)
if check_shape_only:
self.assertEqual(vmap_out.shape, loop_out.shape)
continue
> self.assertEqual(vmap_out, loop_out)
E AssertionError: Tensor-likes are not close!
E
E Mismatched elements: 160 / 256 (62.5%)
E Greatest absolute difference: 8.125823974609375 at index (0, 0, 1, 3, 1) (up to 0.0001 allowed)
E Greatest relative difference: 4.8449992721582325 at index (0, 1, 1, 2, 0) (up to 0.0001 allowed)
functorch/test/test_vmap.py:3129: AssertionError
It fails on the first example with non-default groups
value (groups=4
) with in_dims=(None, None, -1)
.
Relevant Batching Code which is executed (at a quick glance nothing looked fishy): https://github.com/pytorch/pytorch/blob/61b4e8a7bfb69954680013e2e34fc099db900736/aten/src/ATen/functorch/BatchRulesConvolution.cpp#L105-L121
Is it possible to use make_fx(vmap(f))(...)
to get a trace of what operations get run during the batching rule? I'm wondering if the problem is that conv is just incorrect on some case in M1 macs, or our batching rule is wrong
With make_fx,
class wrapped(torch.nn.Module):
def forward(self, flat_args):
flat_args_1, flat_args_2, flat_args_3, flat_args_4, = fx_pytree.tree_flatten_spec([flat_args], self._in_spec)
# No stacktrace found for following nodes
convolution = torch.ops.aten.convolution.default(flat_args_1, flat_args_2, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 4); flat_args_1 = flat_args_2 = None
permute = torch.ops.aten.permute.default(flat_args_3, [1, 0]); flat_args_3 = None
unsqueeze = torch.ops.aten.unsqueeze.default(permute, -1); permute = None
unsqueeze_1 = torch.ops.aten.unsqueeze.default(unsqueeze, -1); unsqueeze = None
view = torch.ops.aten.view.default(unsqueeze_1, [2, 1, 4, 1, 1]); unsqueeze_1 = None
add = torch.ops.aten.add.Tensor(convolution, view); convolution = view = None
return pytree.tree_unflatten([add], self._out_spec)
Patch for print:
diff --git a/functorch/test/common_utils.py b/functorch/test/common_utils.py
index cfaa206619..a867e8ff6f 100644
--- a/functorch/test/common_utils.py
+++ b/functorch/test/common_utils.py
@@ -272,6 +272,11 @@ def compute_quantities_for_vmap_test(
# t = make_fx(vmap(f, in_dims=in_dims, out_dims=out_dim))(*batched_args, **kwarg_values)
# print(in_dims, [arg.shape for arg in batched_args], kwarg_values)
batched_args, kwarg_values = maybe_clone_inputs()
+ from torch.fx.experimental.proxy_tensor import make_fx, wrapper_and_args_for_make_fx
+ wrapped, all_args = wrapper_and_args_for_make_fx(vmap(op, in_dims=in_dims, out_dims=out_dim), batched_args, kwarg_values)
+ fx_graph = functorch.make_fx(wrapped)(all_args)
+ # print(fx_graph)
+ fx_graph.print_readable()
batched_out = vmap(op, in_dims=in_dims, out_dims=out_dim)(*batched_args, **kwarg_values)
yield (loop_out, batched_out)
common_utils code: https://github.com/pytorch/pytorch/blob/9baf6770bcd67272a2cb9212c49e3bb95f0679c3/functorch/test/common_utils.py#L274-L275
Interestingly, the vmapped
and loop
output from Linux matches the vmapped
output from M1 Mac but not the loop
output. So the error seems to be in loop version computation on M1 and not the vmap version.
Interesting, thanks for the find
https://github.com/pytorch/pytorch/pull/85711 will fix this
https://github.com/pytorch/pytorch/pull/85711 is merged. Closing the issue.
test_vmap_exhaustive_nn_functional_conv2d_cpu_float32
fails with the following:This is either a problem with our batching rule for vmap, or a problem in some convolution edge case that we exercise only in this batching rule.
We should try to come up with a minimum repro for this and dig into it.