state-spaces / mamba

Mamba SSM architecture
Apache License 2.0
11.54k stars 944 forks source link

Generating onnx file for the inference of Mamba? #200

Open llmexperiment opened 4 months ago

llmexperiment commented 4 months ago

Dear @tridao , @albertfgu ,

It looks like it is not straightforward to generate onnx file due to following reason using torch.onnx.export:

1) It looks like the underlying scan operator is implemented in the triton 2) We need the recursive version of scan for the inference which I believe is located starting line 119 (lines 119 to 133 where 133 is the return) as shown here: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/modules/mamba_simple.py#L119

Above two prevents (based on my understanding) generating onnx file. It would be great to have onnx file for the inference part for the smallest model.

Any suggestions how we can generate onnx file for the inference? (also for training separately)?

tridao commented 4 months ago

We have no experience with ONNX. Do you have ideas on how to generate onnx for custom operations? If so would you like to contribute?

llmexperiment commented 4 months ago

We have no experience with ONNX. Do you have ideas on how to generate onnx for custom operations? If so would you like to contribute?

Thanks @tridao ! I am working towards that direction (here is a how to do it: https://github.com/onnx/tutorials/blob/master/PyTorchCustomOperator/README.md)

Could you let me know what are other custom operators if any other than scan?

I know the code for scan is here (if I am not mistaken): https://github.com/state-spaces/mamba/blob/main/csrc/selective_scan/selective_scan.cpp

Any chance, if you have standalone implementation of scan?

xenova commented 3 months ago

@llmexperiment You may be interested in looking at the HF transformers implementation (PR here), which supports a fallback if causal-conv1d is not found in the environment. I've also been trying to convert Mamba models to ONNX for transformers.js, but I've been running into a few issues. If I figure something out, I'll update the thread.

IamShubhamGupto commented 3 months ago

Hey Im interested in converting the Vision Mamba - Vim paper to onnx but have not had success. I decided to start working with mamba layers first and then proceed there.

This is the current status of my code stack trace

root@ubuntu:/home/xavier01/Documents/workspace/embedded-vpr/Vim/vim# python3 demo_export.py
Fetching 4 files: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 4/4 [00:00<00:00, 28149.69it/s]
/home/xavier01/Documents/workspace/embedded-vpr/Vim/vim/models_mamba.py:57: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert H == self.img_size[0] and W == self.img_size[1], \
/home/xavier01/Documents/workspace/embedded-vpr/Vim/vim/models_mamba.py:420: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if if_random_token_rank:
/usr/local/lib/python3.8/dist-packages/mamba_ssm-1.1.1-py3.8-linux-aarch64.egg/mamba_ssm/ops/triton/layernorm.py:133: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  assert weight.shape == (N,)
/usr/local/lib/python3.8/dist-packages/mamba_ssm-1.1.1-py3.8-linux-aarch64.egg/mamba_ssm/ops/triton/layernorm.py:150: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
/usr/local/lib/python3.8/dist-packages/mamba_ssm-1.1.1-py3.8-linux-aarch64.egg/mamba_ssm/ops/triton/layernorm.py:151: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
  if N > BLOCK_N:
Traceback (most recent call last):
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 1124, in ast_to_ttir
    generator.visit(fn.parse())
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 1017, in visit
    ret = super().visit(node)
  File "/usr/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 293, in visit_Module
    ast.NodeVisitor.generic_visit(self, node)
  File "/usr/lib/python3.8/ast.py", line 379, in generic_visit
    self.visit(item)
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 1017, in visit
    ret = super().visit(node)
  File "/usr/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 362, in visit_FunctionDef
    self.visit_compound_statement(node.body)
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 288, in visit_compound_statement
    ret_type = self.visit(stmt)
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 1017, in visit
    ret = super().visit(node)
  File "/usr/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 414, in visit_Assign
    values = self.visit(node.value)
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 1017, in visit
    ret = super().visit(node)
  File "/usr/lib/python3.8/ast.py", line 371, in visit
    return visitor(node)
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 946, in visit_Call
    return fn(*args, **extra_kwargs, **kws)
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/language/core.py", line 30, in wrapper
    return fn(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/language/core.py", line 813, in arange
    return semantic.arange(start, end, _builder)
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/language/semantic.py", line 485, in arange
    raise ValueError("arange's arguments must be of type tl.constexpr")
ValueError: arange's arguments must be of type tl.constexpr

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "demo_export.py", line 61, in <module>
    torch.onnx.export(model, dummy_input, "vim_s_midclstok_ft_81p6acc_fp16.onnx", input_names=["input"],
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 516, in export
    _export(
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 1596, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 1135, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 1011, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 915, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 1285, in _get_trace_graph
    outs = ONNXTracedModule(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 133, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 124, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1508, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/xavier01/Documents/workspace/embedded-vpr/Vim/vim/models_mamba.py", line 541, in forward
    x = self.forward_features(x, inference_params, if_random_cls_token_position=if_random_cls_token_position, if_random_token_rank=if_random_token_rank)
  File "/home/xavier01/Documents/workspace/embedded-vpr/Vim/vim/models_mamba.py", line 478, in forward_features
    hidden_states, residual = layer(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1508, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/home/xavier01/Documents/workspace/embedded-vpr/Vim/vim/models_mamba.py", line 115, in forward
    hidden_states, residual = fused_add_norm_fn(
  File "/usr/local/lib/python3.8/dist-packages/mamba_ssm-1.1.1-py3.8-linux-aarch64.egg/mamba_ssm/ops/triton/layernorm.py", line 478, in rms_norm_fn
    return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
  File "/usr/local/lib/python3.8/dist-packages/torch/autograd/function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/usr/local/lib/python3.8/dist-packages/mamba_ssm-1.1.1-py3.8-linux-aarch64.egg/mamba_ssm/ops/triton/layernorm.py", line 411, in forward
    y, mean, rstd, residual_out = _layer_norm_fwd(
  File "/usr/local/lib/python3.8/dist-packages/mamba_ssm-1.1.1-py3.8-linux-aarch64.egg/mamba_ssm/ops/triton/layernorm.py", line 155, in _layer_norm_fwd
    _layer_norm_fwd_1pass_kernel[(M,)](
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/runtime/autotuner.py", line 100, in run
    timings = {config: self._bench(*args, config=config, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/runtime/autotuner.py", line 100, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/runtime/autotuner.py", line 83, in _bench
    return do_bench(kernel_call, warmup=self.warmup, rep=self.rep, quantiles=(0.5, 0.2, 0.8))
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/testing.py", line 104, in do_bench
    fn()
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/runtime/autotuner.py", line 81, in kernel_call
    self.fn.run(*args, num_warps=config.num_warps, num_stages=config.num_stages, **current)
  File "<string>", line 63, in _layer_norm_fwd_1pass_kernel
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/compiler.py", line 476, in compile
    next_module = compile_kernel(module)
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/compiler.py", line 381, in <lambda>
    lambda src: optimize_ttir(ast_to_ttir(src, signature, configs[0], constants, debug=debug, arch=arch), arch))
  File "/usr/local/lib/python3.8/dist-packages/triton-2.1.0-py3.8-linux-aarch64.egg/triton/compiler/code_generator.py", line 1133, in ast_to_ttir
    raise CompilationError(fn.src, node, repr(e)) from e
triton.compiler.errors.CompilationError: at 31:24:    HAS_BIAS: tl.constexpr,
):
    # Map the program id to the row of X and Y it should compute.
    row = tl.program_id(0)
    X += row * stride_x_row
    Y += row * stride_y_row
    if HAS_RESIDUAL:
        RESIDUAL += row * stride_res_row
    if STORE_RESIDUAL_OUT:
        RESIDUAL_OUT += row * stride_res_out_row
    # Compute mean and variance
    cols = tl.arange(0, BLOCK_N)
                        ^
ValueError("arange's arguments must be of type tl.constexpr")

from what I understand, torch.arange creates dynamic arguments which triton is not happy with.