pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.39k stars 102 forks source link

Improve error message when the wrong number of arguments is passed into a `make_fx` function. #84

Open Chillee opened 3 years ago

Chillee commented 3 years ago

Root cause of https://github.com/facebookresearch/functorch/issues/82 :P

ezyang commented 2 years ago

was bit by this. Error message looks like this

Traceback (most recent call last):                                                     
  File "/data/users/ezyang/functorch/functorch/_src/aten2torch.py", line 48, in test_ad
d                                                                                      
    graph = make_fx(TABLE[aten.add.Tensor])(torch.randn(3))                            
  File "/data/users/ezyang/functorch/functorch/_src/python_key.py", line 198, in wrappe
d                                                                                      
    t = pythonkey_trace(wrap_key(f, args), concrete_args=tuple(phs))                     File "/data/users/ezyang/functorch/functorch/_src/python_key.py", line 164, in python
key_trace                                                                              
    graph = tracer.trace(root, concrete_args)                                            File "/data/users/ezyang/pytorch-tmp/torch/fx/_symbolic_trace.py", line 549, in trace
    fn, args = self.create_args_for_root(fn, isinstance(root, torch.nn.Module), concret
e_args)                                                                                
  File "/data/users/ezyang/pytorch-tmp/torch/fx/_symbolic_trace.py", line 437, in creat
e_args_for_root                                                                        
    assert(len(arg_names) == len(concrete_args))                                       
AssertionError