patrick-kidger / diffrax

Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
Apache License 2.0
1.37k stars 124 forks source link

Problem for adding gradient penalty into sde-gan #419

Open Dvunty opened 3 months ago

Dvunty commented 3 months ago

Hi Patrick, Firstly, I applaud the excellent paper: Neural SDEs as Infinite-Dimensional GANs. According to your paper, I am trying to add gradient penalty into your code: /examples/neural_sde.ipynb .

# x_hat through interpolation
interps = epsilon * ys_i + (1 - epsilon) * fake_ys_i

# E_D(x_hat)
def loss_x_(discriminator, ys, ts):
    return dis(ts, ys).mean()

# gradient penalty function
def grad_penalty(discriminator, ys, ts):
    loss_grad = jax.grad(loss_x_, argnums=1)(discriminator, ys, ts)
    grad_norm = jnp.sqrt(jnp.sum(loss_grad ** 2))
    return jnp.mean((grad_norm - 1.0) ** 2)

# gradient penalty
gradient_penalty = eqx.filter_grad(grad_penalty)(discriminator, interps, ts_i)

But after my debugging, I found that eqx.filter_grad doesn't seem to accept jax.grad. I'm a newbie to jax and there's nothing I can do about this. I would be grateful if you could give me some ideas. Last, Error show: ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using fori_loop with static start/stop.

patrick-kidger commented 3 months ago

Take a look at the example on calculating second order derivatives: https://docs.kidger.site/diffrax/examples/hessian/

Which might help.

Dvunty commented 3 months ago

Thank you very much, the problem has been solved perfectly. (•̀ᴗ• ) ̑̑