Open ddudt opened 5 months ago
Thinking of an elegant way to do this is a bit difficult, as the compute_magnetic_field
function lies outside of the normal compute hierarchy that is used when we call get_transforms
. As the code stands right now, we would have to put some clunky logic inside the build of every single objective function where compute_magnetic_field
is used to check if the field
passed in is a Coil
or not, and if it is, to compute the required transforms for it.
I agree this is something we should figure out longer term, but as @dpanici points out it's actually a non-trivial amount of work to get it to work well. FWIW, all the transforms used by coils and magnetic field classes are just fourier series so are basically free to compute, so im not even sure if it saves that much time to precompute them. Also, in theory JAX may optimize that away as long as the grids are static.
Also edited title as we should be able to pre-compute for Surface objects as well FourierCurrentPotentialField
, though again this is not super straightforward to do elegantly
Just to leave a further note:
It is not hard to do this on an objective-by-objective basis for lone magnetic field objects, as we can simply check what the type of object is (FourierRZCoil
, FourierCurrentPotentialField
, ToroidalMagneticField
, etc) and then create the transform with the appropriate grid for that object. However this is
MixedCoilSet
or SumMagneticField
objects, where it matters internally which object gets which grid, and would require some pyTree structure like the _Coil
objectives have.One way forward could be to have a utility function which
get_transforms
with that grid and returns the transform dict for that objectThen use a PyTree utility to apply this utility function to each leaf of the given MagneticField
object (each leaf being any MagneticField
object which is not a MixedCoilSet
or SumMagneticField
) and get back the correctly pre-computed transforms for that MagneticField
PyTree object.
The ideal way which avoids all the if statements and the necessity of a utility function to begin with would be to incorporate the B
computation from compute_magnetic_field
into the normal compute
method, and add the MagneticField
classes as parameterizations in the data index. But again this will also be a bit of work, just wanted to lay out the ideas
QuadraticFlux.compute
calls_Coil.compute_magnetic_field
, which callsCurve.compute
. Then the error is being thrown at the line:You fixed the issue when creating these transforms inside a jitable function. But it is wasteful to keep re-building these same transforms each time the objective is computed. It would be safer and more efficient to pre-compute these transforms in
QuadraticFlux.build
and then pass them through to the relevant compute functions. That would also fix this bug and speed up the code.Originally posted by @ddudt in https://github.com/PlasmaControl/DESC/pull/1069#pullrequestreview-2138807945