Blealtan / RWKV-LM-LoRA

RWKV is a RNN with transformer-level LLM performance. It can be directly trained like a GPT (parallelizable). So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding.
Apache License 2.0
405 stars 41 forks source link

Gradient Checkpointing requires JIT off, does it need in original RWKV? #4

Closed tiendung closed 1 year ago

tiendung commented 1 year ago

https://github.com/Blealtan/RWKV-LM-LoRA/blob/df5689bc88fc2f3334fbbc0117369817b0558b2b/RWKV-v4neo/train.py#L260

Studying RWKV-LoRA I found out that if args.grad_cp == 1 then RWKV_JIT_ON should be set to 0. I would like to ask if it is applicable for LoRA only or original RWKV also needed it?

Thanks.

Blealtan commented 1 year ago

I suspect the original RWKV also needs it but have no direct evidence. I only tried turning LoRA off and checkpointing on in this repo and it doesn't work. However this repo with LoRA off should be identical to the original, so that's what I thought. Maybe later look further into it.

tiendung commented 1 year ago

Nice find out, thank you for implementing LoRA for RWKV :)