uncscode / particula

a simple, fast, and powerful particle simulator
https://uncscode.github.io/particula
MIT License
5 stars 9 forks source link

pyTorch explore for Lagrangian #406

Closed Gorkowski closed 9 months ago

Gorkowski commented 9 months ago

Initial exploration into the Lagrangian model, with a target to run on GPUs (so pytorch) with high particle numbers. Using torch still allows development on CPUs, using identical code.

ngmahfouz commented 9 months ago

@ngmahfouz What do you think about waiting a bit, for 3.12? It has only been for like two months....so pytorch and pytz haven't supported it yet.

We can wait, no problem at all. How did we not catch the pytz error btw?

ngmahfouz commented 9 months ago

Worth mentioning that JAX can be a drop-in replacement for numpy if needed btw, likely with less code changes than pytorch. We can also potentially support different backends...

see an example of switching back and forth between numpy, numba, and jax: https://github.com/uncscode/hypersolver/blob/main/hypersolver/util.py (both numbe and jax are inactive now)

ngmahfouz commented 9 months ago

For the plot, it may be a good idea to scale the size of the spheres :) I think the option is s= (normalize mass first, then set s=normalized_mass) but I am not sure...

Gorkowski commented 9 months ago

@ngmahfouz What do you think about waiting a bit, for 3.12? It has only been for like two months....so pytorch and pytz haven't supported it yet.

We can wait, no problem at all. How did we not catch the pytz error btw?

It's just a warning right now, I think. From issue #405

Gorkowski commented 9 months ago

Worth mentioning that JAX can be a drop-in replacement for numpy if needed btw, likely with less code changes than pytorch. We can also potentially support different backends...

see an example of switching back and forth between numpy, numba, and jax: https://github.com/uncscode/hypersolver/blob/main/hypersolver/util.py (both numbe and jax are inactive now)

@ngmahfouz I did some research on JAX vs PyTorch. The one sticking point I have right now with JAX is the lack of GPU on Windows support, which Pytorch does have. I don't want to add additional hurdles to new Ph.D students (and myself) if there isn't a need. They both have a compile feature.

For Pytorch on Mac OS, it looks like the default install supports MPS on Apple Metal. I am not sure how that affects things on the code side, I assume just a device selection keyword like CUDA vs CPU vs MPS.

For us, I don't see a major problem if we write it in pytorch, then in v2 write it in JAX/numpy. I would be ecstatic if we had enough users that a v2 was called for :)

ngmahfouz commented 9 months ago

Sounds good to me about pytorch. In general, pytorch is more mature than others.

Before v0.1.0, we may want to discuss whether or not all these extensions must be in the main repo/package because it makes it heavier to install (the gpu pytorch package is pretty hefty). We can discuss that when we are closer to v0.1.0