PlasmaControl / DESC

Stellarator Equilibrium and Optimization Suite
MIT License
78 stars 16 forks source link

Pre-compute `Curve` and `Surface` transforms in objectives that compute the magnetic field #1079

Open ddudt opened 1 week ago

ddudt commented 1 week ago

QuadraticFlux.compute calls _Coil.compute_magnetic_field, which calls Curve.compute. Then the error is being thrown at the line:

if transforms is None:
    transforms = get_transforms(
        names, obj=self, grid=grid, jitable=True, **kwargs
    )

You fixed the issue when creating these transforms inside a jitable function. But it is wasteful to keep re-building these same transforms each time the objective is computed. It would be safer and more efficient to pre-compute these transforms in QuadraticFlux.build and then pass them through to the relevant compute functions. That would also fix this bug and speed up the code.

Originally posted by @ddudt in https://github.com/PlasmaControl/DESC/pull/1069#pullrequestreview-2138807945

dpanici commented 1 week ago

Thinking of an elegant way to do this is a bit difficult, as the compute_magnetic_field function lies outside of the normal compute hierarchy that is used when we call get_transforms. As the code stands right now, we would have to put some clunky logic inside the build of every single objective function where compute_magnetic_field is used to check if the field passed in is a Coil or not, and if it is, to compute the required transforms for it.

f0uriest commented 1 week ago

I agree this is something we should figure out longer term, but as @dpanici points out it's actually a non-trivial amount of work to get it to work well. FWIW, all the transforms used by coils and magnetic field classes are just fourier series so are basically free to compute, so im not even sure if it saves that much time to precompute them. Also, in theory JAX may optimize that away as long as the grids are static.

dpanici commented 6 days ago

Also edited title as we should be able to pre-compute for Surface objects as well FourierCurrentPotentialField, though again this is not super straightforward to do elegantly

dpanici commented 6 days ago

Just to leave a further note:

It is not hard to do this on an objective-by-objective basis for lone magnetic field objects, as we can simply check what the type of object is (FourierRZCoil, FourierCurrentPotentialField, ToroidalMagneticField, etc) and then create the transform with the appropriate grid for that object. However this is

One way forward could be to have a utility function which

Then use a PyTree utility to apply this utility function to each leaf of the given MagneticField object (each leaf being any MagneticField object which is not a MixedCoilSet or SumMagneticField) and get back the correctly pre-computed transforms for that MagneticField PyTree object.

The ideal way which avoids all the if statements and the necessity of a utility function to begin with would be to incorporate the B computation from compute_magnetic_field into the normal compute method, and add the MagneticField classes as parameterizations in the data index. But again this will also be a bit of work, just wanted to lay out the ideas