qutip / qutip-jax

JAX backend for QuTiP
BSD 3-Clause "New" or "Revised" License
15 stars 7 forks source link

kwargs requirement for jit (coefficient) functions #29

Open flowerthrower opened 7 months ago

flowerthrower commented 7 months ago

Minor thing, but if its not too much work I think it would be cool if one could define multiple jax functions (w different parameters) without the need for **kwargs -- as already possible with (non-jit) regular python functions. The following example does not work if we drop **kwargs.

def sin(t, p):
    return p[0] * jnp.sin(p[1] * t + p[2])

@jax.jit
def sin_x(t, p, **kwargs): return sin(t, p)

@jax.jit
def sin_y(t, q, **kwargs): return sin(t, q)

H = [[qt.sigmax(), sin_x],
     [qt.sigmay(), sin_y]]

evo = qt.mesolve(H, qt.basis(2, 0), tlist=[0, 1], 
                 args={'p': [ 1, 1, 0], 'q': [ 1, 1, 0]}, 
                 options={'method': 'diffrax'})