RWKV is an RNN with transformer-level LLM performance. It can be directly trained like a GPT (parallelizable). So it's combining the best of RNN and transformer - great performance, fast inference, saves VRAM, fast training, "infinite" ctx_len, and free sentence embedding.
0 | emb | Embedding | 3.1 M
1 | blocks | ModuleList | 20.5 M
2 | ln_out | LayerNorm | 1.0 K
3 | head | Linear | 3.1 M
26.7 M Trainable params
0 Non-trainable params
26.7 M Total params
106.770 Total estimated model params size (MB)
Epoch 0: 0%| | 0/5000 [00:00<?, ?it/s]Traceback (most recent call last):
File "/home/exat500g/RWKV-LM/RWKV-v4neo/train.py", line 340, in
trainer.fit(model, data_loader)
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 696, in fit
self._call_and_handle_interrupt(
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 648, in _call_and_handle_interrupt
return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, kwargs)
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 93, in launch
return function(*args, *kwargs)
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in _fit_impl
results = self._run(model, ckpt_path=self.ckpt_path)
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1166, in _run
results = self._run_stage()
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1252, in _run_stage
return self._run_train()
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1283, in _run_train
self.fit_loop.run()
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(args, kwargs)
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 271, in advance
self._outputs = self.epoch_loop.run(self._data_fetcher)
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, *kwargs)
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 174, in advance
batch = next(data_fetcher)
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/utilities/fetching.py", line 184, in next
return self.fetching_function()
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/utilities/fetching.py", line 263, in fetching_function
self._fetch_next_batch(self.dataloader_iter)
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/utilities/fetching.py", line 277, in _fetch_next_batch
batch = next(iterator)
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/trainer/supporters.py", line 557, in next
return self.request_next_batch(self.loader_iters)
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/trainer/supporters.py", line 569, in request_next_batch
return apply_to_collection(loader_iters, Iterator, next)
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/utilities/apply_func.py", line 99, in apply_to_collection
return function(data, args, **kwargs)
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 628, in next
data = self._next_data()
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1333, in _next_data
return self._process_data(data)
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1359, in _process_data
data.reraise()
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/torch/_utils.py", line 543, in reraise
raise exception
UnboundLocalError: Caught UnboundLocalError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
data = fetcher.fetch(index)
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 58, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 58, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/exat500g/RWKV-LM/RWKV-v4neo/src/dataset.py", line 208, in getitem
dix = [self.stoi[s] for s in data[i : i + req_len]]
UnboundLocalError: local variable 'i' referenced before assignment
try [train](example: train a simple L6-D512 RWKV from scratch on enwik8)
RWKV-LM/RWKV-v4neo$ python train.py --proj_dir "out" --data_file "../../data/enwik8" --data_type "utf-8" --vocab_size 0 --ctx_len 512 --epoch_steps 5000 --epoch_count 500 --epoch_begin 0 --epoch_save 5 --micro_bsz 12 --n_layer 6 --n_embd 512 --pre_ffn 0 --head_qk 0 --lr_init 8e-4 --lr_final 1e-5 --warmup_steps 0 --beta1 0.9 --beta2 0.99 --adam_eps 1e-8 --accelerator gpu --devices 1 --precision bf16 --strategy ddp_find_unused_parameters_false --grad_cp 0
then i got:
| Name | Type | Params
0 | emb | Embedding | 3.1 M 1 | blocks | ModuleList | 20.5 M 2 | ln_out | LayerNorm | 1.0 K 3 | head | Linear | 3.1 M
26.7 M Trainable params 0 Non-trainable params 26.7 M Total params 106.770 Total estimated model params size (MB) Epoch 0: 0%| | 0/5000 [00:00<?, ?it/s]Traceback (most recent call last): File "/home/exat500g/RWKV-LM/RWKV-v4neo/train.py", line 340, in
trainer.fit(model, data_loader)
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 696, in fit
self._call_and_handle_interrupt(
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 648, in _call_and_handle_interrupt
return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, kwargs)
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 93, in launch
return function(*args, *kwargs)
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in _fit_impl
results = self._run(model, ckpt_path=self.ckpt_path)
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1166, in _run
results = self._run_stage()
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1252, in _run_stage
return self._run_train()
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1283, in _run_train
self.fit_loop.run()
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(args, kwargs)
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 271, in advance
self._outputs = self.epoch_loop.run(self._data_fetcher)
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
self.advance(*args, *kwargs)
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 174, in advance
batch = next(data_fetcher)
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/utilities/fetching.py", line 184, in next
return self.fetching_function()
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/utilities/fetching.py", line 263, in fetching_function
self._fetch_next_batch(self.dataloader_iter)
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/utilities/fetching.py", line 277, in _fetch_next_batch
batch = next(iterator)
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/trainer/supporters.py", line 557, in next
return self.request_next_batch(self.loader_iters)
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/trainer/supporters.py", line 569, in request_next_batch
return apply_to_collection(loader_iters, Iterator, next)
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/pytorch_lightning/utilities/apply_func.py", line 99, in apply_to_collection
return function(data, args, **kwargs)
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 628, in next
data = self._next_data()
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1333, in _next_data
return self._process_data(data)
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1359, in _process_data
data.reraise()
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/torch/_utils.py", line 543, in reraise
raise exception
UnboundLocalError: Caught UnboundLocalError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/torch/utils/data/_utils/worker.py", line 302, in _worker_loop
data = fetcher.fetch(index)
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 58, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/exat500g/miniconda3/envs/pytorch113/lib/python3.10/site-packages/torch/utils/data/_utils/fetch.py", line 58, in
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/home/exat500g/RWKV-LM/RWKV-v4neo/src/dataset.py", line 208, in getitem
dix = [self.stoi[s] for s in data[i : i + req_len]]
UnboundLocalError: local variable 'i' referenced before assignment