MilesCranmer / lagrangian_nns

Lagrangian Neural Networks
Apache License 2.0
449 stars 90 forks source link

Issue with gln loss #6

Open MammaM14 opened 1 year ago

MammaM14 commented 1 year ago

Hello! I was trying to use the code in the "experiment_dblpend" directory. As long as it is training the neural network with the loss "baseline_nn", everything is fine.

The moment I try to use the loss "gln" (that should direct me to learning like the paper) something wrong happens. The resulting error is as follows (of course, without touching the code)

TypeError: Gradient only defined for scalar-output functions. Output had shape: (4,).

This error refers to the application of gln_loss at the following line

preds = jax.vmap(partial(lagrangian_eom, learned_dynamics(params)))(state)

Specifically to the command

(jax.grad(lagrangian, 0)(q, q_t)

of the function lagrangian_eom.

I honestly cannot explain this, because everything seems correct

MilesCranmer commented 1 year ago

Hm, this seems weird. The code in the notebooks still works though, right? Maybe it is just you need to jnp.sum the output?

MammaM14 commented 1 year ago

Sorry, I forgot to mention that in the "learned_dynamics" function I had to remove "jnp.squeeze"

jnp.squeeze(nn_forward_fn(params, state), axis=-1)

because otherwise the code would not work.

Sorry for the trouble, your project is really well done, most likely it is me who is struggling with JAX

MilesCranmer commented 1 year ago

Oh, I see. Is that a bug in the code (maybe due to new JAX updates)? If so do you think you could submit a PR to patch it? It would be very appreciated if you could (currently drowning in grant writing...)!

Cheers, Miles