txie-93 / cdvae

An SE(3)-invariant autoencoder for generating the periodic structure of materials [ICLR 2022]
MIT License
244 stars 95 forks source link

Problem with the dataset MP-20 #5

Open ubikpt opened 2 years ago

ubikpt commented 2 years ago

I found another issue... Maybe this is also related to changes in pytorch?

I think I managed to train the CDVAE with the datasets Perov-5 and Carbon-24, but when I try to train it with the dataset MP-20, I get the following error:

Epoch 46: 0%| | 0/212 [00:00<?, ?it/s, loss=2.08e+28, v_num=0sev, val_loss=11.Error executing job with overrides: ['data=mp_20', 'expname=mp_20'] Traceback (most recent call last): File "cdvae/run.py", line 166, in main run(cfg) File "cdvae/run.py", line 154, in run trainer.fit(model=model, datamodule=datamodule) File "/home/.conda/envs/cdvae2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 460, in fit self._run(model) File "/home/.conda/envs/cdvae2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 758, in _run self.dispatch() File "/home/.conda/envs/cdvae2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 799, in dispatch self.accelerator.start_training(self) File "/home/.conda/envs/cdvae2/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 96, in start_training self.training_type_plugin.start_training(trainer) File "/home/.conda/envs/cdvae2/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 144, in start_training self._results = trainer.run_stage() File "/home/.conda/envs/cdvae2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 809, in run_stage return self.run_train() File "/home/.conda/envs/cdvae2/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 871, in run_train self.train_loop.run_training_epoch() File "/home/.conda/envs/cdvae2/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 499, in run_training_epoch batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx) File "/home/.conda/envs/cdvae2/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 738, in run_training_batch self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) File "/home/.conda/envs/cdvae2/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 434, in optimizer_step model_ref.optimizer_step( File "/home/.conda/envs/cdvae2/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py", line 1403, in optimizer_step optimizer.step(closure=optimizer_closure) File "/home/.conda/envs/cdvae2/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py", line 214, in step self.optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs) File "/home/.conda/envs/cdvae2/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py", line 134, in optimizer_step trainer.accelerator.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, kwargs) File "/home/.conda/envs/cdvae2/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 329, in optimizer_step self.run_optimizer_step(optimizer, opt_idx, lambda_closure, kwargs) File "/home/.conda/envs/cdvae2/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 336, in run_optimizer_step self.training_type_plugin.optimizer_step(optimizer, lambda_closure=lambda_closure, kwargs) File "/home/.conda/envs/cdvae2/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 193, in optimizer_step optimizer.step(closure=lambda_closure, kwargs) File "/home/.conda/envs/cdvae2/lib/python3.8/site-packages/torch/optim/optimizer.py", line 89, in wrapper return func(*args, kwargs) File "/home/.conda/envs/cdvae2/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(*args, kwargs) File "/home/.conda/envs/cdvae2/lib/python3.8/site-packages/torch/optim/adam.py", line 66, in step loss = closure() File "/home/.conda/envs/cdvae2/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 732, in train_step_and_backward_closure result = self.training_step_and_backward( File "/home/.conda/envs/cdvae2/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 823, in training_step_and_backward result = self.training_step(split_batch, batch_idx, opt_idx, hiddens) File "/home/.conda/envs/cdvae2/lib/python3.8/site-packages/pytorch_lightning/trainer/training_loop.py", line 290, in training_step training_step_output = self.trainer.accelerator.training_step(args) File "/home/.conda/envs/cdvae2/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 204, in training_step return self.training_type_plugin.training_step(args) File "/home/.conda/envs/cdvae2/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/training_type_plugin.py", line 155, in training_step return self.lightning_module.training_step(args, kwargs) File "/home/CDVAE/cdvae-main/cdvae/pl_modules/model.py", line 528, in training_step outputs = self(batch, teacher_forcing, training=True) File "/home/.conda/envs/cdvae2/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl result = self.forward(*input, kwargs) File "/home/CDVAE/cdvae-main/cdvae/pl_modules/model.py", line 338, in forward rand_atom_types = torch.multinomial( RuntimeError: probability tensor contains either inf, nan or element < 0

txie-93 commented 2 years ago

I think this is due to training instability, which we also see during our experiments. My experience is that you can restart the training by loading the latest checkpoint. The training can usually continue. You can decrease the learning rate if it becomes too unstable.

ubikpt commented 2 years ago

Thank you very much for all your answers ! Now that you talk about that, I think I also got this error once with the carbon dataset. At that time, I attributed it to my experiments (python library upgrades and some changes in your code to try to eliminate warnings). About the learning rate, how can I decrease it ? I found the parameter lr in conf/optim/default.yaml, reduced it to half, but still crashing ...