Open RoiEXLab opened 1 year ago
Thanks a lot for the report! Because of potential issues like these, I'm waiting until Pytorch 2.x is more stable before upgrading the entire repo. If the underlying issue is complex numbers with torch.compile (can you repro the error with a minimal example?) there's not much I can do and it would be great to file an issue directly with Pytorch.
Thanks for the reply.
can you repro the error with a minimal example?
It should be rather easy to reproduce, just by running a simple forward pass on the standalone sashimi model after it has been compiled. But I'll try to provide one within the next 24 hours if I get to it.
Right, although it would be helpful to see if it fails with a more minimal model than the S4 layer if the line you pointed out is indeed the problem. If you don't get to it, I'll keep this in mind when I get around to trying to upgrade the library versions.
Ah I see. I'll see what I can do
@albertfgu Here you go: https://gist.github.com/RoiEXLab/5cc1630aca71b603528a574b2a2e3326
It turns out the SSKernel
seems to be the issue. When running python3 reproducer.py
(see gist, make sure to install the required dependencies, some minor adjustments were made to the files to keep it as simple as possible). I get the following error:
CUDA extension for cauchy multiplication not found. Install by going to extensions/cauchy/ and running `python setup.py install`. This should speed up end-to-end training by 10-50%
[2023-03-22 16:52:30,027] torch._inductor.graph: [ERROR] Error from lowering
Traceback (most recent call last):
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/graph.py", line 333, in call_function
out = lowerings[target](*args, **kwargs)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 225, in wrapped
out = decomp_fn(*args, **kwargs)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 3633, in mul
return make_pointwise(fn)(a, b)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 329, in inner
loaders = [x.make_loader() for x in inputs]
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 329, in <listcomp>
loaders = [x.make_loader() for x in inputs]
AttributeError: 'complex' object has no attribute 'make_loader'
Traceback (most recent call last):
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/graph.py", line 333, in call_function
out = lowerings[target](*args, **kwargs)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 225, in wrapped
out = decomp_fn(*args, **kwargs)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 3633, in mul
return make_pointwise(fn)(a, b)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 329, in inner
loaders = [x.make_loader() for x in inputs]
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/lowering.py", line 329, in <listcomp>
loaders = [x.make_loader() for x in inputs]
AttributeError: 'complex' object has no attribute 'make_loader'
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 670, in call_user_compiler
compiled_fn = compiler_fn(gm, self.fake_example_inputs())
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/debug_utils.py", line 1055, in debug_wrapper
compiled_gm = compiler_fn(gm, example_inputs)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/__init__.py", line 1390, in __call__
return compile_fx(model_, inputs_, config_patches=self.config)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 455, in compile_fx
return aot_autograd(
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/backends/common.py", line 48, in compiler_fn
cg = aot_module_simplified(gm, example_inputs, **kwargs)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2805, in aot_module_simplified
compiled_fn = create_aot_dispatcher_function(
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 2498, in create_aot_dispatcher_function
compiled_fn = compiler_fn(flat_fn, fake_flat_args, aot_config)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 1713, in aot_wrapper_dedupe
return compiler_fn(flat_fn, leaf_flat_args, aot_config)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_functorch/aot_autograd.py", line 1326, in aot_dispatch_base
compiled_fw = aot_config.fw_compiler(fw_module, flat_args_with_views_handled)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 430, in fw_compiler
return inner_compile(
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/debug_utils.py", line 595, in debug_wrapper
compiled_fn = compiler_fn(gm, example_inputs)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/debug.py", line 239, in inner
return fn(*args, **kwargs)
File "/usr/lib/python3.8/contextlib.py", line 75, in inner
return func(*args, **kwds)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/compile_fx.py", line 176, in compile_fx_inner
graph.run(*example_inputs)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/graph.py", line 194, in run
return super().run(*args)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/fx/interpreter.py", line 136, in run
self.env[node] = self.run_node(node)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/graph.py", line 407, in run_node
result = super().run_node(n)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/fx/interpreter.py", line 177, in run_node
return getattr(self, n.op)(n.target, args, kwargs)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_inductor/graph.py", line 337, in call_function
raise LoweringException(e, target, args, kwargs) from e
torch._inductor.exc.LoweringException: AttributeError: 'complex' object has no attribute 'make_loader'
target: aten.mul.Tensor
args[0]: TensorBox(StorageBox(
InputBuffer(name='arg1_1', layout=FixedLayout('cuda', torch.float32, size=[64, 32], stride=[32, 1]))
))
args[1]: 1j
While executing %mul : [#users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg1_1, 1j), kwargs = {})
Original traceback:
File "/home/roiex/s4-reproducer/s4.py", line 702, in _w
w = w_real + 1j * self.w_imag
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "reproducer.py", line 14, in <module>
y, _ = model(x)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 209, in _fn
return fn(*args, **kwargs)
File "/home/roiex/s4-reproducer/s4.py", line 1312, in forward
return self.kernel(state=state, L=L, rate=rate)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
return forward_call(*args, **kwargs)
File "/home/roiex/s4-reproducer/s4.py", line 717, in forward
if self.L.item() == 0 and self.l_max is not None and self.l_max > 0:
File "/home/roiex/s4-reproducer/s4.py", line 723, in <graph break in forward>
L = round(self.L.item() / rate)
File "/home/roiex/s4-reproducer/s4.py", line 736, in <graph break in forward>
w = self._w() # (n_ssm, N)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/eval_frame.py", line 337, in catch_errors
return callback(frame, cache_size, hooks)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 404, in _convert_frame
result = inner_convert(frame, cache_size, hooks)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 104, in _fn
return fn(*args, **kwargs)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 262, in _convert_frame_assert
return _compile(
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 324, in _compile
out_code = transform_code_object(code, transform)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/bytecode_transformation.py", line 445, in transform_code_object
transformations(instructions, code_options)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/convert_frame.py", line 311, in transform
tracer.run()
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1726, in run
super().run()
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 576, in run
and self.step()
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 540, in step
getattr(self, inst.opname)(inst)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/symbolic_convert.py", line 1792, in RETURN_VALUE
self.output.compile_subgraph(
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 517, in compile_subgraph
self.compile_and_call_fx_graph(tx, list(reversed(stack_values)), root)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 588, in compile_and_call_fx_graph
compiled_fn = self.call_user_compiler(gm)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/utils.py", line 163, in time_wrapper
r = func(*args, **kwargs)
File "/home/roiex/s4-reproducer/venv/lib/python3.8/site-packages/torch/_dynamo/output_graph.py", line 675, in call_user_compiler
raise BackendCompilerFailed(self.compiler_fn, e) from e
torch._dynamo.exc.BackendCompilerFailed: debug_wrapper raised LoweringException: AttributeError: 'complex' object has no attribute 'make_loader'
target: aten.mul.Tensor
args[0]: TensorBox(StorageBox(
InputBuffer(name='arg1_1', layout=FixedLayout('cuda', torch.float32, size=[64, 32], stride=[32, 1]))
))
args[1]: 1j
While executing %mul : [#users=1] = call_function[target=torch.ops.aten.mul.Tensor](args = (%arg1_1, 1j), kwargs = {})
Original traceback:
File "/home/roiex/s4-reproducer/s4.py", line 702, in _w
w = w_real + 1j * self.w_imag
You can suppress this exception and fall back to eager by setting:
torch._dynamo.config.suppress_errors = True
So it does indeed seem the issue is with complex numbers. Looking at the pytorch repo there are a lot of issues open regarding complex numbers, but I'm not quite how well they apply to this exact issue. Also I tried using different backends for the compilation (see import torch._dynamo; torch._dynamo.list_backends()
), but they also didn't work out of the box (I assume some need additional dependencies installed, but the ones without extra dependencies didn't work either).
Same issue here, I import the SSMKernelDPLR
in state_spaces.models.s4.s4
as a module of my custom model, and that cause the error when I try torch.compile
. Any progress on this problem?
Unfortunately this is a missing functionality on PyTorch's end (in turn coming from lack of support in Triton): https://github.com/pytorch/pytorch/issues/98161. The PyTorch team is aware of this and may look to support it eventually, but it's unclear how long that would take.
I don't think that the core state space kernels (SSKernelDiag
or SSKernelDPLR
) are a bottleneck for larger scale models. The main benefit of compilation would be fusing together the main computation pathway of the FFT-convolution and the surrounding linears. Unfortunately I don't see a way to do this at the moment.
Hi,
I'm using the sashimi model on my own dataset with reasonable success for a while now and I wanted to see if I could use the recently released
torch.compile
function on the sashimi model to speed up training for my experiments.Unfortunately it doesn't work. The following line seems to fail (for reasons I don't understand): https://github.com/HazyResearch/state-spaces/blob/06dbbdfd0876501a7f12bf3262121badbc7658af/src/models/s4/s4.py#L703
On the pytorch site there's some information on how to deal with those issues, so I hope the code can be extended in the future to run faster by a noticeable amount.
Thanks in advance.