syncdoth / RetNet

Huggingface compatible implementation of RetNet (Retentive Networks, https://arxiv.org/pdf/2307.08621.pdf) including parallel, recurrent, and chunkwise forward.
MIT License
226 stars 24 forks source link

Info/Documentation on chunkwise training #30

Open pkpro opened 9 months ago

pkpro commented 9 months ago

Hi there. I want to understand how to use the RetNet to train a model with the longer context. It is not clear from available documentation how to train the model for a large context. There is no parameters for Trainer of TrainingArguments how do one actually passes 100k text with 512 chunksize into the model during the training? Should one chunk the text themselves with overlapped text fragments? Or there is a method to pass 100k text and it will be processed?

What exactly these outputs can be used for?

chunk_outputs = model(input_ids, forward_impl='chunkwise', use_cache=True, recurrent_chunk_size=4)
chunk_state = chunk_outputs.last_hidden_state
chunk_cache = chunk_outputs.past_key_values

Should one mix these outputs into the beginning of a next input during the data preparation?

Please help.

syncdoth commented 9 months ago

Hi there!

If you want very long ctx len, say 16k, you will have to create batch of shape [bs, ctxlen] in the dataloader and pass it to the trainer. Then, when you create the RetNetConfig, set the forward_impl="chunkwise", recurrent_chunk_size=512. You can set the chunk size to whatever you like. This will set the default forward mode of RetNet to chunkwise, and it will chunk the input ids within the forward. Take a look at chunkwise_forward.

pkpro commented 9 months ago

Hello @syncdoth, Thank you for the swift answer. I have several followup questions.

  1. When I set "forward_impl":"chunkwise" in the config file, or after loading the config and modify it with config.forward_impl="chunkwise" printing this config does not includes the forward_impl, however the amount of VRAM required for "chunkwise" is much larger that indicates the change in the configuration.
  2. In your example [bs, ctxlen], I assume that ctxlen is the new context len - 16k in this case. Right?
  3. What would be the effect of changing the forward_impl for the pretrained model? Let suppose I trained the model with forward_impl="parallel", and then I want to continue training with the larger context setting the implementation to forward_impl="chunkwise". Will this work? The weights are the same it is just the pipeline of training is changed.
  4. Can I switch it back to forward_impl="parallel" for another dataset? What change in behavior should I expect?
  5. What with the inference. Do I need to run the model in forward_impl="chunkwise" mode for the larger context to be available?

Thanks a lot for your answers, I really appreciate your work. And thanks that you are still supporting your implementation.

syncdoth commented 9 months ago

Question 3,4,5 - the answer is yes. You can train with any mode then load the model for further training or inference with any mode, the output should be the same. The point of retentive network is to have equivalence of the three forward modes given the same model parameters.

Question 1 - does the chunkwise forward use more VRAM with the same seqlen? That might be a bug I should track down.

Question 2 - yes, I mean the length of the full sequence, not each chunk.

pkpro commented 9 months ago

@syncdoth Thank you again, appreciate it very much.

I checked with the same content length and chunkwise forward_impl uses more VRAM

With 180 tokens sequences I can have on 16G VRAM:

Network config (150M): { "vocab_size": 32024, "pad_token_id": 2, "eos_token_id": 2, "unk_token_id": 0, "bos_token_id": 1, "decoder_embed_dim": 512, "decoder_value_embed_dim": 1024, "decoder_ffn_embed_dim": 1024, "decoder_layers": 32, "decoder_retention_heads": 16, "initializer_range": 0.02, "layernorm_eps": 1.0e-8, "activation_fn": "swish", "dropout": 0.0, "activation_dropout": 0.0, "drop_path_rate": 0.0, "decoder_normalize_before": true, "layernorm_embedding": true, "no_scale_embedding": false, "recurrent_chunk_size": 512, "use_lm_decay": false, "deepnorm": false, "subln": true, "torch_dtype": "float32", "output_retentions": false, "forward_impl": "parallel", "tie_word_embeddings": false, "is_decoder": true, "transformers_version": "4.36.0.dev0", "use_cache": true, "model_type": "retnet" } Trainer args: --gradient_accumulation_steps 2 --bf16 --optim adamw_torch_fused --do_train --evaluation_strategy epoch --num_train_epochs 1 --lr_scheduler_type cosine_with_restarts --save_strategy epoch --learning_rate 3e-3 --warmup_steps 1000 --include_tokens_per_second True --bf16_full_eval True

syncdoth commented 9 months ago

Thanks for raising this issue. I think I might know the reason;

During parallel forward, it computes q @ k matrix, which is of size (bs, num_head, seqlen, seqlen).

During chunkwise, it must compute k * v, which is of size (bs, num_head, chunk_size, qk_dim / num_head, v_dim / num_head).

  1. You have recurrent chunk size of 512, which is longer than 180. This means that the sequence will be padded to nearest multiple of 512 (i.e. 512) to perform chunkwise forward.

  2. For your configuration, the seqlen should be at least 1024 for chunkwise to use the same amount of memory as parallel. In my experience, I found chunkwise to be beneficial only for training with 4k+ seqlen.