Implementing a constraint on max current #421

Open abaillod opened 4 months ago

abaillod commented 4 months ago


I tried implementing a constraint on the max current in a coil. The objective looks like this:

def current_penalty_pure(I, threshold):
    return jnp.maximum(abs(I) - threshold, 0)**2

class CurrentPenalty(Optimizable):
    A :obj:`CurrentPenalty` can be used to penalize
    large currents in coils.
    def __init__(self, current, threshold=0):
        self.current = current
        self.threshold = threshold

        self.J_jax = lambda I: current_penalty_pure(I, self.threshold)
        self.this_grad = lambda I: grad(self.J_jax, argnums=0)(I)


    def J(self):
        return self.J_jax(self.current.x[0])

    def dJ(self):
        grad0 = self.this_grad(self.current.x[0])
        return self.current.vjp(grad0)

However, when running the following simple test,

from simsopt.field import CurrentPenalty

c = Current(1e5)
test = CurrentPenalty(c)

I get the error

IndexError                                Traceback (most recent call last)
Cell In[3], line 5
      3 c = Current(1e5)
      4 test = CurrentPenalty(c)
----> 5 test.dJ()

File ~/Github/simsopt/src/simsopt/_core/derivative.py:217, in derivative_dec.<locals>._derivative_dec(self, partials, *args, **kwargs)
    215     return func(self, *args, **kwargs)
    216 else:
--> 217     return func(self, *args, **kwargs)(self)

File ~/Github/simsopt/src/simsopt/_core/derivative.py:185, in Derivative.__call__(self, optim, as_derivative)
    183 local_derivs = np.zeros(k.local_dof_size)
    184 for opt in k.dofs.dep_opts():
--> 185     local_derivs += self.data[opt][opt.local_dofs_free_status]
    186     keys.append(opt)
    187 derivs.append(local_derivs)

File /opt/homebrew/Caskroom/miniconda/base/envs/simsopt/lib/python3.8/site-packages/jax/_src/array.py:317, in ArrayImpl.__getitem__(self, idx)
    315   return lax_numpy._rewriting_take(self, idx)
    316 else:
--> 317   return lax_numpy._rewriting_take(self, idx)

File /opt/homebrew/Caskroom/miniconda/base/envs/simsopt/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:4142, in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
   4136     if (isinstance(aval, core.DShapedArray) and aval.shape == () and
   4137         dtypes.issubdtype(aval.dtype, np.integer) and
   4138         not dtypes.issubdtype(aval.dtype, dtypes.bool_) and
   4139         isinstance(arr.shape[0], int)):
   4140       return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
-> 4142 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
   4143 return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
   4144                unique_indices, mode, fill_value)

File /opt/homebrew/Caskroom/miniconda/base/envs/simsopt/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:4220, in _split_index_for_jit(idx, shape)
   4216   raise TypeError(f"JAX does not support string indexing; got {idx=}")
   4218 # Expand any (concrete) boolean indices. We can then use advanced integer
   4219 # indexing logic to handle them.
-> 4220 idx = _expand_bool_indices(idx, shape)
   4222 leaves, treedef = tree_flatten(idx)
   4223 dynamic = [None] * len(leaves)

File /opt/homebrew/Caskroom/miniconda/base/envs/simsopt/lib/python3.8/site-packages/jax/_src/numpy/lax_numpy.py:4542, in _expand_bool_indices(idx, shape)
   4540     expected_shape = shape[start: start + _ndim(i)]
   4541     if i_shape != expected_shape:
-> 4542       raise IndexError("boolean index did not match shape of indexed array in index "
   4543                        f"{dim_number}: got {i_shape}, expected {expected_shape}")
   4544     out.extend(np.where(i))
   4545 else:

IndexError: boolean index did not match shape of indexed array in index 0: got (1,), expected ()

Does anyone have an idea how to fix it?

andrewgiuliani commented 4 months ago

not sure, but one suggestion would be to implement this without Jax since it's such a simple penalty

smiet commented 4 months ago

Looks like the derivative decorator trips over the fact that the optimizable only has one dof. The Jax internals do not make it very clear, but check if current.dofs.dep_opts()[0].dofs_free_status is a boolean instead of a length one array thereof