Blealtan / RWKV-LM-LoRA

RWKV is a RNN with transformer-level LLM performance. It can be directly trained like a GPT (parallelizable). So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding.
Apache License 2.0
405 stars 41 forks source link

Fix lora training for time-related parameters (Reopened PR) #9

Closed thrfirs closed 1 year ago

thrfirs commented 1 year ago

Issue

The issue with the original code is that it only checks whether a module name contains ".time_" when the enable_time_finetune option is enabled. However, in RWKV-v4neo, the only module that contains ".time_" in its name is time_shift in RWKV_TimeMix, which is not trainable because it is a nn.ZeroPad2d.

As a result, time-related trainable nn.Parameters such as time_first in RWKV_TimeMix and time_mix_k in RWKV_ChannelMix will never be set to requires_grad = True since they are not modules.

Fix

To address this issue, this pull request takes into account parameter names in module.named_parameters() when the enable_time_finetune option is enabled. This ensures that gradients are enabled for time-related parameters.