CNCLgithub / Cusanus

Event-driven physical simulation: dynamics with objects as distributions
0 stars 1 forks source link

Jax port #6

Closed flxsosa closed 6 months ago

flxsosa commented 6 months ago

Pull request for porting the siren.py module into Jax.

Changes include:

  1. Removed imports: torch, math, cusanus.pytypes
  2. Added imports: equinox, jax, jaxtyping, jax.numpy
  3. All module classes' forward calls are implemented via __call__ instead of forward.
  4. All module classes ported to jax (obviously).
  5. "Better Comments 2" part 1.
belledon commented 6 months ago

thanks @flxsosa !