Closed LouisDesdoigts closed 1 year ago
Add this to experimental
def float_from_0d(x):
if isinstance(x, np.ndarray):
return float(x) if x.ndim == 0 else x
else:
return x
def get_jit_model(model, args):
opt, non_opt = eqx.partition(model, args)
float_model = jax.tree_map(float_from_0d, non_opt)
return eqx.combine(opt, float_model)
Modify this to just take in the parameters and do the boolean mapping in the function.
Re this issue https://github.com/LouisDesdoigts/dLux/issues/187