pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.39k stars 102 forks source link

`aot_module` with `nn.Linear` leads to runtime error #595

Open adamgayoso opened 2 years ago

adamgayoso commented 2 years ago

functorch looks super cool and helpful! I was just playing around with it this morning, trying to aot compile nn.Linear. Maybe I'm misunderstanding the API, but I'm getting a RuntimeError

from functorch.compile import ts_compile, aot_module
In [51]: b = nn.Linear(100, 200)

In [52]: b(torch.zeros(128, 100))
Out[52]: 
tensor([[-0.0528,  0.0509,  0.0053,  ...,  0.0059,  0.0502,  0.0504],
        [-0.0528,  0.0509,  0.0053,  ...,  0.0059,  0.0502,  0.0504],
        [-0.0528,  0.0509,  0.0053,  ...,  0.0059,  0.0502,  0.0504],
        ...,
        [-0.0528,  0.0509,  0.0053,  ...,  0.0059,  0.0502,  0.0504],
        [-0.0528,  0.0509,  0.0053,  ...,  0.0059,  0.0502,  0.0504],
        [-0.0528,  0.0509,  0.0053,  ...,  0.0059,  0.0502,  0.0504]],
       grad_fn=<AddmmBackward0>)

In [53]: c = aot_module(b, ts_compile, ts_compile)

In [54]: c(torch.zeros(128, 100))
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [54], in <module>
----> 1 c(torch.zeros(128, 100))

File ~/miniconda3/envs/scvi-tools-dev/lib/python3.9/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
   1106 # If we don't have any hooks, we want to skip the rest of the logic in
   1107 # this function, and just call forward.
   1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110     return forward_call(*input, **kwargs)
   1111 # Do not call functions when jit is used
   1112 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/scvi-tools-dev/lib/python3.9/site-packages/functorch/_src/aot_autograd.py:521, in aot_module.<locals>.AOTModule.forward(self, *args, **kwargs)
    520 def forward(self, *args, **kwargs):
--> 521     return compiled_f(
    522         dict(_named_parameters(mod, remove_duplicate=False)),
    523         dict(_named_buffers(mod, remove_duplicate=False)),
    524         *args,
    525         **kwargs,
    526     )

File ~/miniconda3/envs/scvi-tools-dev/lib/python3.9/site-packages/functorch/_src/aot_autograd.py:459, in aot_function.<locals>.returned_function(*args, **kwargs)
    448     compile_cache.insert(
    449         fn_id,
    450         fw_compiler_id,
   (...)
    455         *flat_args_for_cache,
    456     )
    458 cached_fn, out_spec = cached_res
--> 459 out = cached_fn(*flat_tensor_args)
    460 return out_spec.unflatten(out)

File ~/miniconda3/envs/scvi-tools-dev/lib/python3.9/site-packages/functorch/_src/aot_autograd.py:136, in create_aot_autograd_function.<locals>.CompiledFunction.forward(ctx, *flat_tensor_args)
    134 if compiled_fw is None:
    135     with torch.set_grad_enabled(grad_state):
--> 136         out = flat_fn(*flat_tensor_args)
    137     out = pytree.tree_map(
    138         lambda x: x.detach() if isinstance(x, Tensor) else x, out
    139     )
    141     if isinstance(out, (list, tuple)):

File ~/miniconda3/envs/scvi-tools-dev/lib/python3.9/site-packages/functorch/_src/aot_autograd.py:417, in aot_function.<locals>.returned_function.<locals>.flat_fn(*flat_tensor_args)
    415 else:
    416     args = rearrange(tensor_args, static_args, static_argnums)
--> 417 tree_out = fn(*args, **kwargs)
    418 flat_out, spec = pytree.tree_flatten(tree_out)
    419 for i in flat_out:

File ~/miniconda3/envs/scvi-tools-dev/lib/python3.9/site-packages/functorch/_src/aot_autograd.py:511, in aot_module.<locals>.functional_call(named_params, named_buffers, *args, **kwargs)
    509 def functional_call(named_params, named_buffers, *args, **kwargs):
    510     params_and_buffers = {**named_params, **named_buffers}
--> 511     return _stateless.functional_call(mod, params_and_buffers, args, kwargs)

File ~/miniconda3/envs/scvi-tools-dev/lib/python3.9/site-packages/torch/nn/utils/_stateless.py:117, in functional_call(module, parameters_and_buffers, args, kwargs)
    115 with reparametrize_module(module, parameters_and_buffers):
    116     if isinstance(args, tuple):
--> 117         out = module(*args, **kwargs)
    118     else:
    119         out = module(args, **kwargs)

File ~/miniconda3/envs/scvi-tools-dev/lib/python3.9/site-packages/torch/nn/modules/module.py:1110, in Module._call_impl(self, *input, **kwargs)
   1106 # If we don't have any hooks, we want to skip the rest of the logic in
   1107 # this function, and just call forward.
   1108 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110     return forward_call(*input, **kwargs)
   1111 # Do not call functions when jit is used
   1112 full_backward_hooks, non_full_backward_hooks = [], []

File ~/miniconda3/envs/scvi-tools-dev/lib/python3.9/site-packages/torch/nn/modules/linear.py:103, in Linear.forward(self, input)
    102 def forward(self, input: Tensor) -> Tensor:
--> 103     return F.linear(input, self.weight, self.bias)

RuntimeError: mat2 must be a matrix, got 1-D tensor
Chillee commented 2 years ago

@adamgayoso I can't seem to replicate - are you running on nightly or with 1.11?

I'll try it out on nightly in a bit.

adamgayoso commented 2 years ago

I'm running with PyTorch v1.11 and functorch==0.1.0

adamgayoso commented 2 years ago

hmmm now it works, I had done some stuff earlier in that session that failed though and it seems this code I posted fails only when it comes after previous code that fails.

adamgayoso commented 2 years ago

In the case where it doesn't work, somehow the weight and the bias are getting swapped when I access attrs of orig.module

Chillee commented 2 years ago

Interesting... I'll take a look at that module. Do you have any example code snippets on what's not working? Or is it not super clear how to come up with a repro?