pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

Silent incorrectness for vmap x conv on M1 Macs #1023

Closed zou3519 closed 1 year ago

zou3519 commented 1 year ago

test_vmap_exhaustive_nn_functional_conv2d_cpu_float32 fails with the following:

Traceback (most recent call last):
  File "/Users/ec2-user/runner/_work/_temp/conda_environment_3084529285/lib/python3.9/site-packages/torch/testing/_internal/common_device_type.py", line 378, in instantiated_test
    result = test(self, **param_kwargs)
  File "/Users/ec2-user/runner/_work/_temp/conda_environment_3084529285/lib/python3.9/site-packages/torch/testing/_internal/common_device_type.py", line 815, in test_wrapper
    return test(*args, **kwargs)
  File "/Users/ec2-user/runner/_work/pytorch/pytorch/functorch/test/test_vmap.py", line 3287, in test_vmap_exhaustive
    self.opinfo_vmap_test(device, dtype, op, check_has_batch_rule=False,
  File "/Users/ec2-user/runner/_work/pytorch/pytorch/functorch/test/test_vmap.py", line 3186, in opinfo_vmap_test
    test()
  File "/Users/ec2-user/runner/_work/pytorch/pytorch/functorch/test/test_vmap.py", line 3175, in test
    self.vmap_outplace_test(func, args, kwargs, in_dims, check_shape_only, postprocess_fn)
  File "/Users/ec2-user/runner/_work/pytorch/pytorch/functorch/test/test_vmap.py", line 3129, in vmap_outplace_test
    self.assertEqual(vmap_out, loop_out)
  File "/Users/ec2-user/runner/_work/_temp/conda_environment_3084529285/lib/python3.9/site-packages/torch/testing/_internal/common_utils.py", line 2435, in assertEqual
    assert_equal(
  File "/Users/ec2-user/runner/_work/_temp/conda_environment_3084529285/lib/python3.9/site-packages/torch/testing/_comparison.py", line 1093, in assert_equal
    raise error_metas[0].to_error(msg)
AssertionError: Tensor-likes are not close!

Mismatched elements: 160 / 256 (62.5%)
Greatest absolute difference: 16.557510375976562 at index (0, 1, 3, 2, 3) (up to 0.0001 allowed)
Greatest relative difference: 975.8199190647482 at index (0, 1, 3, 3, 3) (up to 0.0001 allowed)

This is either a problem with our batching rule for vmap, or a problem in some convolution edge case that we exercise only in this batching rule.

We should try to come up with a minimum repro for this and dig into it.

zou3519 commented 1 year ago

cc @kshitij12345 @samdow -- since you folks have some access to an M1 machine, could one of you try to take a look please?

kshitij12345 commented 1 year ago

I am able to reproduce the test failure locally, will try to produce minimum repro.

functorch/test/test_vmap.py::TestVmapOperatorsOpInfoCPU::test_vmap_exhaustive_nn_functional_conv2d_cpu_float32 FAILED

============================================================================================ FAILURES =============================================================================================
________________________________________________________ TestVmapOperatorsOpInfoCPU.test_vmap_exhaustive_nn_functional_conv2d_cpu_float32 _________________________________________________________

self = <test_vmap.TestVmapOperatorsOpInfoCPU testMethod=test_vmap_exhaustive_nn_functional_conv2d_cpu_float32>, device = 'cpu', dtype = torch.float32
op = OpInfo(name='nn.functional.conv2d', ref=None, aliases=(<torch.testing._internal.opinfo.core.AliasInfo object at 0x1262...les=True, test_neg_view=True, assert_jit_shape_analysis=True, supports_expanded_weight=True, is_factory_function=False)

    @ops(op_db + additional_op_db, allowed_dtypes=(torch.float,))
    @opsToleranceOverride('TestVmapOperatorsOpInfo', 'test_vmap_exhaustive', (
        tol1('linalg.det',
             {torch.float32: tol(atol=1e-04, rtol=1e-04)}, device_type='cuda'),
        # The following is often flaky, but just on windows.
        # We should investigate if it's actually a problem or not.
        tol1('nn.functional.conv_transpose3d',
             {torch.float32: tol(atol=1e-04, rtol=1e-02)}, device_type='cuda'),
    ))
    @toleranceOverride({torch.float32: tol(atol=1e-04, rtol=1e-04)})
    @skipOps('TestVmapOperatorsOpInfo', 'test_vmap_exhaustive', vmap_fail.union({
        xfail('cat'),
    }))
    def test_vmap_exhaustive(self, device, dtype, op):
        # needs to be fixed
        inplace_failure_list = (
        )
>       self.opinfo_vmap_test(device, dtype, op, check_has_batch_rule=False,
                              skip_inplace=inplace_failure_list)

functorch/test/test_vmap.py:3291: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
functorch/test/test_vmap.py:3186: in opinfo_vmap_test
    test()
functorch/test/test_vmap.py:3175: in test
    self.vmap_outplace_test(func, args, kwargs, in_dims, check_shape_only, postprocess_fn)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = <test_vmap.TestVmapOperatorsOpInfoCPU testMethod=test_vmap_exhaustive_nn_functional_conv2d_cpu_float32>, func = <built-in method conv2d of type object at 0x10832e130>
args = (tensor([[[[ 4.7265,  5.5618, -2.2841, -5.6017, -3.4699, -0.2500],
          [ 6.7974,  6.0568, -6.6377, -3.1288,  5.2...]]]), tensor([[ 5.3820,  5.3820],
        [-2.7438, -2.7438],
        [ 1.5609,  1.5609],
        [-1.3622, -1.3622]]))
kwargs = {'groups': 4}, in_dims = (None, None, -1), check_shape_only = False, postprocess_fn = None

    def vmap_outplace_test(self, func, args, kwargs, in_dims, check_shape_only=False,
                           postprocess_fn=None):
        for loop_out, vmap_out in compute_quantities_for_vmap_test(func, args, kwargs, in_dims):
            if postprocess_fn is not None:
                loop_out = postprocess_fn(loop_out)
                vmap_out = postprocess_fn(vmap_out)
            if check_shape_only:
                self.assertEqual(vmap_out.shape, loop_out.shape)
                continue
>           self.assertEqual(vmap_out, loop_out)
E           AssertionError: Tensor-likes are not close!
E           
E           Mismatched elements: 160 / 256 (62.5%)
E           Greatest absolute difference: 8.125823974609375 at index (0, 0, 1, 3, 1) (up to 0.0001 allowed)
E           Greatest relative difference: 4.8449992721582325 at index (0, 1, 1, 2, 0) (up to 0.0001 allowed)

functorch/test/test_vmap.py:3129: AssertionError
kshitij12345 commented 1 year ago

It fails on the first example with non-default groups value (groups=4) with in_dims=(None, None, -1).

Sample Ref: https://github.com/pytorch/pytorch/blob/61b4e8a7bfb69954680013e2e34fc099db900736/torch/testing/_internal/common_methods_invocations.py#L3367

Relevant Batching Code which is executed (at a quick glance nothing looked fishy): https://github.com/pytorch/pytorch/blob/61b4e8a7bfb69954680013e2e34fc099db900736/aten/src/ATen/functorch/BatchRulesConvolution.cpp#L105-L121

zou3519 commented 1 year ago

Is it possible to use make_fx(vmap(f))(...) to get a trace of what operations get run during the batching rule? I'm wondering if the problem is that conv is just incorrect on some case in M1 macs, or our batching rule is wrong

kshitij12345 commented 1 year ago

With make_fx,

class wrapped(torch.nn.Module):
    def forward(self, flat_args):
        flat_args_1, flat_args_2, flat_args_3, flat_args_4, = fx_pytree.tree_flatten_spec([flat_args], self._in_spec)

        # No stacktrace found for following nodes 
        convolution = torch.ops.aten.convolution.default(flat_args_1, flat_args_2, None, [1, 1], [0, 0], [1, 1], False, [0, 0], 4);  flat_args_1 = flat_args_2 = None
        permute = torch.ops.aten.permute.default(flat_args_3, [1, 0]);  flat_args_3 = None
        unsqueeze = torch.ops.aten.unsqueeze.default(permute, -1);  permute = None
        unsqueeze_1 = torch.ops.aten.unsqueeze.default(unsqueeze, -1);  unsqueeze = None
        view = torch.ops.aten.view.default(unsqueeze_1, [2, 1, 4, 1, 1]);  unsqueeze_1 = None
        add = torch.ops.aten.add.Tensor(convolution, view);  convolution = view = None
        return pytree.tree_unflatten([add], self._out_spec)

Patch for print:

diff --git a/functorch/test/common_utils.py b/functorch/test/common_utils.py
index cfaa206619..a867e8ff6f 100644
--- a/functorch/test/common_utils.py
+++ b/functorch/test/common_utils.py
@@ -272,6 +272,11 @@ def compute_quantities_for_vmap_test(
     # t = make_fx(vmap(f, in_dims=in_dims, out_dims=out_dim))(*batched_args, **kwarg_values)
     # print(in_dims, [arg.shape for arg in batched_args], kwarg_values)
     batched_args, kwarg_values = maybe_clone_inputs()
+    from torch.fx.experimental.proxy_tensor import make_fx, wrapper_and_args_for_make_fx
+    wrapped, all_args = wrapper_and_args_for_make_fx(vmap(op, in_dims=in_dims, out_dims=out_dim), batched_args, kwarg_values)
+    fx_graph = functorch.make_fx(wrapped)(all_args)
+    # print(fx_graph)
+    fx_graph.print_readable()
     batched_out = vmap(op, in_dims=in_dims, out_dims=out_dim)(*batched_args, **kwarg_values)
     yield (loop_out, batched_out)

common_utils code: https://github.com/pytorch/pytorch/blob/9baf6770bcd67272a2cb9212c49e3bb95f0679c3/functorch/test/common_utils.py#L274-L275

kshitij12345 commented 1 year ago

Interestingly, the vmapped and loop output from Linux matches the vmapped output from M1 Mac but not the loop output. So the error seems to be in loop version computation on M1 and not the vmap version.

zou3519 commented 1 year ago

Interesting, thanks for the find

zou3519 commented 1 year ago

https://github.com/pytorch/pytorch/pull/85711 will fix this

kshitij12345 commented 1 year ago

https://github.com/pytorch/pytorch/pull/85711 is merged. Closing the issue.