flexcompute / tidy3d

Fast electromagnetic solver (FDTD) at scale.
https://docs.flexcompute.com/projects/tidy3d/en/latest/
GNU Lesser General Public License v2.1
176 stars 40 forks source link

more general `TriangleMesh` differentiation support #1208

Open tylerflex opened 10 months ago

so-rose commented 3 months ago

Just 2c; massive +1 for this.

My integration relies on geometry generated using Blender 3D's "Geometry Nodes", a visual scripting system for intuitive and interactive geometry generation.

With the flexibility of arbitrary triangle meshes comes some caveats; namely, that the scripted parameterized inputs to geometry-generating GeoNodes trees must be realized at the time of generating a td.TriangleMesh object. The mesh generation is on some level just a function, Blender-the-software must be invoked somewhere in order to actually run this function.

I imagine this is far from the most exotic geometry generation floating around. Thus, the optimization loop of any "generate arbitrary geometry w/optimization of symbolic inputs" would probably need to be local+cloud loop, ex. with td.TriangleMesh wrapping a non-jax callback while itself being registered with jax.

So, not easy. Still, if this were to be undertaken, then almost infinite flexibility would be available to us users with very little effort, when it comes to structures that can be optimized! Which is very fun.

tylerflex commented 3 months ago

I think this should be possible.. hopefully in the next few months we can implement it.

tylerflex commented 3 months ago

FYI: We are in the process of deprecating the adjoint (jax) plugin and making tidy3d natively compatible with autograd. So you should be able to just take your existing functions using regular tidy3d components and call autograd.grad() on them without any API modifications.

Once we get this working, implementing things like triangle mesh will be a lot simpler.

so-rose commented 3 months ago

@tylerflex That sounds amazing! I'm glad it's in the plan. Just to make sure I'm understanding correctly; the autograd is in reference to this? https://github.com/HIPS/autograd

If so, two thoughts/worries:

tylerflex commented 3 months ago

@tylerflex That sounds amazing! I'm glad it's in the plan. Just to make sure I'm understanding correctly; the autograd is in reference to this? https://github.com/HIPS/autograd

Yea that's the one.

If so, two thoughts/worries:

  • HIPS/autograd seems to not be actively developed anymore, per the message on the main page of the repository. Incidentally, the four main devs seem to now be working on jax?

That's true. We are considering forking autograd ourselves and maintaining a version of it, since we're mainly just using it for auto-diff. jax is proving quite challenging to work with for auto-diff alone.

  1. jax and jaxlib are ~50 MB, whereas autograd is just ~50kB. So there is much less of an issue making autograd a core dependency. it also has very few dependencies of its own, mainly just numpy
  2. Many users have installation issues with jax.
  • On a personal node, I'm relying quite heavily on jax, not just for gradients, but also for JIT optimization + GPU-based processing of output data, sympy support (see sympy's codegen, which is extra nice together with the optics module), and possibly sharding in the future.
    • Perhaps I'd just like to give feedback that jax support in tidy3d truly is a "killer feature" for me (and maybe others), for far more reasons than just gradients.
    • For example, CPU-based jax alone is allowing me to manipulate FieldTime monitor output as real-time, interactively adjustable videos of field monitors, which should scale even to large / volumetric fields due to jax's GPU support.

It would be interesting to learn more about how you use jax + tidy3d for JIT and GPU processing on the front end. These features seem to not work for me with tidy3d.

We are planning to support autograd 'natively'. but also write converters from jax/pytorch/tensorflow to tidy3d autograd. So you should still be able to use jax auto-diff features. And we'll keep the adjoint plugin around, although we probably won't develop much for it.

so-rose commented 3 months ago

jax and jaxlib are hundreds of MB, whereas autograd is just 50kB. So there is much less of an issue making autograd a core dependency.

Well, that makes sense! jaxlib[cpu] is indeed quite big; I have no GPU kernels installed right now, but I can see it eats 250MB.

As mentioned, I'm currently sticking to jaxlib[cpu], to validate my methodology. However, I'm keeping a rather strict "jit/vmap/etc. everything" usage pattern that plays nice with the way GPUs like their memory, and I'd be surprised if jaxlib[cuda12] were to have trouble with the XLA code I'd like it to run. After all, the whole point of jax is to make GPU-based machine learning go fast.

It would be interesting to learn more about how you use jax + tidy3d for JIT and GPU processing on the front end.

Sure, if you're curious. So, all of this is in the context of a visual node-based system I'm building in Blender for my BSc thesis (https://github.com/so-rose/blender_maxwell/; I wouldn't try to install it right now, it's quite unstable as of writing).

Tl;dr: My use case is not "normal" by a long stretch, but I just deal with the data directly by leaving xarray land as fast as possible to take advantage of jax's happy path of `jit everything and run only that".

Each "math node" composes a function lazily, with data originating from an "extract data" node that pulls out a jax-compatible array sourced directly from xarray's data attribute (don't think a copy happens, but if so, only one).

My completely average laptop can process a 100x1x100x200 complex frequency-domain FieldMonitor with a squeeze, frequency index selection, any real/imag/abs, and an BW->RGB interpolated colormap on the order of microseconds. I've yet to throw any operations that meaningfully makes it slow slow down (even things like computing the determinant of matrix-folded F-X-Y data, SVDs, etc.). Which is why I'm quite convinced that far larger datasets will scale beautifully when I flick the "GPU" button; 41ms (24fps) really is a long, long time for most data sizes that one cares to deal with to bog down modern GPUs.

Of course, generally one just needs to deduce things like transmission losses, alone for which all of this is vastly over-engineered. But my hope is that this approach is also flexible enough to do some really exotic stuff. Especially loss functions. Or trying to deduce which presumptions differ from simulations to experiments.

These features seem to not work for me with tidy3d.

I mean, there's relatively little "just works" if we're talking about directly @jiting a lot of what Tidy3D comes with. But honestly, it's not a show stopper - data is data, after all.

We are planning to support autograd 'natively'. but also write converters from jax/pytorch/tensorflow to tidy3d autograd.

Fantastic. I'd be happy to give feedback on the jax side of things once it's ready for user consumption. Though I'll be sticking with the adjoint plugin for now, of course.

I hope it's at least interesting why I ask about jax support! It can be a bit sharp. But sometimes also magical!