rwkv-kit is a pure PyTorch implementation of the RWKV large language model inference framework. This project aims to provide a flexible and easily scalable PyTorch implementation for the RWKV x060 model, supporting various features such as batch inference, parallel inference, ONNX format export, and standalone training.
We support various hardware devices, including but not limited to:
Contributions for additional device support are welcome.
Clone the repository:
git clone -b dev https://github.com/TorchRWKV/rwkv-kit.git
Install dependencies:
cd rwkv-kit
pip install -r requirements.txt
# you need to install triton and rwkv-fla if you want to use triton kernel
pip install rwkv-fla[cuda] # pip install rwkv-fla[xpu], pip install rwkv-fla[rocm]
Download the RWKV6 model from BlinkDL/rwkv-6-world and place the weights in the weight
folder.
Benchmark: (we use native torch to autoregress)
import time
import os
import torch
from rwkvkit import rwkv6, sample_logits
initial_string = """hello"""
batch_size = 128
TEMPERATURE = 1.0
TOP_P = 0.0
LENGTH_PER_TRIAL = 100
model = rwkv6(
model_path="weight/RWKV-x060-World-3B-v2.1-20240417-ctx4096.pth",
prefill_kernel="torch", # torch, torch-manual, triton, triton-chunk
use_jit=True,
compile=False
)
state = model.init_state(batch_size)
encoded_input = model.tokenizer.encode([initial_string] * batch_size)
token = torch.tensor(encoded_input).long().to(model.device) #
state = None
out, state = model.forward(token, state)
for step in range(LENGTH_PER_TRIAL):
token_sampled = sample_logits(out, TEMPERATURE, TOP_P)
out, state = model.forward(token_sampled, state)
t1 = time.time()
state = None
out, state = model.forward(token, state)
t2 = time.time()
print(f"Time: {t2 - t1}")
start_time = time.time()
for step in range(LENGTH_PER_TRIAL):
token_sampled = sample_logits(out, TEMPERATURE, TOP_P)
out, state = model.forward(token_sampled, state)
end_time = time.time()
total_time = end_time - start_time
tokens_generated = LENGTH_PER_TRIAL * batch_size
speed = tokens_generated / total_time
print(f"\nTotal time: {total_time:.2f} seconds")
print(f"Tokens generated: {tokens_generated}")
print(f"Token generation speed: {speed:.2f} tokens/second")
Method | Batch Size | Token Length | Prefill Time (ms) | Token Generation Speed (tokens/second) | Notes |
---|---|---|---|---|---|
triton-chunk | 1 | 1024 | 132.50 | 42.83 | Suitable for inference and training, better speed |
triton | 1 | 1024 | 105.49 | - | Suitable for inference and training, high accuracy |
torch | 1 | 1024 | 595.22 | - | Suitable for inference on devices where Triton is unavailable |
torch-manual | 1 | 1024 | 2468.00 | - | Suitable for training on devices where Triton is unavailable, high accuracy |
- | 1 | - | - | 48.42 | Excluding prefill |
- | 64 | - | - | 1266.77 | Excluding prefill |
- | 128 | - | - | 1875.03 | Excluding prefill |
Notes:
For normal use:
initial_string = """User: 你好! 请问你是什么模型?"""
batch_size = 2
state = None
TEMPERATURE = 1.0
TOP_P = 0.0
LENGTH_PER_TRIAL = 100
encoded_input = tokenizer.encode([initial_string] * batch_size)
token = torch.tensor(encoded_input).long().to(config.device)
token_all = torch.tensor(encoded_input).long().to(config.device)
for step in range(LENGTH_PER_TRIAL):
out, state = model.forward(token, state)
token = sample_logits(out, TEMPERATURE, TOP_P)
token_all = torch.cat((token_all, token.unsqueeze(1)), 1)
os.system('cls' if os.name == 'nt' else 'clear')
decoded_sequences = tokenizer.decode(token_all.cpu().tolist())
for i, seq in enumerate(decoded_sequences):
print(f"Batch {i+1}: {seq}")
You can also try:
print(model.generate(initial_string, LENGTH_PER_TRIAL, TEMPERATURE, TOP_P, include_prompt=True))
print(model.chat([{"role": "user", "content": "你是什么模型?"}], 500, TEMPERATURE, TOP_P))
for i in model.chat([{"role": "user", "content": "你好呀"}], 500, TEMPERATURE, TOP_P, stream=True):
print(i, end="", flush=True)
onnx_export.py
for your desired model.python onnx_export.py
mkdir ONNX_Simplified
python simplify_large_onnx.py -m onnx/{model name}.onnx -o ONNX_Simplified/{model name}.onnx
onnx_infer.py
and run:
python onnx_infer.py
To start an OpenAI server:
python -m rwkv-kit.openai_server --model model_path --state state_path(optional) --host 0.0.0.0 --port 8848
This framework currently supports only RWKV v6 models, specifically version x060.
We plan to adapt this project for the AI Pro development board launched by Xunlong Orange Pi, enabling inference of the RWKV model on the Ascend ecosystem.
Special thanks to:
Yang, S., & Zhang, Y. (2024). FLA: A Triton-Based Library for Hardware-Efficient Implementations of Linear Attention Mechanism (Version 0.0.1) [Computer software]. https://github.com/sustcsonglin/flash-linear-attention
We have used modified implementations based on their work in different kernels.
We welcome contributions from the community. Please feel free to submit PRs and raise Issues. Your input is valuable and helps improve the project for everyone.
优化模型用到的仓库:
Zhiyuan Li |
Yuunnn_w |
WuTianyi |
Null |
Dejiao Zeng |
We warmly invite everyone to contribute to the project by submitting PRs and raising Issues! Your input and contributions are highly valued and play a vital role in improving the project for the entire community. Let's collaborate and make this project even better together!