llama models fail to torch.compile - errs out re: 2 compatible backends for target (cuda) ([<class 'c.CUDABackend'>, <class 'nvi.CUDABackend'>]). There should only be one. #55
with latest nightly I am unable to compile the llama models (any size):
[rank0]:2024-02-12 10:00:40,026 - root - INFO - Compiling model llama with torch.compile...
[rank0]:[rank0]:[2024-02-12 10:00:42,377] [0/0] torch._dynamo.variables.torch: [WARNING] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
[rank0]:NCCL version 2.19.3+cuda12.3
[rank0]:/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_inductor/lowering.py:1697: UserWarning: Torchinductor does not support code generation for complex operators. Performance may be worse than eager.
[rank0]: warnings.warn(
[rank0]:[rank0]: Traceback (most recent call last):
[rank0]:[rank0]: File "/data/users/less/local/torchtrain/train.py", line 288, in <module>
[rank0]:[rank0]: main(args)
[rank0]:[rank0]: File "/data/users/less/local/torchtrain/train.py", line 161, in main
[rank0]:[rank0]: pred = model(input_ids)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:[rank0]: return self._call_impl(*args, **kwargs)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:[rank0]: return forward_call(*args, **kwargs)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 454, in _fn
[rank0]:[rank0]: return fn(*args, **kwargs)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/external_utils.py", line 25, in inner
[rank0]:[rank0]: return fn(*args, **kwargs)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:[rank0]: return self._call_impl(*args, **kwargs)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:[rank0]: return forward_call(*args, **kwargs)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 853, in forward
[rank0]:[rank0]: output = self._fsdp_wrapped_module(*args, **kwargs)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:[rank0]: return self._call_impl(*args, **kwargs)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:[rank0]: return forward_call(*args, **kwargs)
[rank0]:[rank0]: File "/data/users/less/local/torchtrain/torchtrain/models/llama/model.py", line 485, in forward
[rank0]:[rank0]: h, freqs_cis = self.embeddings(tokens)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
[rank0]:[rank0]: return self._call_impl(*args, **kwargs)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
[rank0]:[rank0]: return forward_call(*args, **kwargs)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 904, in catch_errors
[rank0]:[rank0]: return callback(frame, cache_entry, hooks, frame_state, skip=1)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 769, in _convert_frame
[rank0]:[rank0]: result = inner_convert(
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 398, in _convert_frame_assert
[rank0]:[rank0]: return _compile(
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/contextlib.py", line 79, in inner
[rank0]:[rank0]: return func(*args, **kwds)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 669, in _compile
[rank0]:[rank0]: guarded_code = compile_inner(code, one_graph, hooks, transform)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 250, in time_wrapper
[rank0]:[rank0]: r = func(*args, **kwargs)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 542, in compile_inner
[rank0]:[rank0]: out_code = transform_code_object(code, transform)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1033, in transform_code_object
[rank0]:[rank0]: transformations(instructions, code_options)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 163, in _fn
[rank0]:[rank0]: return fn(*args, **kwargs)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 507, in transform
[rank0]:[rank0]: tracer.run()
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2130, in run
[rank0]:[rank0]: super().run()
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 793, in run
[rank0]:[rank0]: and self.step()
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 756, in step
[rank0]:[rank0]: getattr(self, inst.opname)(inst)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1302, in STORE_ATTR
[rank0]:[rank0]: return self.store_attr_graph_break(inst)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 1326, in store_attr_graph_break
[rank0]:[rank0]: self.output.compile_subgraph(
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 957, in compile_subgraph
[rank0]:[rank0]: self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/contextlib.py", line 79, in inner
[rank0]:[rank0]: return func(*args, **kwds)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1102, in compile_and_call_fx_graph
[rank0]:[rank0]: compiled_fn = self.call_user_compiler(gm)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 250, in time_wrapper
[rank0]:[rank0]: r = func(*args, **kwargs)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1175, in call_user_compiler
[rank0]:[rank0]: raise BackendCompilerFailed(self.compiler_fn, e).with_traceback(
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/output_graph.py", line 1156, in call_user_compiler
[rank0]:[rank0]: compiled_fn = compiler_fn(gm, self.example_inputs())
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/repro/after_dynamo.py", line 117, in debug_wrapper
[rank0]:[rank0]: compiled_gm = compiler_fn(gm, example_inputs)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/__init__.py", line 1730, in __call__
[rank0]:[rank0]: return compile_fx(model_, inputs_, config_patches=self.config)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/contextlib.py", line 79, in inner
[rank0]:[rank0]: return func(*args, **kwds)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1321, in compile_fx
[rank0]:[rank0]: return aot_autograd(
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/backends/common.py", line 57, in compiler_fn
[rank0]:[rank0]: cg = aot_module_simplified(gm, example_inputs, **kwargs)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 879, in aot_module_simplified
[rank0]:[rank0]: compiled_fn = create_aot_dispatcher_function(
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 250, in time_wrapper
[rank0]:[rank0]: r = func(*args, **kwargs)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_functorch/aot_autograd.py", line 604, in create_aot_dispatcher_function
[rank0]:[rank0]: compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config, fw_metadata=fw_metadata)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 434, in aot_wrapper_dedupe
[rank0]:[rank0]: return compiler_fn(flat_fn, leaf_flat_args, aot_config, fw_metadata=fw_metadata)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 639, in aot_wrapper_synthetic_base
[rank0]:[rank0]: return compiler_fn(flat_fn, flat_args, aot_config, fw_metadata=fw_metadata)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 297, in aot_dispatch_autograd
[rank0]:[rank0]: compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 250, in time_wrapper
[rank0]:[rank0]: r = func(*args, **kwargs)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 1249, in fw_compiler_base
[rank0]:[rank0]: return inner_compile(
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/repro/after_aot.py", line 83, in debug_wrapper
[rank0]:[rank0]: inner_compiled_fn = compiler_fn(gm, example_inputs)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_inductor/debug.py", line 304, in inner
[rank0]:[rank0]: return fn(*args, **kwargs)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/contextlib.py", line 79, in inner
[rank0]:[rank0]: return func(*args, **kwds)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/contextlib.py", line 79, in inner
[rank0]:[rank0]: return func(*args, **kwds)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 423, in compile_fx_inner
[rank0]:[rank0]: compiled_graph = fx_codegen_and_compile(
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_inductor/compile_fx.py", line 689, in fx_codegen_and_compile
[rank0]:[rank0]: compiled_fn = graph.compile_to_fn()
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1224, in compile_to_fn
[rank0]:[rank0]: return self.compile_to_module().call
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 250, in time_wrapper
[rank0]:[rank0]: r = func(*args, **kwargs)
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_inductor/graph.py", line 1176, in compile_to_module
[rank0]:[rank0]: mod = PyCodeCache.load_by_key_path(
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 2053, in load_by_key_path
[rank0]:[rank0]: exec(code, mod.__dict__, mod.__dict__)
[rank0]:[rank0]: File "/tmp/torchinductor_less/rf/crfqzglf5slafezlc46mvfhf7rn5xkcdjgxfhjfyek4ffjh5mbpo.py", line 72, in <module>
[rank0]:[rank0]: async_compile.wait(globals())
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 2593, in wait
[rank0]:[rank0]: scope[key] = result.result()
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/site-packages/torch/_inductor/codecache.py", line 2400, in result
[rank0]:[rank0]: self.future.result()
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/concurrent/futures/_base.py", line 445, in result
[rank0]:[rank0]: return self.__get_result()
[rank0]:[rank0]: File "/home/less/local/miniconda3/envs/triton/lib/python3.10/concurrent/futures/_base.py", line 390, in __get_result
[rank0]:[rank0]: raise self._exception
[rank0]:[rank0]: torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
[rank0]:[rank0]: RuntimeError: 2 compatible backends for target (cuda) ([<class 'c.CUDABackend'>, <class 'nvi.CUDABackend'>]). There should only be one.
[rank0]:
[rank0]:[rank0]: Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
[rank0]:
[rank0]:
[rank0]:[rank0]: You can suppress this exception and fall back to eager by setting:
[rank0]:[rank0]: import torch._dynamo
[rank0]:[rank0]: torch._dynamo.config.suppress_errors = True
[rank0]:
[2024-02-12 10:00:55,905] torch.distributed.elastic.multiprocessing.api: [WARNING] Sending process 1864627 closing signal SIGTERM
This turned out to be something from cuda kernel work.
Was not able to isolate the root cause, but making a new conda env and all went back to normal so closing.
with latest nightly I am unable to compile the llama models (any size):