lucidrains / meshgpt-pytorch

Implementation of MeshGPT, SOTA Mesh generation using Attention, in Pytorch
MIT License
700 stars 57 forks source link

Transformer - token_embed outputs nan values #44

Closed MarcusLoppe closed 5 months ago

MarcusLoppe commented 8 months ago

This issue occurs if you have too high learning rate (1-e2) at a low loss (0.3), through this also occurred when I had 1-e3 as lr and at 0.01 loss. edit: Using flash attention it goes from 5.0 loss to nan in the 5th epoch using 1e-4 lr.

After the codes are masked the and token_embed is called, it will output nan values. Not sure if this issue is a pytorch, meshgpt-pytorch or user error :)

codes = codes.masked_fill(codes == self.pad_id, 0)
codes = self.token_embed(codes)
codes  after  masked_fill  torch.Size([2, 912]) tensor([[11965,   608, 11350,  ...,     0,     0,     0],
        [15507, 13398,  5400,  ...,  8247, 13231,  5280]], device='cuda:0') 

codes token_embed after  torch.Size([2, 912, 512]) tensor([[[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]],

        [[nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         ...,
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan],
         [nan, nan, nan,  ..., nan, nan, nan]]], device='cuda:0',
       grad_fn=<EmbeddingBackward0>)
lucidrains commented 8 months ago

yea, this is just normal transformer instability

there's a bag of tricks for tackling this

MarcusLoppe commented 8 months ago

yea, this is just normal transformer instability

there's a bag of tricks for tackling this

@lucidrains

Shoot, I'm using a dataset of 120 mesh models (1200 after augmentation), it worked bit better with a bigger dataset so it might be due to the 'small' dataset.

lr 1e-4:

Epoch 1/10: 100%|██████████| 600/600 [02:34<00:00,  3.89it/s, loss=8.54]
Epoch 1 average loss: 8.743859918912252
Epoch 2/10: 100%|██████████| 600/600 [02:31<00:00,  3.95it/s, loss=8.15]
Epoch 2 average loss: 8.339149476687114
Epoch 3/10: 100%|██████████| 600/600 [02:31<00:00,  3.96it/s, loss=6.67]
Epoch 3 average loss: 7.025277642409007
Epoch 4/10: 100%|██████████| 600/600 [02:32<00:00,  3.94it/s, loss=5.83]
Epoch 4 average loss: 5.839961892763774           avg loss speed: 2.196133786572349
Epoch 5/10: 100%|██████████| 600/600 [02:31<00:00,  3.95it/s, loss=5.23]
Epoch 5 average loss: 5.08304128130277           avg loss speed: 1.9850883893171947
Epoch 6/10: 100%|██████████| 600/600 [02:32<00:00,  3.94it/s, loss=4.39]
Epoch 6 average loss: 4.479391298294067           avg loss speed: 1.5033689738644487
Epoch 7/10: 100%|██████████| 600/600 [02:23<00:00,  4.19it/s, loss=nan] 
Epoch 7 average loss: nan
Epoch 8/10: 100%|██████████| 600/600 [02:17<00:00,  4.35it/s, loss=nan]
Epoch 8 average loss: nan
Kurokabe commented 8 months ago

yea, this is just normal transformer instability

there's a bag of tricks for tackling this

Could you give some examples on how to tackle this? I'm also having NaN after a few epochs (~5 epochs) when training on full ShapeNet (~15k different mesh models) with an 1e-4 lr. I'm still investigating so I'm not sure if it's exactly the same problem as @MarcusLoppe but it could be nice to have some ideas on how to solve this problem :)

lucidrains commented 8 months ago

there are no solutions. stabilizing transformers is still an active area of research, esp as you increase parameter count. there are various bandaids however. most practitioners have a couple they apply, but none of them are panaceas yet

lucidrains commented 8 months ago

you can check out my x-transformers repo for more info

MarcusLoppe commented 8 months ago

you can check out my x-transformers repo for more info

Any particular feature? I'm finding gate_residual ,sandwich_norm, ResiDual and scale_residual. Btw do you have already or plan on implement sliding window in x-transformers?

Could you give some examples on how to tackle this? I'm also having NaN after a few epochs (~5 epochs) when training on full ShapeNet (~15k different mesh models) with an 1e-4 lr. I'm still investigating so I'm not sure if it's exactly the same problem as @MarcusLoppe but it could be nice to have some ideas on how to solve this problem :)

I think experimenting with the optimizer would be a good start as well, most easiest parameters is probably; max_grad_norm and weight_decay. I'll do some testing and I'll let you know what I find out.

In the paper they didn't mention of any other details then using Adam and batch size of 64, I believe that increasing the batch size might help as well. Due to VRAM constrains I'm only using 1 or 2 batch size.

lucidrains commented 8 months ago

@MarcusLoppe you could try qk norm. some researchers at google brain are attached to this, but i suspect it has a slight generalization cost

yea, you are right with optimizer. values to play with are beta1, beta2, and eps. your batch size def needs to be bigger once you scale up, but you can use gradient accumulation for this (which is built-in)

lucidrains commented 8 months ago

other things that would help is warmup, gradient clipping of 0.5 and 0.25 if you want to be really aggressive

lucidrains commented 8 months ago

@MarcusLoppe scratch everything i said, as Kurokabe noted that a potential source of instability was actually due to the gateloop layers

MarcusLoppe commented 8 months ago

@MarcusLoppe scratch everything i said, as Kurokabe noted that a potential source of instability was actually due to the gateloop layers

I still get nan loss at 0.07 using 1e-4 as learning rate. But above that it doesn't give any issues anymore. I'll try to replicate and use detect_anomaly to see what happens.

MarcusLoppe commented 5 months ago

Resolved by using larger dataset, possible explanation: https://github.com/lucidrains/meshgpt-pytorch/issues/68#issuecomment-1991789517