general-molecular-simulations / so3lr

SO3krates and Universal Pairwise Force Field for Molecular Simulation
MIT License
64 stars 2 forks source link

ValueError: vmap got inconsistent sizes for array axes to be mapped #2

Open lamthuy opened 1 month ago

lamthuy commented 1 month ago

Hi, I tried to run so3lr to calculate the energy of the system:

calc = So3lrCalculator()
symbols = ['H', 'H', 'H', 'H', 'C', 'H', 'C', 'H', 'C', 'H', 'H', 'C', 'H', 'H', 'C', 'H', 'H', 'H', 'H', 'C', 'H', 'C', 'H', 'H', 'C', 'C', 'H', 'H', 'H', 'O', 'C', 'H', 'C', 'C', 'H', 'N', 'H', 'O', 'O', 'H', 'N', 'H', 'H', 'H', 'O', 'C', 'H', 'H', 'H', 'H', 'H', 'S', 'H', 'H', 'O', 'F', 'O', 'N', 'N', 'N', 'N', 'N', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C']
coordinates = tensor([[  3.5120,  15.7560,  -3.3390],
        [  2.1040,  16.3200,  -2.4180],
        [  0.0410,  15.0000,  -7.6510],
        [ -6.8870,   9.4890,  -8.6300],
        [ -8.0550,   7.7580,  -9.1600],
        [ -7.2220,   7.2370,  -9.6090],
        [ -1.5880,  15.7860,  -4.4260],
        [ -0.8390,  15.7620,  -5.2180],
        [ -0.8910,  15.5980,  -3.0720],
        [ -0.3440,  14.6550,  -3.0640],
        [ -1.6260,  15.5900,  -2.2670],
        [ -2.5420,  14.6010,  -4.6250],
        [ -1.9920,  13.6640,  -4.5340],
        [ -2.9910,  14.6390,  -5.6170],
        [  0.9810,  15.4370,   0.7910],
        [  0.4470,  15.3120,  -0.1510],
        [  0.6770,  14.6530,   1.4850],
        [  2.0520,  15.3490,   0.6080],
        [ -0.8780,   8.4070,   2.3880],
        [  0.9290,   9.7010,   4.0010],
        [  0.8350,  10.1120,   2.9960],
        [ -1.2740,  13.1340,   3.7930],
        [ -0.6910,  12.2200,   3.8990],
        [ -0.8080,  13.7220,   3.0020],
        [ -2.7030,  12.7690,   3.3390],
        [ -2.6970,  12.3090,   1.8810],
        [ -2.0430,  11.4450,   1.7660],
        [ -3.7090,  12.0460,   1.5740],
        [ -2.3380,  13.1220,   1.2500],
        [  2.6930,  12.7470,   3.7220],
        [  4.6830,  14.5690,   3.0960],
        [  4.0060,  14.3270,   2.2770],
        [  5.8210,  13.5460,   3.0990],
        [  5.5240,  16.1870,   0.3450],
        [  4.4490,  16.1930,   0.2400],
        [  5.9040,  12.7470,   2.0380],
        [  5.2150,  12.8440,   1.3050],
        [  6.7430,  12.1600,  -0.5410],
        [ -4.8040,   2.7240,  -7.3800],
        [ -0.8810,   4.1440,  -9.1580],
        [ -2.9680,   3.7440,  -9.0530],
        [ -3.5850,   3.2860,  -8.3970],
        [ -3.2230,   4.7190,  -8.9860],
        [ -3.1820,   3.3230,  -9.9460],
        [ -1.0070,   5.0910,  -5.6680],
        [  3.0320,   8.8190,   0.6930],
        [  2.0150,   9.0580,   0.3820],
        [  3.6120,   9.7410,   0.7290],
        [  4.5580,   9.2440,  -1.6680],
        [  2.8110,   9.0140,  -1.9010],
        [ -0.6840,   7.8370,   0.2180],
        [  0.4240,   6.8650,  -1.6690],
        [  1.4380,   7.7300,  -1.5710],
        [ -2.6590,   7.1040,  -3.3470],
        [ -4.4660,   6.4850,  -5.4810],
        [ -1.1144,   8.9273, -10.8767],
        [ -1.3417,   8.1687,  -6.5208],
        [  2.8746,  13.0024,   0.5512],
        [  3.5102,  13.0607,  -0.6100],
        [ -0.5359,  10.2647,  -6.9262],
        [  3.0124,  12.5055,  -2.9554],
        [ -2.3783,   9.2686,  -8.2274],
        [  1.5227,  12.2295,  -0.9969],
        [  0.3518,  11.7073,  -1.4823],
        [  1.6563,  12.5017,   0.3658],
        [  0.1405,  11.3587,  -2.8872],
        [  2.6909,  12.5863,  -1.5889],
        [ -0.6768,  11.4601,  -0.6202],
        [  0.6022,  12.2405,   1.2057],
        [ -0.0495,  12.3204,  -3.8582],
        [  0.1070,  10.0377,  -3.2454],
        [ -0.5594,  11.7217,   0.7110],
        [ -0.3092,  10.6305,  -5.5710],
        [ -0.2712,  11.9555,  -5.1822],
        [ -0.1175,   9.6944,  -4.5848],
        [ -2.6926,   8.3168,  -9.2236],
        [ -1.4378,   9.1750,  -7.2098],
        [ -4.1347,   6.4364,  -9.8255],
        [ -3.7321,   7.4068,  -8.9462],
        [ -2.1119,   8.1580, -10.4440],
        [ -3.5288,   6.3054, -11.0435],
        [ -2.5175,   7.1652, -11.3543],
        [ -5.2368,   5.4858,  -9.5086]])
        atoms = Atoms(symbols=symbols, positions=coord)
        atoms.calc = calc
        energy = atoms.get_potential_energy()

But I got the following issues:

File "correlation_analysis_so3lr.py", line 94, in <module>
    energy = atoms.get_potential_energy()
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda3/lib/python3.11/site-packages/ase/atoms.py", line 731, in get_potential_energy
    energy = self._calc.get_potential_energy(self)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda3/lib/python3.11/site-packages/ase/calculators/calculator.py", line 709, in get_potential_energy
    energy = self.get_property('energy', atoms)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda3/lib/python3.11/site-packages/ase/calculators/calculator.py", line 737, in get_property
    self.calculate(atoms, [name], system_changes)
  File "/anaconda3/lib/python3.11/site-packages/mlff/md/calculator_sparse.py", line 277, in calculate
    neighbors = self.spatial_partitioning.update_fn(system.R, self.neighbors, new_cell=cell)
                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda3/lib/python3.11/site-packages/glp/neighborlist.py", line 235, in update_fn
    force_update | need_update_fn(neighbors, positions, new_cell),
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda3/lib/python3.11/site-packages/glp/neighborlist.py", line 158, in need_update_fn
    movement = make_squared_distance(new_cell)(
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ValueError: vmap got inconsistent sizes for array axes to be mapped:
  * one axis had size 101: axis 0 of argument Ra of type float32[101,3];
  * one axis had size 83: axis 0 of argument Rb of type float32[83,3]

Any idea on how that exception happens?

thorben-frank commented 1 month ago

Hey, thanks for rasing the issue. Unfortunately, I was not able to execute the code as posted in the issue.

Along the way I experienced to issues: What is tensor in your case? And coord is not defined. By replacing tensor with numpy.array(...) and replacing coord by coordinates I was able to run the code. You find the code which worked for me below.


import numpy as np

from ase import Atoms
from so3lr import So3lrCalculator

calc = So3lrCalculator()
symbols = ['H', 'H', 'H', 'H', 'C', 'H', 'C', 'H', 'C', 'H', 'H', 'C', 'H', 'H', 'C', 'H', 'H', 'H', 'H', 'C', 'H', 'C', 'H', 'H', 'C', 'C', 'H', 'H', 'H', 'O', 'C', 'H', 'C', 'C', 'H', 'N', 'H', 'O', 'O', 'H', 'N', 'H', 'H', 'H', 'O', 'C', 'H', 'H', 'H', 'H', 'H', 'S', 'H', 'H', 'O', 'F', 'O', 'N', 'N', 'N', 'N', 'N', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C', 'C']
coordinates = np.array(
    [
        [  3.5120,  15.7560,  -3.3390],
        [  2.1040,  16.3200,  -2.4180],
        [  0.0410,  15.0000,  -7.6510],
        [ -6.8870,   9.4890,  -8.6300],
        [ -8.0550,   7.7580,  -9.1600],
        [ -7.2220,   7.2370,  -9.6090],
        [ -1.5880,  15.7860,  -4.4260],
        [ -0.8390,  15.7620,  -5.2180],
        [ -0.8910,  15.5980,  -3.0720],
        [ -0.3440,  14.6550,  -3.0640],
        [ -1.6260,  15.5900,  -2.2670],
        [ -2.5420,  14.6010,  -4.6250],
        [ -1.9920,  13.6640,  -4.5340],
        [ -2.9910,  14.6390,  -5.6170],
        [  0.9810,  15.4370,   0.7910],
        [  0.4470,  15.3120,  -0.1510],
        [  0.6770,  14.6530,   1.4850],
        [  2.0520,  15.3490,   0.6080],
        [ -0.8780,   8.4070,   2.3880],
        [  0.9290,   9.7010,   4.0010],
        [  0.8350,  10.1120,   2.9960],
        [ -1.2740,  13.1340,   3.7930],
        [ -0.6910,  12.2200,   3.8990],
        [ -0.8080,  13.7220,   3.0020],
        [ -2.7030,  12.7690,   3.3390],
        [ -2.6970,  12.3090,   1.8810],
        [ -2.0430,  11.4450,   1.7660],
        [ -3.7090,  12.0460,   1.5740],
        [ -2.3380,  13.1220,   1.2500],
        [  2.6930,  12.7470,   3.7220],
        [  4.6830,  14.5690,   3.0960],
        [  4.0060,  14.3270,   2.2770],
        [  5.8210,  13.5460,   3.0990],
        [  5.5240,  16.1870,   0.3450],
        [  4.4490,  16.1930,   0.2400],
        [  5.9040,  12.7470,   2.0380],
        [  5.2150,  12.8440,   1.3050],
        [  6.7430,  12.1600,  -0.5410],
        [ -4.8040,   2.7240,  -7.3800],
        [ -0.8810,   4.1440,  -9.1580],
        [ -2.9680,   3.7440,  -9.0530],
        [ -3.5850,   3.2860,  -8.3970],
        [ -3.2230,   4.7190,  -8.9860],
        [ -3.1820,   3.3230,  -9.9460],
        [ -1.0070,   5.0910,  -5.6680],
        [  3.0320,   8.8190,   0.6930],
        [  2.0150,   9.0580,   0.3820],
        [  3.6120,   9.7410,   0.7290],
        [  4.5580,   9.2440,  -1.6680],
        [  2.8110,   9.0140,  -1.9010],
        [ -0.6840,   7.8370,   0.2180],
        [  0.4240,   6.8650,  -1.6690],
        [  1.4380,   7.7300,  -1.5710],
        [ -2.6590,   7.1040,  -3.3470],
        [ -4.4660,   6.4850,  -5.4810],
        [ -1.1144,   8.9273, -10.8767],
        [ -1.3417,   8.1687,  -6.5208],
        [  2.8746,  13.0024,   0.5512],
        [  3.5102,  13.0607,  -0.6100],
        [ -0.5359,  10.2647,  -6.9262],
        [  3.0124,  12.5055,  -2.9554],
        [ -2.3783,   9.2686,  -8.2274],
        [  1.5227,  12.2295,  -0.9969],
        [  0.3518,  11.7073,  -1.4823],
        [  1.6563,  12.5017,   0.3658],
        [  0.1405,  11.3587,  -2.8872],
        [  2.6909,  12.5863,  -1.5889],
        [ -0.6768,  11.4601,  -0.6202],
        [  0.6022,  12.2405,   1.2057],
        [ -0.0495,  12.3204,  -3.8582],
        [  0.1070,  10.0377,  -3.2454],
        [ -0.5594,  11.7217,   0.7110],
        [ -0.3092,  10.6305,  -5.5710],
        [ -0.2712,  11.9555,  -5.1822],
        [ -0.1175,   9.6944,  -4.5848],
        [ -2.6926,   8.3168,  -9.2236],
        [ -1.4378,   9.1750,  -7.2098],
        [ -4.1347,   6.4364,  -9.8255],
        [ -3.7321,   7.4068,  -8.9462],
        [ -2.1119,   8.1580, -10.4440],
        [ -3.5288,   6.3054, -11.0435],
        [ -2.5175,   7.1652, -11.3543],
        [ -5.2368,   5.4858,  -9.5086]
    ]
)
atoms = Atoms(symbols=symbols, positions=coordinates)
atoms.calc = calc
energy = atoms.get_potential_energy()

Let me know is something remains unclear / is not working!

lamthuy commented 1 month ago

Hi, Thank you for the response. I ran the code above which works well but I found that the issue is actually different, below is my code:

calc = So3lrCalculator()
for idx, d in enumerate(pbar):
    atoms, coordinates, atom_masks, y = d
    for atom, coord, label in zip(atoms, coordinates, y):
        symbols = [chemical_symbols[number] for number in atom]
        print(symbols)
        coord = coord.cpu().numpy()
        print(coord)
        atoms = Atoms(symbols=symbols, positions=coord)
        atoms.calc = calc
        energy = atoms.get_potential_energy()
        energies.append(energy)
        affinity.append(label.item())
        print("Energy", energy)

In the above code I reused calc for new systems and that cause the issue. the following code works:

for idx, d in enumerate(pbar):
    atoms, coordinates, atom_masks, y = d
    for atom, coord, label in zip(atoms, coordinates, y):
        calc = So3lrCalculator()
        symbols = [chemical_symbols[number] for number in atom]
        print(symbols)
        coord = coord.cpu().numpy()
        print(coord)
        atoms = Atoms(symbols=symbols, positions=coord)
        atoms.calc = calc
        energy = atoms.get_potential_energy()
        energies.append(energy)
        affinity.append(label.item())
        print("Energy", energy)

But that code seems not efficient because it needs to create and load the checkpoint again for every new atomic systems. Do you have a solution that works better that does not require to recreate a new So3lrCalculator() everytime we work with a new system?