BlinkDL / RWKV-LM

RWKV is an RNN with transformer-level LLM performance. It can be directly trained like a GPT (parallelizable). So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding.
Apache License 2.0
12.05k stars 827 forks source link

Exporting RWKV into ONNX #140

Closed marty1885 closed 1 year ago

marty1885 commented 1 year ago

Hi, sorry I had to open an issue.

I am looking into running RWKV to run on a RK3588 NPU as a hobby project. So far I figured out that CUDA compilation must be disabled for export to work. However, both the TorchScript and ONNX export route fails at some point(TorchScript needs to be exported into ONNX at the end). How can I get RWKV into ONNX format?

mode.eval()
torch.onnx.export(model, [1], "RWKV-7B.onnx")
Traceback (most recent call last):
  File "/home/marty/anaconda3/envs/rwkv/lib/python3.11/site-packages/torch/onnx/utils.py", line 962, in _create_jit_graph
    graph = model.forward.graph  # type: ignore[attr-defined]
            ^^^^^^^^^^^^^^^^^^^
AttributeError: 'function' object has no attribute 'graph'

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/marty/Documents/ChatRWKV/v2/dump_model.py", line 138, in <module>
    torch.onnx.export(model, [1], "RWKV-7B.onnx")
  File "/home/marty/anaconda3/envs/rwkv/lib/python3.11/site-packages/torch/onnx/utils.py", line 506, in export
    _export(
  File "/home/marty/anaconda3/envs/rwkv/lib/python3.11/site-packages/torch/onnx/utils.py", line 1548, in _export
    graph, params_dict, torch_out = _model_to_graph(
                                    ^^^^^^^^^^^^^^^^
  File "/home/marty/anaconda3/envs/rwkv/lib/python3.11/site-packages/torch/onnx/utils.py", line 1113, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
                                       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/marty/anaconda3/envs/rwkv/lib/python3.11/site-packages/torch/onnx/utils.py", line 964, in _create_jit_graph
    raise RuntimeError("'forward' method must be a script method") from e
RuntimeError: 'forward' method must be a script method

Thanks for your time

BlinkDL commented 1 year ago

see https://github.com/search?o=desc&q=rwkv+onnx&s=updated&type=repositories