Closed PattanaikL closed 2 years ago
Hm, that's odd, indices
ought to be a regular numpy array. Do you have a minimal script that reproduces this issue? Your proposed solution will work, but I'm worried about why indices
is becoming a jax.numpy.ndarray
object in the first place.
Ahh ok, I think this issue has been resolved in the master
branch. On my server, I had installed it with pip, which was causing the issue. Locally, I used the most recent master
branch, which doesn't raise the error.
In any case, here's some code to reproduce the original issue:
import numpy as np
from sella import Sella
from ase import Atoms
from ase.calculators.emt import EMT
numbers = [6, 6, 8, 6, 7, 6, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1]
positions = np.array([[ 1.6527, 0.1839, 0.5471],
[ 0.8306, -0.7019, -0.3921],
[ 0.1444, -1.7057, 0.3074],
[-0.1588, 0.1407, -1.2323],
[-1.0435, 0.9014, -0.3839],
[-1.9624, 0.3492, 0.3955],
[-2.6059, 1.2438, 1.1877],
[ 2.2265, 0.9113, -0.021 ],
[ 1.0005, 0.7078, 1.243 ],
[ 2.3375, -0.4341, 1.1212],
[ 1.5007, -1.2317, -1.0781],
[-0.59 , -1.2882, 0.7884],
[ 0.389 , 0.8214, -1.8838],
[-0.7507, -0.5456, -1.8453],
[-0.8885, 1.9072, -0.2888],
[-2.9743, 0.1089, 0.9871]])
atoms = Atoms(positions=positions, numbers=numbers)
atoms.calc = EMT()
opt = Sella(atoms, trajectory="irc.traj")
opt.run()
Alright, good to know.
We'll probably be updating the version of Sella on PyPI soon-ish. For now, please continue to use the git master.
I'm running into an issue when running the Sella optimizer. Here's the stack trace:
I'm not too familiar with JIT, but perhaps the issue stems from using non-hashable objects as static arguments?
If I replace line 334:
set(self.indices[:-1]) == set(other.indices[:-1])
to the following:set(self.indices[:-1].tolist()) == set(other.indices[:-1].tolist())
there's no error and the algorithm works as expected.