Add methods ptcl.raveled_id that gives the flattened particle IDs, and replaced ptcl_pos with ptcl.pos. Even though these looks more OOP than functional, these 2 methods are likely never going to be transformed by JAX through their input arguments (other than the self).
Add methods
ptcl.raveled_id
that gives the flattened particle IDs, and replacedptcl_pos
withptcl.pos
. Even though these looks more OOP than functional, these 2 methods are likely never going to be transformed by JAX through their input arguments (other than theself
).