Minor thing, but if its not too much work I think it would be cool if one could define multiple jax functions (w different parameters) without the need for **kwargs -- as already possible with (non-jit) regular python functions.
The following example does not work if we drop **kwargs.
Minor thing, but if its not too much work I think it would be cool if one could define multiple jax functions (w different parameters) without the need for
**kwargs
-- as already possible with (non-jit) regular python functions. The following example does not work if we drop**kwargs
.