Closed raimis closed 9 months ago
The minimizer changes to CPU, when forces are too large for GPU:
While the symmetry operation setups a device once and reuse it:
Quick solution would be to have a map with both implementations (initialized on demand) and simply choose the correct one depending on the position device. Lets see if doing this in SymmetryFunctions is enough to get rid of the issue.
This works, but just makes the same kind of error pop up a bit later on. Let me evaluate how hard it would be to apply the same thing to the rest of OptimizedTorchAni... In the meantime you can use this as a quick workaround:
def forward(self, species_coordinates: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies:
#send to cuda
species_coordinates = (species_coordinates[0].cuda(), species_coordinates[1].cuda())
if cell is not None:
cell = cell.cuda()
if pbc is not None:
pbc = pbc.cuda()
species_coordinates = self.species_converter(species_coordinates)
species_aevs = self.aev_computer(species_coordinates, cell=cell, pbc=pbc)
species_energies = self.neural_networks(species_aevs)
species_energies = self.energy_shifter(species_energies)
return species_energies
The minimizer changes to CPU, when forces are too large for GPU:
While I understand that the problem here is an unexpected device migration and I also see the rationale behind this recomputation in the CPU, I do not see how these lines you shared actually change the device of the positions. What I gather from those lines is that the energy is recomputed CPU side, but without changing anything else. Could you share some more context?
I have no idea where the transfer happens.
@peastman should know.
Lets move the discussion to #113 .
I think this could be closed, since the issue was solved in OpenMM-Torch
Create a system with two hydrogen atoms. Artificially scale up the energy (and forces) of ANI-2x and minimize the system.
The script crashes with a peculiar error:
If the energy scale is reduced or changed to the CPU platform the error disappears: