torchmd / torchmd-net

Training neural network potentials
MIT License
320 stars 72 forks source link

Tensor size mismatch when training tensornet #309

Closed FranklinHu1 closed 5 months ago

FranklinHu1 commented 5 months ago

Hello,

I am trying to do some tensornet training using the latest version of torchmd-net and a single H100 GPU. However, I encounter the following error:

/home/frankhu/torchmd-net/torchmdnet/module.py:168: UserWarning: Using a target size (torch.Size([98464, 1])) that is different to the input size (torch.Si  loss_y = loss_fn(y, batch.y)
Traceback (most recent call last):
  File "/home/frankhu/torchmd-net/scripts/train.py", line 233, in <module>
    main()
  File "/home/frankhu/torchmd-net/scripts/train.py", line 218, in main
    trainer.fit(model, data, ckpt_path=None if args.reset_trainer else args.load_model)
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 544, in fit
    call._call_and_handle_interrupt(
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 44, in _call_and_handle_interrupt      return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 580, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 989, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1033, in _run_stage
    self._run_sanity_check()
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.11/site-packages/lightning/pytorch/trainer/trainer.py", line 1062, in _run_sanity_check
    val_loop.run()
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.11/site-packages/lightning/pytorch/loops/utilities.py", line 182, in _decorator
    return loop_run(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 134, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.11/site-packages/lightning/pytorch/loops/evaluation_loop.py", line 391, in _evaluation_step      output = call._call_strategy_hook(trainer, hook_name, *step_args)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.11/site-packages/lightning/pytorch/trainer/call.py", line 309, in _call_strategy_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.11/site-packages/lightning/pytorch/strategies/strategy.py", line 403, in validation_step
    return self.lightning_module.validation_step(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/frankhu/torchmd-net/torchmdnet/module.py", line 141, in validation_step
    return self.step(batch, **step_type)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/frankhu/torchmd-net/torchmdnet/module.py", line 227, in step
    step_losses = self._compute_losses(y, neg_dy, batch, loss_fn, stage)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/frankhu/torchmd-net/torchmdnet/module.py", line 168, in _compute_losses
    loss_y = loss_fn(y, batch.y)
             ^^^^^^^^^^^^^^^^^^^
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.11/site-packages/torch/nn/functional.py", line 3297, in l1_loss
    expanded_input, expanded_target = torch.broadcast_tensors(input, target)
                                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/frankhu/mambaforge/envs/torchmd-net/lib/python3.11/site-packages/torch/functional.py", line 73, in broadcast_tensors
    return _VF.broadcast_tensors(tensors)  # type: ignore[attr-defined]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: The size of tensor a (16) must match the size of tensor b (98464) at non-singleton dimension 0

My generated hparams.yaml for this experiment is as follows:

activation: silu
aggr: add
atom_filter: -1
attn_activation: silu
batch_size: 16
box_vecs:
- - 12.42
  - 0
  - 0
- - 0
  - 12.42
  - 0
- - 0
  - 0
  - 12.42
charge: false
check_errors: true
conf: null
coord_files: null
cutoff_lower: 0.0
cutoff_upper: 5.0
dataset: HDF5
dataset_arg: null
dataset_preload_limit: 1024
dataset_root: W64_revPBE.h5
derivative: true
distance_influence: both
early_stopping_patience: 95
ema_alpha_neg_dy: 1.0
ema_alpha_y: 0.0
embed_files: null
embedding_dimension: 64
energy_files: null
equivariance_invariance_group: O(3)
force_files: null
gradient_clipping: 0.0
inference_batch_size: 16
load_model: null
log_dir: H2O_exp
lr: 0.001
lr_factor: 0.9
lr_metric: val_total_mse_loss
lr_min: 1.0e-07
lr_patience: 5
lr_warmup_steps: 0
max_num_neighbors: 256
max_z: 100
model: tensornet
neg_dy_weight: 1.0
neighbor_embedding: true
ngpus: -1
num_epochs: 1000
num_heads: 2
num_layers: 0
num_nodes: 1
num_rbf: 32
num_workers: 32
output_model: Scalar
precision: 64
prior_args: []
prior_model: null
rbf_type: expnorm
redirect: true
reduce_op: mean
remove_ref_energy: false
reset_trainer: false
save_interval: 1
seed: 42
spin: false
splits: h2o_splits.npz
standardize: false
static_shapes: false
tensorboard_use: true
test_interval: 10
test_size: 0.1
train_size: 0.8
trainable_rbf: true
val_size: 0.1
vector_cutoff: false
wandb_name: training
wandb_project: training_
wandb_resume_from_id: null
wandb_use: false
weight_decay: 0.0
y_weight: 0.0

My dataset is formatted using the HDF5 format, consisting of boxes of 64 water molecules. I have attached it as a zip file to this issue, along with a npz file of the splits I use for training. I am a little confused because I have successfully run trainings with tensornet in the past using this data and these settings,

Any help would be greatly appreciated. Thank you very much! water_data.zip

RaulPPelaez commented 5 months ago

We think this is a problem with the dataset. The error is in computing the losses, so after the model is executed and the trainer is comparing the model outputs with the labels in your dataset.

Try to import the dataset manually and check the shapes of some sample.

from torchmdnet.datasets import HDF5
ds = HDF5("water_data.hdf")
sample = ds.get(0)
print(sample)
FranklinHu1 commented 5 months ago

Yes, it does seem to be a problem with the dataset. When I run the suggested code snippet, I get the following output:

>>> from torchmdnet.datasets import HDF5
>>> ds = HDF5("W64_revPBE.h5")
Loading 1 HDF5 files (28.20 MB)
Preloading 1 HDF5 files (28.20 MB)
>>> sample = ds.get(0)
>>> print(sample)
Data(pos=[192, 3], z=[192], y=[6154], neg_dy=[192, 3])

The shape of y is wrong, since there should only be 1 energy per frame, although I do have 6154 frames in total. Each frame has 64 water molecules, so 192 atoms total, which is correct for the position, type embedding, and force.

Right now, my dataset is an h5 file with the following shapes for each of the keys:

energy: (6154,), i.e. (N_frames,) forces: (6154, 192, 3), i.e. (N_frames, N_atoms, direction) pos: (6154, 192, 3), i.e. (N_frames, N_atoms, direction) types: (6154, 192), i.e. (N_frames, N_atoms)

Is the fix here simply to change the shape of the energies to add an extra dimension, i.e. (6154,) --> (6154, 1)?

Thank you!

guillemsimeon commented 5 months ago

I think so. Can you try it and let us know? Btw, I realized that you are using a very small model (0L, 64 hidden channels). I am just curious: does it perform satisfactorily? Also, I saw that you are using mean as the aggregation scheme. Is this for any particular reason? I never tried that.

G

FranklinHu1 commented 5 months ago

Hi @RaulPPelaez @guillemsimeon,

So sorry for the late response! Yes, reshaping the energies to be (N_frames, 1) resolved the issue. Intuitively this makes sense, but it might be helpful to add some documentation around the hdf5 dataset specifying what shapes everything should be.

Regarding the model size @guillemsimeon, I mostly use this model size because it is quick to train and it works quite well, at least for the bulk water dynamics I am working on right now. The mean aggregation scheme, along with the previous standardization feature, turned out to be the root cause of many of the issues I was having in the past, so using the addition aggregation for everything is definitely the way to go. I will let you know if this model size continues to work for some of the more challenging systems I intend to tackle next.

Thanks again for all your help!

RaulPPelaez commented 5 months ago

Glad you got it working. I was reading the code to add a comment about this and I believe your usecase should be supported. You found a bug!

This line here: https://github.com/torchmd/torchmd-net/blob/8b472462f212aa58a36c03b26d75900acc09647c/torchmdnet/datasets/hdf.py#L92 Should be:

tmp = tmp.unsqueeze(-1)
RaulPPelaez commented 5 months ago

You can confirm this is a bug in the cache preloading code by setting

dataset_preload_limit: 0

in your yaml, which skips this code.

RaulPPelaez commented 5 months ago

Should be fixed by https://github.com/torchmd/torchmd-net/pull/313

FranklinHu1 commented 5 months ago

Awesome, thanks for looking into this @RaulPPelaez! As far as the model performance goes, this doesn't affect the training or any other operations right?

RaulPPelaez commented 5 months ago

Should not affect greatly. In principle preloading should be faster, but YMMV. Let me know your experience!

RaulPPelaez commented 5 months ago

Please feel free to reopen if the issue resurfaces.