lucidrains / DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
MIT License
11.12k stars 1.09k forks source link

Optimizer error when resuming from state-dict #115

Closed nousr closed 2 years ago

nousr commented 2 years ago

Hey, this one has me a little stumped.

problem: DiffusionPriorTrainer is not able to successfully reload from a checkpoint and continue training.

calling .save() and then .load() seems to corrupt the optimizer in some way. here is a minimal code example that should reproduce the error on main.

(steps to reproduce are simple and in the file as a docstring) https://gist.github.com/nousr/5f6877c2bcb8b0c1b806ed8e206d39af

my hunch is that it has something to do with the param_groups as setting weight_decay to 0 in the trainer (i.e. using ADAM instead of ADAMW) will fix the issue. this seems to be a related issue https://github.com/pytorch/pytorch/issues/40769 with potentially the most relevant reply being:

...It seems like some important aspect of your model changes between checkpoints. From the error, it looks like the shape of the gradient average does not match the shape of the current gradient....

Other than just skipping the re-load of the optimizer all together I haven't really been able to solve this issue fully. Does anyone have any ideas?

nousr commented 2 years ago

here is the trace

Traceback (most recent call last):
  File "/home/x/dalle2/checkpoint_issue.py", line 77, in <module>
    train()
  File "/home/x/anaconda3/envs/dalle/lib/python3.9/site-packages/click/core.py", line 1128, in __call__
    return self.main(*args, **kwargs)
  File "/home/x/anaconda3/envs/dalle/lib/python3.9/site-packages/click/core.py", line 1053, in main
    rv = self.invoke(ctx)
  File "/home/x/anaconda3/envs/dalle/lib/python3.9/site-packages/click/core.py", line 1395, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/home/x/anaconda3/envs/dalle/lib/python3.9/site-packages/click/core.py", line 754, in invoke
    return __callback(*args, **kwargs)
  File "/home/x/dalle2/checkpoint_issue.py", line 68, in train
    trainer.update()
  File "/home/x/dalle2/DALLE2-pytorch/dalle2_pytorch/trainer.py", line 339, in update
    self.scaler.step(self.optimizer)
  File "/home/x/anaconda3/envs/dalle/lib/python3.9/site-packages/torch/cuda/amp/grad_scaler.py", line 310, in step
    return optimizer.step(*args, **kwargs)
  File "/home/x/anaconda3/envs/dalle/lib/python3.9/site-packages/torch/optim/optimizer.py", line 88, in wrapper
    return func(*args, **kwargs)
  File "/home/x/anaconda3/envs/dalle/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/x/anaconda3/envs/dalle/lib/python3.9/site-packages/torch/optim/adamw.py", line 145, in step
    F.adamw(params_with_grad,
  File "/home/x/anaconda3/envs/dalle/lib/python3.9/site-packages/torch/optim/_functional.py", line 143, in adamw
    exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
RuntimeError: The size of tensor a (64) must match the size of tensor b (2048) at non-singleton dimension 1
lucidrains commented 2 years ago

@nousr i'm not sure what's going on! i added https://github.com/lucidrains/DALLE2-pytorch/blob/main/dalle2_pytorch/trainer.py#L257 so we can test out your hypothesis that it is the param groups as root issue