LouisDesdoigts / zodiax

Object-oriented Jax framework extending Equinox for scientific programming
https://louisdesdoigts.github.io/zodiax/
BSD 3-Clause "New" or "Revised" License
11 stars 1 forks source link

Add `get_jit_model`. #4

Closed LouisDesdoigts closed 1 year ago

LouisDesdoigts commented 1 year ago

Re this issue https://github.com/LouisDesdoigts/dLux/issues/187

LouisDesdoigts commented 1 year ago

Add this to experimental

def float_from_0d(x):
    if isinstance(x, np.ndarray):
        return float(x) if x.ndim == 0 else x
    else:
        return x

def get_jit_model(model, args):
    opt, non_opt = eqx.partition(model, args)
    float_model = jax.tree_map(float_from_0d, non_opt)
    return eqx.combine(opt, float_model)

Modify this to just take in the parameters and do the boolean mapping in the function.