atomistic-machine-learning / schnetpack

SchNetPack - Deep Neural Networks for Atomistic Systems
Other
791 stars 215 forks source link

Error in qm9 tutorial #592

Closed JonathanSchmidt1 closed 10 months ago

JonathanSchmidt1 commented 10 months ago

Thank you for the great package. Running spktrain experiment=qm9_atomwise produces an error after the first epoch while the md17 example runs fine. Seemingly the validation loss cannot be found. Here is the full stack trace:

⚙ Running with the following config:
├── run
│   └── work_dir: ${hydra:runtime.cwd}                                                                                                                                                                             
│       data_dir: ${run.work_dir}/data                                                                                                                                                                             
│       path: ${run.work_dir}/runs                                                                                                                                                                                 
│       experiment: qm9_${globals.property}                                                                                                                                                                        
│       id: ${uuid:1}                                                                                                                                                                                              
│       ckpt_path: null                                                                                                                                                                                            
│                                                                                                                                                                                                                  
├── globals
│   └── model_path: best_model                                                                                                                                                                                     
│       cutoff: 5.0                                                                                                                                                                                                
│       lr: 0.0005                                                                                                                                                                                                 
│       property: energy_U0                                                                                                                                                                                        
│       aggregation: sum                                                                                                                                                                                           
│                                                                                                                                                                                                                  
├── data
│   └── _target_: schnetpack.datasets.QM9                                                                                                                                                                          
│       datapath: ${run.data_dir}/qm9.db                                                                                                                                                                           
│       data_workdir: null                                                                                                                                                                                         
│       batch_size: 100                                                                                                                                                                                            
│       num_train: 110000                                                                                                                                                                                          
│       num_val: 10000                                                                                                                                                                                             
│       num_test: null                                                                                                                                                                                             
│       num_workers: 8                                                                                                                                                                                             
│       num_val_workers: null                                                                                                                                                                                      
│       num_test_workers: null                                                                                                                                                                                     
│       remove_uncharacterized: true                                                                                                                                                                               
│       distance_unit: Ang                                                                                                                                                                                         
│       property_units:                                                                                                                                                                                            
│         energy_U0: eV                                                                                                                                                                                            
│         energy_U: eV                                                                                                                                                                                             
│         enthalpy_H: eV                                                                                                                                                                                           
│         free_energy: eV                                                                                                                                                                                          
│         homo: eV                                                                                                                                                                                                 
│         lumo: eV                                                                                                                                                                                                 
│         gap: eV                                                                                                                                                                                                  
│         zpve: eV                                                                                                                                                                                                 
│       transforms:                                                                                                                                                                                                
│       - _target_: schnetpack.transform.SubtractCenterOfMass                                                                                                                                                      
│       - _target_: schnetpack.transform.RemoveOffsets                                                                                                                                                             
│         property: ${globals.property}                                                                                                                                                                            
│         remove_atomrefs: true                                                                                                                                                                                    
│         remove_mean: true                                                                                                                                                                                        
│       - _target_: schnetpack.transform.MatScipyNeighborList                                                                                                                                                      
│         cutoff: ${globals.cutoff}                                                                                                                                                                                
│       - _target_: schnetpack.transform.CastTo32                                                                                                                                                                  
│                                                                                                                                                                                                                  
├── model
│   └── representation:                                                                                                                                                                                            
│         radial_basis:                                                                                                                                                                                            
│           _target_: schnetpack.nn.radial.GaussianRBF                                                                                                                                                             
│           n_rbf: 20                                                                                                                                                                                              
│           cutoff: ${globals.cutoff}                                                                                                                                                                              
│         _target_: schnetpack.representation.PaiNN                                                                                                                                                                
│         n_atom_basis: 128                                                                                                                                                                                        
│         n_interactions: 3                                                                                                                                                                                        
│         shared_interactions: false                                                                                                                                                                               
│         shared_filters: false                                                                                                                                                                                    
│         cutoff_fn:                                                                                                                                                                                               
│           _target_: schnetpack.nn.cutoff.CosineCutoff                                                                                                                                                            
│           cutoff: ${globals.cutoff}                                                                                                                                                                              
│       _target_: schnetpack.model.NeuralNetworkPotential                                                                                                                                                          
│       input_modules:                                                                                                                                                                                             
│       - _target_: schnetpack.atomistic.PairwiseDistances                                                                                                                                                         
│       output_modules:                                                                                                                                                                                            
│       - _target_: schnetpack.atomistic.Atomwise                                                                                                                                                                  
│         output_key: ${globals.property}                                                                                                                                                                          
│         n_in: ${model.representation.n_atom_basis}                                                                                                                                                               
│         aggregation_mode: ${globals.aggregation}                                                                                                                                                                 
│       postprocessors:                                                                                                                                                                                            
│       - _target_: schnetpack.transform.CastTo64                                                                                                                                                                  
│       - _target_: schnetpack.transform.AddOffsets                                                                                                                                                                
│         property: ${globals.property}                                                                                                                                                                            
│         add_mean: true                                                                                                                                                                                           
│         add_atomrefs: true                                                                                                                                                                                       
│                                                                                                                                                                                                                  
├── task
│   └── optimizer_cls: torch.optim.AdamW                                                                                                                                                                           
│       optimizer_args:                                                                                                                                                                                            
│         lr: ${globals.lr}                                                                                                                                                                                        
│         weight_decay: 0.0                                                                                                                                                                                        
│       scheduler_cls: schnetpack.train.ReduceLROnPlateau                                                                                                                                                          
│       scheduler_monitor: val_loss                                                                                                                                                                                
│       scheduler_args:                                                                                                                                                                                            
│         mode: min                                                                                                                                                                                                
│         factor: 0.5                                                                                                                                                                                              
│         patience: 75                                                                                                                                                                                             
│         threshold: 0.0                                                                                                                                                                                           
│         threshold_mode: rel                                                                                                                                                                                      
│         cooldown: 10                                                                                                                                                                                             
│         min_lr: 0.0                                                                                                                                                                                              
│         smoothing_factor: 0.0                                                                                                                                                                                    
│       _target_: schnetpack.AtomisticTask                                                                                                                                                                         
│       outputs:                                                                                                                                                                                                   
│       - _target_: schnetpack.task.ModelOutput                                                                                                                                                                    
│         name: ${globals.property}                                                                                                                                                                                
│         loss_fn:                                                                                                                                                                                                 
│           _target_: torch.nn.MSELoss                                                                                                                                                                             
│         metrics:                                                                                                                                                                                                 
│           mae:                                                                                                                                                                                                   
│             _target_: torchmetrics.regression.MeanAbsoluteError                                                                                                                                                  
│           rmse:                                                                                                                                                                                                  
│             _target_: torchmetrics.regression.MeanSquaredError                                                                                                                                                   
│             squared: false                                                                                                                                                                                       
│         loss_weight: 1.0                                                                                                                                                                                         
│       warmup_steps: 0                                                                                                                                                                                            
│                                                                                                                                                                                                                  
├── trainer
│   └── _target_: pytorch_lightning.Trainer                                                                                                                                                                        
│       devices: 1                                                                                                                                                                                                 
│       min_epochs: null                                                                                                                                                                                           
│       max_epochs: 100000                                                                                                                                                                                         
│       enable_model_summary: true                                                                                                                                                                                 
│       profiler: null                                                                                                                                                                                             
│       gradient_clip_val: 0                                                                                                                                                                                       
│       accumulate_grad_batches: 1                                                                                                                                                                                 
│       val_check_interval: 1.0                                                                                                                                                                                    
│       check_val_every_n_epoch: 1                                                                                                                                                                                 
│       num_sanity_val_steps: 0                                                                                                                                                                                    
│       fast_dev_run: false                                                                                                                                                                                        
│       overfit_batches: 0                                                                                                                                                                                         
│       limit_train_batches: 1.0                                                                                                                                                                                   
│       limit_val_batches: 1.0                                                                                                                                                                                     
│       limit_test_batches: 1.0                                                                                                                                                                                    
│       detect_anomaly: false                                                                                                                                                                                      
│       precision: 32                                                                                                                                                                                              
│       accelerator: auto                                                                                                                                                                                          
│       num_nodes: 1                                                                                                                                                                                               
│       deterministic: false                                                                                                                                                                                       
│       inference_mode: false                                                                                                                                                                                      
│                                                                                                                                                                                                                  
├── callbacks
│   └── model_checkpoint:                                                                                                                                                                                          
│         _target_: schnetpack.train.ModelCheckpoint                                                                                                                                                               
│         monitor: val_loss                                                                                                                                                                                        
│         save_top_k: 1                                                                                                                                                                                            
│         save_last: true                                                                                                                                                                                          
│         mode: min                                                                                                                                                                                                
│         verbose: false                                                                                                                                                                                           
│         dirpath: checkpoints/                                                                                                                                                                                    
│         filename: '{epoch:02d}'                                                                                                                                                                                  
│         model_path: ${globals.model_path}                                                                                                                                                                        
│       early_stopping:                                                                                                                                                                                            
│         _target_: pytorch_lightning.callbacks.EarlyStopping                                                                                                                                                      
│         monitor: val_loss                                                                                                                                                                                        
│         patience: 200                                                                                                                                                                                            
│         mode: min                                                                                                                                                                                                
│         min_delta: 0.0                                                                                                                                                                                           
│         check_on_train_epoch_end: false                                                                                                                                                                          
│       lr_monitor:                                                                                                                                                                                                
│         _target_: pytorch_lightning.callbacks.LearningRateMonitor                                                                                                                                                
│         logging_interval: epoch                                                                                                                                                                                  
│       ema:                                                                                                                                                                                                       
│         _target_: schnetpack.train.ExponentialMovingAverage                                                                                                                                                      
│         decay: 0.995                                                                                                                                                                                             
│                                                                                                                                                                                                                  
├── logger
│   └── tensorboard:                                                                                                                                                                                               
│         _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger                                                                                                                                        
│         save_dir: tensorboard/                                                                                                                                                                                   
│         name: default                                                                                                                                                                                            
│                                                                                                                                                                                                                  
└── seed
    └── None                                                                                                                                                                                                       
[2024-01-05 16:46:48,119][schnetpack.cli][INFO] - Seed randomly...
/scratch/snx3000/jschmidt/envs/schnet/lib/python3.11/site-packages/lightning_fabric/utilities/seed.py:40: No seed found, seed set to 2839986104
Seed set to 2839986104
[2024-01-05 16:46:48,124][schnetpack.cli][INFO] - Instantiating datamodule <schnetpack.datasets.QM9>
[2024-01-05 16:46:48,130][schnetpack.cli][INFO] - Instantiating model <schnetpack.model.NeuralNetworkPotential>
[2024-01-05 16:46:48,143][schnetpack.cli][INFO] - Instantiating task <schnetpack.AtomisticTask>
[2024-01-05 16:46:48,171][torch.distributed.nn.jit.instantiator][INFO] - Created a temporary directory at /tmp/tmp_i12qvww
[2024-01-05 16:46:48,171][torch.distributed.nn.jit.instantiator][INFO] - Writing /tmp/tmp_i12qvww/_remote_module_non_scriptable.py
/scratch/snx3000/jschmidt/envs/schnet/lib/python3.11/site-packages/pytorch_lightning/utilities/parsing.py:198: Attribute 'model' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['model'])`.
[2024-01-05 16:46:48,207][schnetpack.cli][INFO] - Instantiating callback <schnetpack.train.ModelCheckpoint>
[2024-01-05 16:46:48,209][schnetpack.cli][INFO] - Instantiating callback <pytorch_lightning.callbacks.EarlyStopping>
[2024-01-05 16:46:48,210][schnetpack.cli][INFO] - Instantiating callback <pytorch_lightning.callbacks.LearningRateMonitor>
[2024-01-05 16:46:48,210][schnetpack.cli][INFO] - Instantiating callback <schnetpack.train.ExponentialMovingAverage>
[2024-01-05 16:46:48,211][schnetpack.cli][INFO] - Instantiating logger <pytorch_lightning.loggers.tensorboard.TensorBoardLogger>
[2024-01-05 16:46:48,237][schnetpack.cli][INFO] - Instantiating trainer <pytorch_lightning.Trainer>
/scratch/snx3000/jschmidt/envs/schnet/lib/python3.11/site-packages/lightning_fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /scratch/snx3000/jschmidt/envs/schnet/bin/spktrain e ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
`Trainer(limit_train_batches=1.0)` was configured so 100% of the batches per epoch will be used..
`Trainer(limit_val_batches=1.0)` was configured so 100% of the batches will be used..
`Trainer(limit_test_batches=1.0)` was configured so 100% of the batches will be used..
`Trainer(val_check_interval=1.0)` was configured so validation will run at the end of the training epoch..
[2024-01-05 16:46:48,522][schnetpack.cli][INFO] - Logging hyperparameters.
[2024-01-05 16:46:48,734][schnetpack.cli][INFO] - Starting training.
/scratch/snx3000/jschmidt/envs/schnet/lib/python3.11/site-packages/lightning_fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /scratch/snx3000/jschmidt/envs/schnet/bin/spktrain e ...
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:04<00:00, 12.37it/s]
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type                   | Params
---------------------------------------------------
0 | model   | NeuralNetworkPotential | 589 K 
1 | outputs | ModuleList             | 0     
---------------------------------------------------
589 K     Trainable params
0         Non-trainable params
589 K     Total params
2.356     Total estimated model params size (MB)
/scratch/snx3000/jschmidt/envs/schnet/lib/python3.11/site-packages/pytorch_lightning/utilities/data.py:104: Total length of `AtomsLoader` across ranks is zero. Please make sure this was your intention.
Epoch 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:07<00:00,  6.63it/s, v_num=0]/scratch/snx3000/jschmidt/envs/schnet/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:369: `ModelCheckpoint(monitor='val_loss')` could not find the monitored key in the returned metrics: ['lr_schedule', 'train_loss', 'train_energy_U0_mae', 'train_energy_U0_rmse', 'epoch', 'step']. HINT: Did you call `log('val_loss', value)` in the `LightningModule`?
Error executing job with overrides: ['experiment=qm9_atomwise']
Traceback (most recent call last):
  File "/scratch/snx3000/jschmidt/envs/schnet/lib/python3.11/site-packages/schnetpack/cli.py", line 158, in train
    trainer.fit(model=task, datamodule=datamodule, ckpt_path=config.run.ckpt_path)
  File "/scratch/snx3000/jschmidt/envs/schnet/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 544, in fit
    call._call_and_handle_interrupt(
  File "/scratch/snx3000/jschmidt/envs/schnet/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/scratch/snx3000/jschmidt/envs/schnet/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 580, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/scratch/snx3000/jschmidt/envs/schnet/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 989, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/scratch/snx3000/jschmidt/envs/schnet/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 1035, in _run_stage
    self.fit_loop.run()
  File "/scratch/snx3000/jschmidt/envs/schnet/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py", line 203, in run
    self.on_advance_end()
  File "/scratch/snx3000/jschmidt/envs/schnet/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py", line 382, in on_advance_end
    self.epoch_loop.update_lr_schedulers("epoch", update_plateau_schedulers=not self.restarting)
  File "/scratch/snx3000/jschmidt/envs/schnet/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 337, in update_lr_schedulers
    self._update_learning_rates(interval=interval, update_plateau_schedulers=update_plateau_schedulers)
  File "/scratch/snx3000/jschmidt/envs/schnet/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 372, in _update_learning_rates
    raise MisconfigurationException(
lightning_fabric.utilities.exceptions.MisconfigurationException: ReduceLROnPlateau conditioned on metric val_loss which is not available. Available metrics are: ['lr_schedule', 'train_loss', 'train_energy_U0_mae', 'train_energy_U0_rmse']. Condition can be set using `monitor` key in lr scheduler dict

Set the environment variable HYDRA_FULL_ERROR=1 for a complete stack trace.
Epoch 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 50/50 [00:08<00:00,  6.21it/s, v_num=0]
JonathanSchmidt1 commented 10 months ago

Sorry it was an issue with the qm9_database not being saved correctly in the end. Somehow it still ran through until the validation error.