facebookresearch / fmmax

Fourier modal method with Jax
MIT License
97 stars 10 forks source link

Hanging (or slow) fmmax on GPU with vector formulation and uniform patterned layers #88

Open mfschubert opened 8 months ago

mfschubert commented 8 months ago

I am finding that fmmax can hang or be unexpectedly slow in certain scenarios, specifically on GPU platform with uniform patterned layers and using a vector formulation.

I put together a colab which shows non-problematic and problematic cases. Right now it uses the invrs.io gym polarization sorter example, so a simpler fmmax-only repro is likely possible.

https://colab.research.google.com/drive/1BdntytzVa8Li66VeY0QCOX8LMoyRL1dG#scrollTo=RT9goZ8pFAXd

smartalecH commented 8 months ago

@mfschubert have you tried profiling eg using perfetto or tensorboard to see what functions are hanging?

It does a good job even with really complicated stacks. I've profiled entire optimization loops wrapped around really complicated fmmax gym problems, and it offers really useful granularity.

mfschubert commented 8 months ago

I haven't tried that here, but seems like it would be a good step.

By the way, would you be willing to share some of that profiling info? It would be useful to know where the long poles are from a performance perspective, even if they aren't acted upon immediately.

smartalecH commented 8 months ago

By the way, would you be willing to share some of that profiling info?

Here's some profiling I did on a problem that essentially combined the microlens.py example with a simple diffractive lens structure underneath. The excitation fields were computed from an external solver, and we injected them into the simulation using the source expansion feature. This was all on the CPU (didn't have a GPU with enough memory at the time) but this kind of profiling works great on the GPU too.

image

As expected, the most expensive piece of the computation was the eigendecomposition. But it's interesting how long assembling the S-matrix takes as well. Expanding the source solutions onto the basis takes awhile too.

All of the other "important" stuff (e.g. propagating the sources, computing poynting fluxes and near2far transforms) don't even register, really, because their cost is neglible in comparison.

Of course, this is only one data point... would be good to profile a few different examples of different sizes on different hardware.