GalSim-developers / JAX-GalSim

JAX port of GalSim, for parallelized, GPU accelerated, and differentiable galaxy image simulations.
Other
25 stars 3 forks source link

GPU performance degredation with wrapping code #96

Closed beckermr closed 3 months ago

beckermr commented 3 months ago

@ismael-mendoza and @aguinot found that the wrapping code for hermitian images is causing a big performance degradation on GPUs. We should investigate and fix.

see this notebook: https://colab.research.google.com/drive/14RQYB4BSPcv-TeHzz-cDkTZa_jqkXfGe?usp=sharing

beckermr commented 3 months ago

Some idea on this

ismael-mendoza commented 3 months ago

thanks for posting the issue, one more datapoint that might or might not be helpful is that it runs fast on a TPU in google colab (but I don’t know too much about TPUs to know whether this clarifies the issue further)

aguinot commented 3 months ago

The fact that the code works "well" on CPU and TPU make me think that there could be a bug in a jax function that is architecture dependent. I haven't track down to which part of the code there is a bottleneck but I was think at the for loops with jax.lax.fori_loop but I haven't found any insight on that..
@beckermr I found the C version of the wrapping in GalSim (here) but do you know where we could find a description of what this function is doing?

beckermr commented 3 months ago

That function is a bear. I'd look at the non-hermitian version in jax to get the basic idea. It wraps a full image into a sub image at some location assuming periodic BCs.

beckermr commented 3 months ago

The hermitian stuff is just dealing with the half-complex format of some of the images.