Closed nousr closed 2 years ago
here is the trace
Traceback (most recent call last):
File "/home/x/dalle2/checkpoint_issue.py", line 77, in <module>
train()
File "/home/x/anaconda3/envs/dalle/lib/python3.9/site-packages/click/core.py", line 1128, in __call__
return self.main(*args, **kwargs)
File "/home/x/anaconda3/envs/dalle/lib/python3.9/site-packages/click/core.py", line 1053, in main
rv = self.invoke(ctx)
File "/home/x/anaconda3/envs/dalle/lib/python3.9/site-packages/click/core.py", line 1395, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/home/x/anaconda3/envs/dalle/lib/python3.9/site-packages/click/core.py", line 754, in invoke
return __callback(*args, **kwargs)
File "/home/x/dalle2/checkpoint_issue.py", line 68, in train
trainer.update()
File "/home/x/dalle2/DALLE2-pytorch/dalle2_pytorch/trainer.py", line 339, in update
self.scaler.step(self.optimizer)
File "/home/x/anaconda3/envs/dalle/lib/python3.9/site-packages/torch/cuda/amp/grad_scaler.py", line 310, in step
return optimizer.step(*args, **kwargs)
File "/home/x/anaconda3/envs/dalle/lib/python3.9/site-packages/torch/optim/optimizer.py", line 88, in wrapper
return func(*args, **kwargs)
File "/home/x/anaconda3/envs/dalle/lib/python3.9/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/home/x/anaconda3/envs/dalle/lib/python3.9/site-packages/torch/optim/adamw.py", line 145, in step
F.adamw(params_with_grad,
File "/home/x/anaconda3/envs/dalle/lib/python3.9/site-packages/torch/optim/_functional.py", line 143, in adamw
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
RuntimeError: The size of tensor a (64) must match the size of tensor b (2048) at non-singleton dimension 1
@nousr i'm not sure what's going on! i added https://github.com/lucidrains/DALLE2-pytorch/blob/main/dalle2_pytorch/trainer.py#L257 so we can test out your hypothesis that it is the param groups as root issue
Hey, this one has me a little stumped.
problem: DiffusionPriorTrainer is not able to successfully reload from a checkpoint and continue training.
calling
.save()
and then.load()
seems to corrupt the optimizer in some way. here is a minimal code example that should reproduce the error on main.(steps to reproduce are simple and in the file as a docstring) https://gist.github.com/nousr/5f6877c2bcb8b0c1b806ed8e206d39af
my hunch is that it has something to do with the
param_groups
as setting weight_decay to 0 in the trainer (i.e. using ADAM instead of ADAMW) will fix the issue. this seems to be a related issue https://github.com/pytorch/pytorch/issues/40769 with potentially the most relevant reply being:Other than just skipping the re-load of the optimizer all together I haven't really been able to solve this issue fully. Does anyone have any ideas?