Open tylerflex opened 1 year ago
I think this should be possible.. hopefully in the next few months we can implement it.
FYI: We are in the process of deprecating the adjoint (jax) plugin and making tidy3d natively compatible with autograd
. So you should be able to just take your existing functions using regular tidy3d components and call autograd.grad()
on them without any API modifications.
Once we get this working, implementing things like triangle mesh will be a lot simpler.
@tylerflex That sounds amazing! I'm glad it's in the plan. Just to make sure I'm understanding correctly; the autograd
is in reference to this? https://github.com/HIPS/autograd
If so, two thoughts/worries:
HIPS/autograd
seems to not be actively developed anymore, per the message on the main page of the repository. Incidentally, the four main devs seem to now be working on jax
?jax
, not just for gradients, but also for JIT optimization + GPU-based processing of output data, sympy
support (see sympy
's codegen, which is extra nice together with the optics module), and possibly sharding in the future.
jax
support in tidy3d
truly is a "killer feature" for me (and maybe others), for far more reasons than just gradients.jax
alone is allowing me to manipulate FieldTime
monitor output as real-time, interactively adjustable videos of field monitors, which should scale even to large / volumetric fields due to jax
's GPU support.@tylerflex That sounds amazing! I'm glad it's in the plan. Just to make sure I'm understanding correctly; the
autograd
is in reference to this? https://github.com/HIPS/autograd
Yea that's the one.
If so, two thoughts/worries:
HIPS/autograd
seems to not be actively developed anymore, per the message on the main page of the repository. Incidentally, the four main devs seem to now be working onjax
?
That's true. We are considering forking autograd ourselves and maintaining a version of it, since we're mainly just using it for auto-diff. jax
is proving quite challenging to work with for auto-diff alone.
- On a personal node, I'm relying quite heavily on
jax
, not just for gradients, but also for JIT optimization + GPU-based processing of output data,sympy
support (seesympy
's codegen, which is extra nice together with the optics module), and possibly sharding in the future.
- Perhaps I'd just like to give feedback that
jax
support intidy3d
truly is a "killer feature" for me (and maybe others), for far more reasons than just gradients.- For example, CPU-based
jax
alone is allowing me to manipulateFieldTime
monitor output as real-time, interactively adjustable videos of field monitors, which should scale even to large / volumetric fields due tojax
's GPU support.
It would be interesting to learn more about how you use jax + tidy3d for JIT and GPU processing on the front end. These features seem to not work for me with tidy3d.
We are planning to support autograd 'natively'. but also write converters from jax/pytorch/tensorflow to tidy3d autograd. So you should still be able to use jax
auto-diff features. And we'll keep the adjoint plugin around, although we probably won't develop much for it.
jax and jaxlib are hundreds of MB, whereas autograd is just 50kB. So there is much less of an issue making autograd a core dependency.
Well, that makes sense! jaxlib[cpu]
is indeed quite big; I have no GPU kernels installed right now, but I can see it eats 250MB.
.venv
, xla_extension.so
is the culprit - which must be a brute-force "include everything" set of architecture-specific binary procedures? I wonder if anything can be done. It's probably hard, though. Unrolled loops and all that.As mentioned, I'm currently sticking to jaxlib[cpu]
, to validate my methodology. However, I'm keeping a rather strict "jit/vmap/etc. everything" usage pattern that plays nice with the way GPUs like their memory, and I'd be surprised if jaxlib[cuda12]
were to have trouble with the XLA code I'd like it to run. After all, the whole point of jax
is to make GPU-based machine learning go fast.
It would be interesting to learn more about how you use jax + tidy3d for JIT and GPU processing on the front end.
Sure, if you're curious. So, all of this is in the context of a visual node-based system I'm building in Blender for my BSc thesis (https://github.com/so-rose/blender_maxwell/; I wouldn't try to install it right now, it's quite unstable as of writing).
Tl;dr: My use case is not "normal" by a long stretch, but I just deal with the data directly by leaving xarray
land as fast as possible to take advantage of jax
's happy path of `jit everything and run only that".
Each "math node" composes a function lazily, with data originating from an "extract data" node that pulls out a jax-compatible array sourced directly from xarray
's data
attribute (don't think a copy happens, but if so, only one).
jnp
operations (only jnp
), run through a lambdify
'ed sympy
expression, or declare parameters which must be all eventually inserted in the top-level function. This might sound restrictive, but I promise it isn't!@jit
, which compiles the super-inefficient function of function of ... into optimized XLA bytecode, which by its nature should run on anything jax
supports.jax
) lazily load data from a disk, for final evaluation only when the result is needed.jit
and (implicitly) caches it, then runs it to produce pixels that are blitted directly to a Blender image buffer (which I've maximally seen take 3ms
for a high-res image), so the user can see the result almost instantly.My completely average laptop can process a 100x1x100x200 complex frequency-domain FieldMonitor
with a squeeze, frequency index selection, any real/imag/abs, and an BW->RGB interpolated colormap on the order of microseconds. I've yet to throw any operations that meaningfully makes it slow slow down (even things like computing the determinant of matrix-folded F-X-Y data, SVDs, etc.). Which is why I'm quite convinced that far larger datasets will scale beautifully when I flick the "GPU" button; 41ms
(24fps) really is a long, long time for most data sizes that one cares to deal with to bog down modern GPUs.
Of course, generally one just needs to deduce things like transmission losses, alone for which all of this is vastly over-engineered. But my hope is that this approach is also flexible enough to do some really exotic stuff. Especially loss functions. Or trying to deduce which presumptions differ from simulations to experiments.
These features seem to not work for me with tidy3d.
I mean, there's relatively little "just works" if we're talking about directly @jit
ing a lot of what Tidy3D comes with. But honestly, it's not a show stopper - data is data, after all.
%s/np/jnp/g
any functions (like amp_time
for source time dependence), to get around the primary differences between np
and jnp
. It's not ideal, but it's manageable in the few cases I need it.
numpy
does it on its own, jax
loses the ability to trace), which jax
really hates, as well as anything that changes the array shape (jax
simply doesn't jit
with dynamically sized arrays; it's a design choice). Simply using jnp
fixes the first, not the second, but Tidy3D seems generally designed in a way that makes the second not such a big issue.xarray
really is lovely, but the happy path in my case has been to extract the raw data (for composing high performance operations on) and manually tracking index names, coordinates, etc. on the side (for deducing which operations exactly to compose). So, reinventing tensors again, I suppose!We are planning to support autograd 'natively'. but also write converters from jax/pytorch/tensorflow to tidy3d autograd.
Fantastic. I'd be happy to give feedback on the jax
side of things once it's ready for user consumption. Though I'll be sticking with the adjoint plugin for now, of course.
I hope it's at least interesting why I ask about jax
support! It can be a bit sharp. But sometimes also magical!
Just 2c; massive +1 for this.
My integration relies on geometry generated using Blender 3D's "Geometry Nodes", a visual scripting system for intuitive and interactive geometry generation.
With the flexibility of arbitrary triangle meshes comes some caveats; namely, that the scripted parameterized inputs to geometry-generating GeoNodes trees must be realized at the time of generating a
td.TriangleMesh
object. The mesh generation is on some level just a function, Blender-the-software must be invoked somewhere in order to actually run this function.I imagine this is far from the most exotic geometry generation floating around. Thus, the optimization loop of any "generate arbitrary geometry w/optimization of symbolic inputs" would probably need to be local+cloud loop, ex. with
td.TriangleMesh
wrapping a non-jax
callback while itself being registered withjax
.So, not easy. Still, if this were to be undertaken, then almost infinite flexibility would be available to us users with very little effort, when it comes to structures that can be optimized! Which is very fun.