unslothai / unsloth

Finetune Llama 3.2, Mistral, Phi, Qwen 2.5 & Gemma LLMs 2-5x faster with 80% less memory
https://unsloth.ai
Apache License 2.0
18.58k stars 1.3k forks source link

torch.compile fails #1175

Closed fzyzcjy closed 1 month ago

fzyzcjy commented 1 month ago

Hi thanks for the lib! When trying torch_compile=True for trainer for llama3.2 1B, I see:

``` [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] Traceback (most recent call last): [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 948, in __call__ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] result = self._inner_convert( [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 472, in __call__ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] return _compile( [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_utils_internal.py", line 84, in wrapper_function [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] return StrobelightCompileTimeProfiler.profile_compile_time( [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_strobelight/compile_time_profiler.py", line 129, in profile_compile_time [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] return func(*args, **kwargs) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/contextlib.py", line 81, in inner [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] return func(*args, **kwds) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 817, in _compile [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] guarded_code = compile_inner(code, one_graph, hooks, transform) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] r = func(*args, **kwargs) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 636, in compile_inner [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] out_code = transform_code_object(code, transform) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/bytecode_transformation.py", line 1185, in transform_code_object [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] transformations(instructions, code_options) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 178, in _fn [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] return fn(*args, **kwargs) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/convert_frame.py", line 582, in transform [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] tracer.run() [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2451, in run [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] super().run() [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 893, in run [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] while self.step(): [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 805, in step [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] self.dispatch_table[inst.opcode](self, inst) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2642, in RETURN_VALUE [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] self._return(inst) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/symbolic_convert.py", line 2627, in _return [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] self.output.compile_subgraph( [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1123, in compile_subgraph [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] self.compile_and_call_fx_graph(tx, pass2.graph_output_vars(), root) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/contextlib.py", line 81, in inner [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] return func(*args, **kwds) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1318, in compile_and_call_fx_graph [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] compiled_fn = self.call_user_compiler(gm) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] r = func(*args, **kwargs) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1409, in call_user_compiler [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] raise BackendCompilerFailed(self.compiler_fn, e).with_traceback( [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/output_graph.py", line 1390, in call_user_compiler [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] compiled_fn = compiler_fn(gm, self.example_inputs()) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/repro/after_dynamo.py", line 129, in __call__ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] compiled_gm = compiler_fn(gm, example_inputs) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/__init__.py", line 1951, in __call__ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] return compile_fx(model_, inputs_, config_patches=self.config) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/contextlib.py", line 81, in inner [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] return func(*args, **kwds) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1505, in compile_fx [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] return aot_autograd( [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/backends/common.py", line 69, in __call__ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] cg = aot_module_simplified(gm, example_inputs, **self.kwargs) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 954, in aot_module_simplified [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] compiled_fn, _ = create_aot_dispatcher_function( [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] r = func(*args, **kwargs) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_functorch/aot_autograd.py", line 687, in create_aot_dispatcher_function [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] compiled_fn, fw_metadata = compiler_fn( [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_functorch/_aot_autograd/jit_compile_runtime_wrappers.py", line 168, in aot_dispatch_base [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] compiled_fw = compiler(fw_module, updated_flat_args) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] r = func(*args, **kwargs) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 1410, in fw_compiler_base [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] return inner_compile( [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/repro/after_aot.py", line 84, in debug_wrapper [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] inner_compiled_fn = compiler_fn(gm, example_inputs) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/debug.py", line 304, in inner [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] return fn(*args, **kwargs) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/contextlib.py", line 81, in inner [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] return func(*args, **kwds) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/contextlib.py", line 81, in inner [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] return func(*args, **kwds) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] r = func(*args, **kwargs) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 527, in compile_fx_inner [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] compiled_graph = fx_codegen_and_compile( [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/contextlib.py", line 81, in inner [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] return func(*args, **kwds) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/compile_fx.py", line 831, in fx_codegen_and_compile [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] compiled_fn = graph.compile_to_fn() [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1749, in compile_to_fn [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] return self.compile_to_module().call [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_dynamo/utils.py", line 231, in time_wrapper [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] r = func(*args, **kwargs) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/graph.py", line 1699, in compile_to_module [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] mod = PyCodeCache.load_by_key_path( [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 3062, in load_by_key_path [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] mod = _reload_python_module(key, path) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/runtime/compile_tasks.py", line 45, in _reload_python_module [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] exec(code, mod.__dict__, mod.__dict__) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/tmp/torchinductor_jovyan/6e/c6eiyvjd7ahfvedk55xe3y5woqx6hmck3qxao6c7hnuvhn7ljkys.py", line 31, in [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] _rope_embedding_0 = async_compile.triton('_rope_embedding', ''' [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/async_compile.py", line 173, in triton [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] kernel = TritonCodeCache.load(kernel_name, source_code) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 3112, in load [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] return _module_to_triton_kernel(PyCodeCache.load(source_code), kernel_name) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 3049, in load [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] return cls.load_by_key_path(key, path, linemap, attrs) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/codecache.py", line 3062, in load_by_key_path [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] mod = _reload_python_module(key, path) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/opt/conda/lib/python3.11/site-packages/torch/_inductor/runtime/compile_tasks.py", line 45, in _reload_python_module [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] exec(code, mod.__dict__, mod.__dict__) [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] File "/tmp/torchinductor_jovyan/dh/cdhxmnxvtwxyu7omwappfbadvtpd5efmpljibtjbpyrbw7j5mkm4.py", line 14, in [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] triton_meta={'signature': {0: '*bf16', 1: 'i32', 2: '*bf16', 3: 'i32', 4: '*bf16', 5: 'i32', 6: 'i32'}, 'device': DeviceProperties(type='cuda', index=0, cc=89, major=8, regs_per_multiprocessor=65536, max_threads_per_multi_processor=1536, multi_processor_count=114), 'constants': {7: 64, 8: s4, 9: False, 10: 32}, 'configs': [AttrsDescriptor(divisible_by_16=(0, 2, 3, 4, 5), equal_to_1=())]}, [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ^^ [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised: [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] NameError: name 's4' is not defined [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information [01:17:10.938]: W1024 01:17:10.936000 140404233111360 torch/_dynamo/convert_frame.py:1009] ```
danielhanchen commented 1 month ago

@fzyzcjy Oh wait try not to use compile = True :) It won't make stuff that much faster and it's broken sadly :(

fzyzcjy commented 1 month ago

Ok :( Heard people saying compile improves speed, thus wanted to have a try.

danielhanchen commented 1 month ago

Oh Unsloth already makes it 2x faster, so it's nearly the maximum of hardware limits, so you might get 1-2% faster with torch.compile - so no need :)