BlinkDL / RWKV-LM

RWKV is an RNN with transformer-level LLM performance. It can be directly trained like a GPT (parallelizable). So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding.
Apache License 2.0
12.05k stars 827 forks source link

Proposal: add a flag to recognize model version #181

Closed AsakusaRinne closed 9 months ago

AsakusaRinne commented 10 months ago

When I load the model (converted with chat-rwkv) with torch.load, there only three non-tensor objects, which are _strategy, _rescale_layer, _version. However both the latest v4 world model and v5 0.1B world model has _version = 0.7. I'm finding somthing like version=4 or version=5 to recognize the version of rwkv model but there hasn't been yet. If I load the original weights directly, theres's no non-tensor objects.

I think it's necessary to find a way to recognize the model version because some libraries need to support different versions. For example, adding a flag _rwkv_version to the model dict.

BlinkDL commented 9 months ago

you can tell the model version from its params names and dimensions

check the computation of self.version in https://github.com/BlinkDL/ChatRWKV/blob/main/rwkv_pip_package/src/rwkv/model.py

AsakusaRinne commented 9 months ago

Thank you!