zadorlab / sella

A Python software package for saddle point optimization and minimization of atomic systems.
https://www.ecc-project.org/
Other
72 stars 20 forks source link

TypeError during sella optimization #11

Closed PattanaikL closed 2 years ago

PattanaikL commented 2 years ago

I'm running into an issue when running the Sella optimizer. Here's the stack trace: image

I'm not too familiar with JIT, but perhaps the issue stems from using non-hashable objects as static arguments?

If I replace line 334: set(self.indices[:-1]) == set(other.indices[:-1]) to the following: set(self.indices[:-1].tolist()) == set(other.indices[:-1].tolist()) there's no error and the algorithm works as expected.

ehermes commented 2 years ago

Hm, that's odd, indices ought to be a regular numpy array. Do you have a minimal script that reproduces this issue? Your proposed solution will work, but I'm worried about why indices is becoming a jax.numpy.ndarray object in the first place.

PattanaikL commented 2 years ago

Ahh ok, I think this issue has been resolved in the master branch. On my server, I had installed it with pip, which was causing the issue. Locally, I used the most recent master branch, which doesn't raise the error.

In any case, here's some code to reproduce the original issue:

import numpy as np
from sella import Sella
from ase import Atoms
from ase.calculators.emt import EMT

numbers = [6, 6, 8, 6, 7, 6, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1]
positions = np.array([[ 1.6527,  0.1839,  0.5471],
                      [ 0.8306, -0.7019, -0.3921],
                      [ 0.1444, -1.7057,  0.3074],
                      [-0.1588,  0.1407, -1.2323],
                      [-1.0435,  0.9014, -0.3839],
                      [-1.9624,  0.3492,  0.3955],
                      [-2.6059,  1.2438,  1.1877],
                      [ 2.2265,  0.9113, -0.021 ],
                      [ 1.0005,  0.7078,  1.243 ],
                      [ 2.3375, -0.4341,  1.1212],
                      [ 1.5007, -1.2317, -1.0781],
                      [-0.59  , -1.2882,  0.7884],
                      [ 0.389 ,  0.8214, -1.8838],
                      [-0.7507, -0.5456, -1.8453],
                      [-0.8885,  1.9072, -0.2888],
                      [-2.9743,  0.1089,  0.9871]])

atoms = Atoms(positions=positions, numbers=numbers)
atoms.calc = EMT()

opt = Sella(atoms, trajectory="irc.traj")
opt.run()
ehermes commented 2 years ago

Alright, good to know.

We'll probably be updating the version of Sella on PyPI soon-ish. For now, please continue to use the git master.