Open llmexperiment opened 4 months ago
We have no experience with ONNX. Do you have ideas on how to generate onnx for custom operations? If so would you like to contribute?
We have no experience with ONNX. Do you have ideas on how to generate onnx for custom operations? If so would you like to contribute?
Thanks @tridao ! I am working towards that direction (here is a how to do it: https://github.com/onnx/tutorials/blob/master/PyTorchCustomOperator/README.md)
Could you let me know what are other custom operators if any other than scan?
I know the code for scan is here (if I am not mistaken): https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan.cpp
Any chance, if you have standalone implementation of scan?
@llmexperiment You may be interested in looking at the HF transformers implementation (PR here), which supports a fallback if causal-conv1d
is not found in the environment. I've also been trying to convert Mamba models to ONNX for transformers.js, but I've been running into a few issues. If I figure something out, I'll update the thread.
Hey Im interested in converting the Vision Mamba - Vim paper to onnx but have not had success. I decided to start working with mamba layers first and then proceed there.
This is the current status of my code stack trace
root@ubuntu:/home/xavier01/Documents/workspace/embedded-vpr/Vim/vim# python3 demo_export.py
Fetching 4 files: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 28149.69it/s]
/home/xavier01/Documents/workspace/embedded-vpr/Vim/vim/models_mamba.py:57: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert H == self.img_size[0] and W == self.img_size[1], \
/home/xavier01/Documents/workspace/embedded-vpr/Vim/vim/models_mamba.py:420: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if if_random_token_rank:
/usr/local/lib/python3.8/dist-packages/mamba_ssm-1.1.1-py3.8-linux-aarch64.egg/mamba_ssm/ops/triton/layernorm.py:133: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
assert weight.shape == (N,)
/usr/local/lib/python3.8/dist-packages/mamba_ssm-1.1.1-py3.8-linux-aarch64.egg/mamba_ssm/ops/triton/layernorm.py:150: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
/usr/local/lib/python3.8/dist-packages/mamba_ssm-1.1.1-py3.8-linux-aarch64.egg/mamba_ssm/ops/triton/layernorm.py:151: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
if N > BLOCK_N:
Traceback (most recent call last):
File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 1124, in ast_to_ttir
generator.visit(fn.parse())
File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 1017, in visit
ret = super().visit(node)
File "/usr/lib/python3.8/ast.py", line 371, in visit
return visitor(node)
File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 293, in visit_Module
ast.NodeVisitor.generic_visit(self, node)
File "/usr/lib/python3.8/ast.py", line 379, in generic_visit
self.visit(item)
File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 1017, in visit
ret = super().visit(node)
File "/usr/lib/python3.8/ast.py", line 371, in visit
return visitor(node)
File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 362, in visit_FunctionDef
self.visit_compound_statement(node.body)
File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 288, in visit_compound_statement
ret_type = self.visit(stmt)
File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 1017, in visit
ret = super().visit(node)
File "/usr/lib/python3.8/ast.py", line 371, in visit
return visitor(node)
File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 414, in visit_Assign
values = self.visit(node.value)
File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 1017, in visit
ret = super().visit(node)
File "/usr/lib/python3.8/ast.py", line 371, in visit
return visitor(node)
File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 946, in visit_Call
return fn(*args, **extra_kwargs, **kws)
File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/language/core.py", line 30, in wrapper
return fn(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/language/core.py", line 813, in arange
return semantic.arange(start, end, _builder)
File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/language/semantic.py", line 485, in arange
raise ValueError("arange's arguments must be of type tl.constexpr")
ValueError: arange's arguments must be of type tl.constexpr
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "demo_export.py", line 61, in <module>
torch.onnx.export(model, dummy_input, "vim_s_midclstok_ft_81p6acc_fp16.onnx", input_names=["input"],
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 516, in export
_export(
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 1596, in _export
graph, params_dict, torch_out = _model_to_graph(
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 1135, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args)
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 1011, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 915, in _trace_and_get_graph_from_model
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 1285, in _get_trace_graph
outs = ONNXTracedModule(
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 133, in forward
graph, out = torch._C._create_graph_by_tracing(
File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 124, in wrapper
outs.append(self.inner(*trace_inputs))
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1508, in _slow_forward
result = self.forward(*input, **kwargs)
File "/home/xavier01/Documents/workspace/embedded-vpr/Vim/vim/models_mamba.py", line 541, in forward
x = self.forward_features(x, inference_params, if_random_cls_token_position=if_random_cls_token_position, if_random_token_rank=if_random_token_rank)
File "/home/xavier01/Documents/workspace/embedded-vpr/Vim/vim/models_mamba.py", line 478, in forward_features
hidden_states, residual = layer(
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1508, in _slow_forward
result = self.forward(*input, **kwargs)
File "/home/xavier01/Documents/workspace/embedded-vpr/Vim/vim/models_mamba.py", line 115, in forward
hidden_states, residual = fused_add_norm_fn(
File "/usr/local/lib/python3.8/dist-packages/mamba_ssm-1.1.1-py3.8-linux-aarch64.egg/mamba_ssm/ops/triton/layernorm.py", line 478, in rms_norm_fn
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
File "/usr/local/lib/python3.8/dist-packages/torch/autograd/function.py", line 539, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/usr/local/lib/python3.8/dist-packages/mamba_ssm-1.1.1-py3.8-linux-aarch64.egg/mamba_ssm/ops/triton/layernorm.py", line 411, in forward
y, mean, rstd, residual_out = _layer_norm_fwd(
File "/usr/local/lib/python3.8/dist-packages/mamba_ssm-1.1.1-py3.8-linux-aarch64.egg/mamba_ssm/ops/triton/layernorm.py", line 155, in _layer_norm_fwd
_layer_norm_fwd_1pass_kernel[(M,)](
File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/runtime/autotuner.py", line 100, in run
timings = {config: self._bench(*args, config=config, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/runtime/autotuner.py", line 100, in <dictcomp>
timings = {config: self._bench(*args, config=config, **kwargs)
File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/runtime/autotuner.py", line 83, in _bench
return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/testing.py", line 104, in do_bench
fn()
File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/runtime/autotuner.py", line 81, in kernel_call
self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
File "<string>", line 63, in _layer_norm_fwd_1pass_kernel
File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/compiler.py", line 476, in compile
next_module = compile_kernel(module)
File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/compiler.py", line 381, in <lambda>
lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch))
File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 1133, in ast_to_ttir
raise CompilationError(fn.src, node, repr(e)) from e
triton.compiler.errors.CompilationError: at 31:24: HAS_BIAS: tl.constexpr,
):
# Map the program id to the row of X and Y it should compute.
row = tl.program_id(0)
X += row * stride_x_row
Y += row * stride_y_row
if HAS_RESIDUAL:
RESIDUAL += row * stride_res_row
if STORE_RESIDUAL_OUT:
RESIDUAL_OUT += row * stride_res_out_row
# Compute mean and variance
cols = tl.arange(0, BLOCK_N)
^
ValueError("arange's arguments must be of type tl.constexpr")
from what I understand, torch.arange creates dynamic arguments which triton
is not happy with.
Dear @tridao , @albertfgu ,
It looks like it is not straightforward to generate onnx file due to following reason using torch.onnx.export:
1) It looks like the underlying scan operator is implemented in the triton 2) We need the recursive version of scan for the inference which I believe is located starting line 119 (lines 119 to 133 where 133 is the return) as shown here: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119
Above two prevents (based on my understanding) generating onnx file. It would be great to have onnx file for the inference part for the smallest model.
Any suggestions how we can generate onnx file for the inference? (also for training separately)?