pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
84.5k stars 22.76k forks source link

Torch Dynamo fails to trace `torch.nn.RReLU()` #119460

Open fynnsu opened 9 months ago

fynnsu commented 9 months ago

🐛 Describe the bug

Dynamo fails to produce the graph for RReLU.

It fails when checking if this graph is functional.

opcode         name           target                      args                                 kwargs
-------------  -------------  --------------------------  -----------------------------------  ----------------------------------------------------------------------------------------------------------------------------------------------
placeholder    arg0_1         arg0_1                      ()                                   {}
call_function  empty          aten.empty.memory_format    ([5, 2],)                            {'dtype': torch.float32, 'layout': torch.strided, 'device': device(type='cpu'), 'pin_memory': False, 'memory_format': torch.contiguous_format}
call_function  le             aten.le.Scalar              (arg0_1, 0)                          {}
call_function  uniform        aten.uniform.default        (arg0_1, 0.125, 0.3333333333333333)  {}
call_function  mul            aten.mul.Tensor             (arg0_1, uniform)                    {}
call_function  where          aten.where.self             (le, mul, arg0_1)                    {}
call_function  scalar_tensor  aten.scalar_tensor.default  (1,)                                 {'dtype': torch.float32, 'layout': torch.strided, 'device': device(type='cpu')}
call_function  where_1        aten.where.self             (le, uniform, scalar_tensor)         {}
call_function  copy_          aten.copy_.default          (empty, where_1)                     {}
output         output         output                      ((where,),)                          {}

However, this graph seems to contain several unnecessary operations (including the one that fails copy_). The graph below should be sufficient to compute RReLU operation.

opcode         name           target                      args                                 kwargs
-------------  -------------  --------------------------  -----------------------------------  --------
placeholder    arg0_1         arg0_1                      ()                                   {}
call_function  le             aten.le.Scalar              (arg0_1, 0)                          {}
call_function  uniform        aten.uniform.default        (arg0_1, 0.125, 0.3333333333333333)  {}
call_function  mul            aten.mul.Tensor             (arg0_1, uniform)                    {}
call_function  where          aten.where.self             (le, mul, arg0_1)                    {}
output         output         output                      ((where,),)                          {}

I initially thought the extra lines might be for the inplace=True graph but they aren't used in that graph either. (below)

opcode         name           target                      args                                 kwargs
-------------  -------------  --------------------------  -----------------------------------  ----------------------------------------------------------------------------------------------------------------------------------------------
placeholder    arg0_1         arg0_1                      ()                                   {}
call_function  empty          aten.empty.memory_format    ([5, 2],)                            {'dtype': torch.float32, 'layout': torch.strided, 'device': device(type='cpu'), 'pin_memory': False, 'memory_format': torch.contiguous_format}
call_function  le             aten.le.Scalar              (arg0_1, 0)                          {}
call_function  uniform        aten.uniform.default        (arg0_1, 0.125, 0.3333333333333333)  {}
call_function  mul            aten.mul.Tensor             (arg0_1, uniform)                    {}
call_function  where          aten.where.self             (le, mul, arg0_1)                    {}
call_function  scalar_tensor  aten.scalar_tensor.default  (1,)                                 {'dtype': torch.float32, 'layout': torch.strided, 'device': device(type='cpu')}
call_function  where_1        aten.where.self             (le, uniform, scalar_tensor)         {}
call_function  copy_          aten.copy_.default          (empty, where_1)                     {}
call_function  copy__1        aten.copy_.default          (arg0_1, where)                      {}
output         output         output                      ((where,),)                          {}

Error logs

[2024-02-08 11:23:34,926] torch._dynamo.eval_frame: [DEBUG] Saving dynamo config and hash for new compiled object(s). Hash: a9446d0645a24f8e5db15f38d621b2a5
[2024-02-08 11:23:34,927] torch._dynamo.eval_frame: [DEBUG] skipping: helper (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/contextlib.py)
[2024-02-08 11:23:34,927] torch._dynamo.eval_frame: [DEBUG] skipping: __init__ (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/contextlib.py)
[2024-02-08 11:23:34,927] torch._dynamo.eval_frame: [DEBUG] skipping: __enter__ (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/contextlib.py)
[2024-02-08 11:23:34,928] torch._dynamo.eval_frame: [DEBUG] skipping: backend_cache_wrapper (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py)
[2024-02-08 11:23:34,928] torch._dynamo.eval_frame: [DEBUG] skipping: _maybe_init_guarded_backend_cache (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py)
[2024-02-08 11:23:34,928] torch._dynamo.eval_frame: [DEBUG] skipping: innermost_fn (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py)
[2024-02-08 11:23:34,928] torch._dynamo.eval_frame: [DEBUG] skipping: _set_current_backend (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py)
[2024-02-08 11:23:34,928] torch._dynamo.eval_frame: [DEBUG] skipping: __init__ (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/contextlib.py)
[2024-02-08 11:23:34,928] torch._dynamo.eval_frame: [DEBUG] skipping: __enter__ (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/contextlib.py)
[2024-02-08 11:23:34,928] torch._dynamo.eval_frame: [DEBUG] skipping: restore_guarded_dynamo_config (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py)
[2024-02-08 11:23:34,928] torch._dynamo.eval_frame: [DEBUG] skipping: _maybe_init_guarded_config_cache (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py)
[2024-02-08 11:23:34,928] torch._dynamo.eval_frame: [DEBUG] skipping: debug (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/logging/__init__.py)
[2024-02-08 11:23:34,928] torch._dynamo.eval_frame: [DEBUG] skipping: isEnabledFor (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/logging/__init__.py)
[2024-02-08 11:23:34,928] torch._dynamo.eval_frame: [DEBUG] skipping: _log (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/logging/__init__.py)
[2024-02-08 11:23:34,928] torch._dynamo.eval_frame: [DEBUG] skipping: findCaller (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/logging/__init__.py)
[2024-02-08 11:23:34,928] torch._dynamo.eval_frame: [DEBUG] skipping: <lambda> (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/logging/__init__.py)
[2024-02-08 11:23:34,928] torch._dynamo.eval_frame: [DEBUG] skipping: normcase (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/posixpath.py)
[2024-02-08 11:23:34,928] torch._dynamo.eval_frame: [DEBUG] skipping: makeRecord (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/logging/__init__.py)
[2024-02-08 11:23:34,928] torch._dynamo.eval_frame: [DEBUG] skipping: __init__ (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/logging/__init__.py)
[2024-02-08 11:23:34,928] torch._dynamo.eval_frame: [DEBUG] skipping: __instancecheck__ (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/abc.py)
[2024-02-08 11:23:34,928] torch._dynamo.eval_frame: [DEBUG] skipping: getLevelName (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/logging/__init__.py)
[2024-02-08 11:23:34,928] torch._dynamo.eval_frame: [DEBUG] skipping: basename (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/posixpath.py)
[2024-02-08 11:23:34,928] torch._dynamo.eval_frame: [DEBUG] skipping: _get_sep (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/posixpath.py)
[2024-02-08 11:23:34,928] torch._dynamo.eval_frame: [DEBUG] skipping: splitext (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/posixpath.py)
[2024-02-08 11:23:34,929] torch._dynamo.convert_frame: [DEBUG] skipping because no torch.* _splitext             /home/fynnsu/miniconda3/envs/uw/lib/python3.9/genericpath.py 121
[2024-02-08 11:23:34,929] torch._dynamo.eval_frame: [DEBUG] skipping: current_thread (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/threading.py)
[2024-02-08 11:23:34,929] torch._dynamo.eval_frame: [DEBUG] skipping: name (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/threading.py)
[2024-02-08 11:23:34,929] torch._dynamo.eval_frame: [DEBUG] skipping: current_process (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/multiprocessing/process.py)
[2024-02-08 11:23:34,929] torch._dynamo.eval_frame: [DEBUG] skipping: name (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/multiprocessing/process.py)
[2024-02-08 11:23:34,929] torch._dynamo.eval_frame: [DEBUG] skipping: handle (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/logging/__init__.py)
[2024-02-08 11:23:34,929] torch._dynamo.eval_frame: [DEBUG] skipping: filter (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/logging/__init__.py)
[2024-02-08 11:23:34,929] torch._dynamo.eval_frame: [DEBUG] skipping: callHandlers (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/logging/__init__.py)
[2024-02-08 11:23:34,929] torch._dynamo.eval_frame: [DEBUG] skipping: handle (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/logging/__init__.py)
[2024-02-08 11:23:34,929] torch._dynamo.eval_frame: [DEBUG] skipping: acquire (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/logging/__init__.py)
[2024-02-08 11:23:34,929] torch._dynamo.eval_frame: [DEBUG] skipping: emit (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/logging/__init__.py)
[2024-02-08 11:23:34,929] torch._dynamo.eval_frame: [DEBUG] skipping: format (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/logging/__init__.py)
[2024-02-08 11:23:34,929] torch._dynamo.eval_frame: [DEBUG] skipping: format (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_logging/_internal.py)
[2024-02-08 11:23:34,929] torch._dynamo.eval_frame: [DEBUG] skipping: getLogger (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/logging/__init__.py)
[2024-02-08 11:23:34,929] torch._dynamo.eval_frame: [DEBUG] skipping: getLogger (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/logging/__init__.py)
[2024-02-08 11:23:34,929] torch._dynamo.eval_frame: [DEBUG] skipping: _acquireLock (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/logging/__init__.py)
[2024-02-08 11:23:34,929] torch._dynamo.eval_frame: [DEBUG] skipping: _releaseLock (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/logging/__init__.py)
[2024-02-08 11:23:34,929] torch._dynamo.eval_frame: [DEBUG] skipping: getMessage (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/logging/__init__.py)
[2024-02-08 11:23:34,929] torch._dynamo.eval_frame: [DEBUG] skipping: formatTime (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/logging/__init__.py)
[2024-02-08 11:23:34,930] torch._dynamo.eval_frame: [DEBUG] skipping: is_available (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/distributed/__init__.py)
[2024-02-08 11:23:34,930] torch._dynamo.eval_frame: [DEBUG] skipping: is_initialized (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py)
[2024-02-08 11:23:34,930] torch._dynamo.eval_frame: [DEBUG] skipping: WORLD (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py)
[2024-02-08 11:23:34,930] torch._dynamo.eval_frame: [DEBUG] skipping: default_pg (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/distributed/distributed_c10d.py)
[2024-02-08 11:23:34,930] torch._dynamo.eval_frame: [DEBUG] skipping: current_trace_id (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_guards.py)
[2024-02-08 11:23:34,930] torch._dynamo.eval_frame: [DEBUG] skipping: try_get (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_guards.py)
[2024-02-08 11:23:34,930] torch._dynamo.eval_frame: [DEBUG] skipping: <genexpr> (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_logging/_internal.py)
[2024-02-08 11:23:34,928] torch._dynamo.eval_frame: [DEBUG] Setting top-level compile config hash: a9446d0645a24f8e5db15f38d621b2a5
[2024-02-08 11:23:34,930] torch._dynamo.eval_frame: [DEBUG] skipping: flush (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/logging/__init__.py)
[2024-02-08 11:23:34,930] torch._dynamo.eval_frame: [DEBUG] skipping: release (reason: in skipfiles, file: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/logging/__init__.py)
[2024-02-08 11:23:34,931] [0/0] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing inner /home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/external_utils.py:15
[2024-02-08 11:23:34,931] [0/0] torch.fx.experimental.symbolic_shapes: [INFO] create_env
[2024-02-08 11:23:34,933] [0/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/external_utils.py:15 in inner (wrap_inline)
[2024-02-08 11:23:34,933] [0/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]         @functools.wraps(fn)
[2024-02-08 11:23:34,934] [0/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG] TRACE starts_line /home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/external_utils.py:17 in inner (wrap_inline.inner)
[2024-02-08 11:23:34,934] [0/0] torch._dynamo.symbolic_convert.__trace_source: [DEBUG]             return fn(*args, **kwargs)
[2024-02-08 11:23:34,934] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_DEREF fn []
[2024-02-08 11:23:34,934] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST args [LazyVariableTracker()]
[2024-02-08 11:23:34,934] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE BUILD_MAP 0 [LazyVariableTracker(), LazyVariableTracker()]
[2024-02-08 11:23:34,934] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE LOAD_FAST kwargs [LazyVariableTracker(), LazyVariableTracker(), ConstDictVariable()]
[2024-02-08 11:23:34,934] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE DICT_MERGE 1 [LazyVariableTracker(), LazyVariableTracker(), ConstDictVariable(), LazyVariableTracker()]
[2024-02-08 11:23:34,935] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE CALL_FUNCTION_EX 1 [LazyVariableTracker(), LazyVariableTracker(), ConstDictVariable()]
[2024-02-08 11:23:34,935] [0/0] torch._dynamo.output_graph: [DEBUG] create_graph_input L_args_0_ L['args'][0]
[2024-02-08 11:23:34,936] [0/0] torch._dynamo.variables.builder: [DEBUG] wrap_to_fake L['args'][0] (5, 2) [<DimDynamic.STATIC: 2>, <DimDynamic.STATIC: 2>] [None, None]
[2024-02-08 11:23:34,942] [0/0] torch._dynamo.symbolic_convert: [DEBUG] TRACE RETURN_VALUE None [TensorVariable()]
[2024-02-08 11:23:34,942] [0/0] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing inner (RETURN_VALUE)
[2024-02-08 11:23:34,942] [0/0] torch._dynamo.symbolic_convert: [DEBUG] RETURN_VALUE triggered compile
[2024-02-08 11:23:34,942] [0/0] torch._dynamo.output_graph: [DEBUG] COMPILING GRAPH due to GraphCompileReason(reason='return_value', user_stack=[<FrameSummary file /home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/external_utils.py, line 17 in inner>], graph_break=False)
[2024-02-08 11:23:34,944] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] TRACED GRAPH
[2024-02-08 11:23:34,944] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]  ===== __compiled_fn_0 =====
[2024-02-08 11:23:34,944] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]  <eval_with_key>.0 class GraphModule(torch.nn.Module):
[2024-02-08 11:23:34,944] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]     def forward(self, L_args_0_ : torch.Tensor):
[2024-02-08 11:23:34,944] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         l_args_0_ = L_args_0_
[2024-02-08 11:23:34,944] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         
[2024-02-08 11:23:34,944] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         # File: /home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/external_utils.py:17, code: return fn(*args, **kwargs)
[2024-02-08 11:23:34,944] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         fn = self.fn(l_args_0_);  l_args_0_ = None
[2024-02-08 11:23:34,944] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         return (fn,)
[2024-02-08 11:23:34,944] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG]         
[2024-02-08 11:23:34,944] [0/0] torch._dynamo.output_graph.__graph_code: [DEBUG] 
[2024-02-08 11:23:34,944] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] TRACED GRAPH
[2024-02-08 11:23:34,944] [0/0] torch._dynamo.output_graph.__graph: [DEBUG]  __compiled_fn_0 <eval_with_key>.0 opcode       name       target     args          kwargs
[2024-02-08 11:23:34,944] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] -----------  ---------  ---------  ------------  --------
[2024-02-08 11:23:34,944] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] placeholder  l_args_0_  L_args_0_  ()            {}
[2024-02-08 11:23:34,944] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] call_module  fn         fn         (l_args_0_,)  {}
[2024-02-08 11:23:34,944] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] output       output     output     ((fn,),)      {}
[2024-02-08 11:23:34,944] [0/0] torch._dynamo.output_graph.__graph: [DEBUG] 
[2024-02-08 11:23:34,948] [0/0] torch._dynamo.output_graph.__graph_sizes: [DEBUG] TRACED GRAPH TENSOR SIZES
[2024-02-08 11:23:34,948] [0/0] torch._dynamo.output_graph.__graph_sizes: [DEBUG] ===== __compiled_fn_0 =====
[2024-02-08 11:23:34,948] [0/0] torch._dynamo.output_graph.__graph_sizes: [DEBUG] l_args_0_: (5, 2)
[2024-02-08 11:23:34,948] [0/0] torch._dynamo.output_graph.__graph_sizes: [DEBUG] fn: (5, 2)
[2024-02-08 11:23:34,948] [0/0] torch._dynamo.output_graph.__graph_sizes: [DEBUG] 
[2024-02-08 11:23:34,948] [0/0] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function inductor
[2024-02-08 11:23:34,968] torch._dynamo.eval_frame: [DEBUG] Unsetting top-level compile config hash: a9446d0645a24f8e5db15f38d621b2a5
Traceback (most recent call last):
  File "/home/fynnsu/***/scrap/rrelu.py", line 14, in <module>
    print(m(x))
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 489, in _fn
    return fn(*args, **kwargs)
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 655, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 727, in _convert_frame
    result = inner_convert(frame, cache_entry, hooks, frame_state)
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 383, in _convert_frame_assert
    compiled_product = _compile(
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 646, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 562, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
    transformations(instructions, code_options)
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 151, in _fn
    return fn(*args, **kwargs)
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/convert_frame.py", line 527, in transform
    tracer.run()
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2128, in run
    super().run()
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 818, in run
    and self.step()
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 781, in step
    getattr(self, inst.opname)(inst)
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/symbolic_convert.py", line 2243, in RETURN_VALUE
    self.output.compile_subgraph(
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 919, in compile_subgraph
    self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 1087, in compile_and_call_fx_graph
    compiled_fn = self.call_user_compiler(gm)
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 1159, in call_user_compiler
    raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/output_graph.py", line 1140, in call_user_compiler
    compiled_fn = compiler_fn(gm, self.example_inputs())
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
    compiled_gm = compiler_fn(gm, example_inputs)
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/__init__.py", line 1662, in __call__
    return compile_fx(model_, inputs_, config_patches=self.config)
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 1168, in compile_fx
    return aot_autograd(
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/backends/common.py", line 55, in compiler_fn
    cg = aot_module_simplified(gm, example_inputs, **kwargs)
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 887, in aot_module_simplified
    compiled_fn = create_aot_dispatcher_function(
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 244, in time_wrapper
    r = func(*args, **kwargs)
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 600, in create_aot_dispatcher_function
    compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 425, in aot_wrapper_dedupe
    return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 630, in aot_wrapper_synthetic_base
    return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 72, in aot_dispatch_base
    fw_module, updated_flat_args, maybe_subclass_meta = aot_dispatch_base_graph(  # type: ignore[misc]
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/dispatch_and_compile_graph.py", line 85, in aot_dispatch_base_graph
    copy_count = assert_functional_graph(
  File "/home/fynnsu/miniconda3/envs/uw/lib/python3.9/site-packages/torch/_functorch/_aot_autograd/functional_utils.py", line 331, in assert_functional_graph
    assert n.args[0] in placeholders
torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
AssertionError: 

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

[2024-02-08 11:23:35,008] torch._dynamo.utils: [INFO] TorchDynamo compilation metrics:
[2024-02-08 11:23:35,008] torch._dynamo.utils: [INFO] Function                           Runtimes (s)
[2024-02-08 11:23:35,008] torch._dynamo.utils: [INFO] -------------------------------  --------------
[2024-02-08 11:23:35,008] torch._dynamo.utils: [INFO] _compile.<locals>.compile_inner               0
[2024-02-08 11:23:35,008] torch._dynamo.utils: [INFO] OutputGraph.call_user_compiler                0
[2024-02-08 11:23:35,008] torch._dynamo.utils: [INFO] create_aot_dispatcher_function                0

Minified repro

import torch
m = torch.compile(torch.nn.RReLU())
m(torch.randn(5, 2))

I also tried using the torch.rrelu fn directly and ran into a different error (rrelu_with_noise() missing 2 required positional arguments: 'lower' and 'upper') doing so (This might be a separate bug):

import torch

@torch.compile
def fn(x, lower, upper):
    return torch.rrelu(x, lower, upper)

fn(torch.randn(5, 2), 1 / 8, 1 / 3)

Versions

PyTorch version: 2.2.0+cu121 Is debug build: False CUDA used to build PyTorch: 12.1 ROCM used to build PyTorch: N/A

OS: Ubuntu 22.04.1 LTS (x86_64) GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0 Clang version: Could not collect CMake version: Could not collect Libc version: glibc-2.35

Python version: 3.9.0 (default, Nov 15 2020, 14:28:56) [GCC 7.3.0] (64-bit runtime) Python platform: Linux-6.2.0-37-generic-x86_64-with-glibc2.35 Is CUDA available: True CUDA runtime version: Could not collect CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA GeForce GTX 1080 Ti Nvidia driver version: 535.113.01 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Architecture: x86_64 CPU op-mode(s): 32-bit, 64-bit Address sizes: 43 bits physical, 48 bits virtual Byte Order: Little Endian CPU(s): 16 On-line CPU(s) list: 0-15 Vendor ID: AuthenticAMD Model name: AMD Ryzen 7 3800X 8-Core Processor CPU family: 23 Model: 113 Thread(s) per core: 2 Core(s) per socket: 8 Socket(s): 1 Stepping: 0 Frequency boost: enabled CPU max MHz: 4558.8862 CPU min MHz: 2200.0000 BogoMIPS: 7800.12 Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx mmxext fxsr_opt pdpe1gb rdtscp lm constant_tsc rep_good nopl nonstop_tsc cpuid extd_apicid aperfmperf rapl pni pclmulqdq monitor ssse3 fma cx16 sse4_1 sse4_2 movbe popcnt aes xsave avx f16c rdrand lahf_lm cmp_legacy svm extapic cr8_legacy abm sse4a misalignsse 3dnowprefetch osvw ibs skinit wdt tce topoext perfctr_core perfctr_nb bpext perfctr_llc mwaitx cpb cat_l3 cdp_l3 hw_pstate ssbd mba ibpb stibp vmmcall fsgsbase bmi1 avx2 smep bmi2 cqm rdt_a rdseed adx smap clflushopt clwb sha_ni xsaveopt xsavec xgetbv1 cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local clzero irperf xsaveerptr rdpru wbnoinvd arat npt lbrv svm_lock nrip_save tsc_scale vmcb_clean flushbyasid decodeassists pausefilter pfthreshold avic v_vmsave_vmload vgif v_spec_ctrl umip rdpid overflow_recov succor smca sev sev_es Virtualization: AMD-V L1d cache: 256 KiB (8 instances) L1i cache: 256 KiB (8 instances) L2 cache: 4 MiB (8 instances) L3 cache: 32 MiB (2 instances) NUMA node(s): 1 NUMA node0 CPU(s): 0-15 Vulnerability Gather data sampling: Not affected Vulnerability Itlb multihit: Not affected Vulnerability L1tf: Not affected Vulnerability Mds: Not affected Vulnerability Meltdown: Not affected Vulnerability Mmio stale data: Not affected Vulnerability Retbleed: Mitigation; untrained return thunk; SMT enabled with STIBP protection Vulnerability Spec rstack overflow: Mitigation; safe RET Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization Vulnerability Spectre v2: Mitigation; Retpolines, IBPB conditional, STIBP always-on, RSB filling, PBRSB-eIBRS Not affected Vulnerability Srbds: Not affected Vulnerability Tsx async abort: Not affected

Versions of relevant libraries: [pip3] mypy==1.8.0 [pip3] mypy-extensions==1.0.0 [pip3] numpy==1.26.3 [pip3] torch==2.2.0 [pip3] torchaudio==2.2.0 [pip3] torchvision==0.17.0 [pip3] torchviz==0.0.2 [pip3] triton==2.2.0 [conda] numpy 1.26.3 pypi_0 pypi [conda] torch 2.2.0 pypi_0 pypi [conda] torchaudio 2.2.0 pypi_0 pypi [conda] torchvision 0.17.0 pypi_0 pypi [conda] torchviz 0.0.2 pypi_0 pypi [conda] triton 2.2.0 pypi_0 pypi

cc @ezyang @msaroufim @bdhirsh @anijain2305 @zou3519

tringwald commented 9 months ago

The rrelu_with_noise() missing 2 required positional arguments: 'lower' and 'upper' problem is a duplicate of https://github.com/pytorch/pytorch/issues/115811 and should be fixed already. The torch.nn.RReLU version still doesn't work though.

tringwald commented 9 months ago

The problem seems to be that torch.nn.RReLU defaults to training = True and therefore takes another code path than torch.nn.functional.rrelu. This code path then inplace copies to the RReLU noise tensor, which triggers the assert. I've put up a PR that sidesteps the problem, but that really needs to be checked carefully as it violates the invariant that is mentioned in the code comments.

soulitzer commented 9 months ago

I think the issue may be that the decomp for torch._ops.aten.rrelu_with_noise.default has an in-place op copy_. Since FunctionalTensorMode runs above ProxyTensorMode (which runs the decomps), maybe its good to say that generally decomps for out-of-place ops shouldn't be allowed to have in-place operations.

tiru1930 commented 6 months ago

i have seen similar issue with 'torch.ops.aten._scaled_dot_product_flash_attention_for_cpu', in our coustm backend , so after removing this from decomposition table, there is no error.