syncdoth / RetNet

Huggingface compatible implementation of RetNet (Retentive Networks, https://arxiv.org/pdf/2307.08621.pdf) including parallel, recurrent, and chunkwise forward.
MIT License
226 stars 24 forks source link

encountered nan while trying to train #6

Open liujuncn opened 1 year ago

liujuncn commented 1 year ago

image

Trying to compare with other transformer architectures. But as soon as the training starts, the gradient encounters nan. Other transformer architectures use the same data set and hyperparameters, it will not happen.

I don't know where there are numerical stability problems

syncdoth commented 1 year ago

Could you provide a minimal reproducible example? I have tried single GPU training outlined in the train.py of the repo with huggingface Trainer API without encountering numerical instability.

liujuncn commented 1 year ago

I looked at train.py, I think you are training with 32bit.

image

image

I also switched to 32bit mode, the training is fine, but the speed is very slow (same dataset, but less layers for memory cost). So the problem is under 16bit AMP training.

jploski commented 1 year ago

FWIW, I also trained a mini-model (in 32-bit mode). I did not notice instability nor particularly bad training performance.

daskol commented 1 year ago

@liujuncn Can you elaborate on your training setup? What model config? What dataset? What mixed precision regime do you use (compute, accumulate, params dtypes)? When does NaN appear in training? On what layer?

jploski commented 1 year ago

Unfortunately I don't have a minimal example either, but I also encountered these NaN/infinity problems while training a bigger model. They happened during forward pass, most commonly were right before GroupNorm. My model's config was:

vocab_size: 65024
hidden_size: 4544
num_layers: 4
num_heads: 8
qk_dim: 4544
v_dim: 9088
ffn_proj_size: 9088
use_bias_in_msr: False
use_bias_in_mlp: True
use_bias_in_msr_out: False
use_default_gamma: False
initializer_range: 0.02
output_retentions: False
pad_token_id: 11
eos_token_id: 11
unk_token_id: 11

That model was special in that I copied all the embeddings from a different model (Falcon-7B, hence the hiden_size) without any rescaling and also turned off gradients for them to prevent from being trained along with the MLPs (my questionable rationale being that the embeddings were already pretrained and good enough and should not be disturbed).

It was mentioned in another issue that there was a fix to torchscale's implementation aimed at improving numerical stability (https://github.com/microsoft/torchscale/issues/47), so there may well be a similar problem here.

However, I did not manage to port that fix without compromising the consistency of recurrent/parallel pass, so I gave up on it.

In the end to get out of the instability I applied a dirty hack of inserting torch.nan_to_num at various vulnerable locations. Miraculously the weights converged to a more numerically stable configuration with that, and I could even remove these protections. But it surely does not seem like the correct approach.

daskol commented 1 year ago

@jploski Thank you very much for you detailed comment. I've faced to the same instability issues with a model of similar size and I'm trying to figure out.

syncdoth commented 11 months ago

I couldn't try this model on a large training setting yet, and on my tiny synthetic dataset I didn't have issues. But great to know about these issues; let's work together to solve them!

I have been working on porting the official torchscale implementation to HF, which is almost finished except for chunkwise forward. In fact, I have pushed the branch (official_implementation) so if you want "early access" you can try it out and let me know if you find any bugs. The main difference should be some tricks for stability, which may hopefully stabilize training and even enable FP16 :)

daskol commented 11 months ago

@syncdoth There is definitely some issues with stability with this implementation whilst everything is fine with torchscale.

I have been working on porting the official torchscale implementation to HF

why not just copypaste origina ltorchscale.architecture.retnet with all dependencies and wrap it with HF's model mix-ins? Repo torchscale is under MIT: no legal limitations.

syncdoth commented 11 months ago

@daskol That's one way to do it, but some things can't be done that way, most important of which is attention_mask handling. This is not in torchscale repo.

Still, I followed your advice and have the torchscale/ directory, which is the basic copy & paste of the official code, with minor changes (args -> config, no moe, no multiway). I have tests/ that suggests that my implementation in retnet/ and torchscale/ are exactly the same functionally during forward/backwards.

I can also confirm that the model training is stable with bf16. Not sure with FP16, but not planning to test it out in the near future :(

Shreyas-Dongre commented 11 months ago

This might answer your query. https://github.com/microsoft/torchscale/issues/48