MilesCranmer / lagrangian_nns

Lagrangian Neural Networks
Apache License 2.0
461 stars 92 forks source link

Issues when I open the main branch on google colab and use the `base-nn-double-pendulum.ipynb` #3

Open MariosGkMeng opened 1 year ago

MariosGkMeng commented 1 year ago

Issues when I open the main branch on google colab and use the base-nn-double-pendulum.ipynb:

  1. cannot import from jax.example_libraries import stax. Fixed it by replacing he line with: from jax.example_libraries import stax
  2. issues with the HyperParameterSearch.py:
    • ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 3) of type <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'> for function _odeint_wrapper is non-hashable. for the function odeint.
  3. 
    The above exception was the direct cause of the following exception:

TypeError Traceback (most recent call last) /content/lagrangian_nns/experiment_dblpend/data.py in get_trajectory_analytic(y0, times, kwargs) 38 @partial(jax.jit, backend='cpu') 39 def get_trajectory_analytic(y0, times, kwargs): ---> 40 return odeint(analytical_fn, y0, t=times, rtol=1e-10, atol=1e-10, kwargs) 41 42 def get_dataset(seed=0, samples=1, t_span=[0,2000], fps=1, test_split=0.5, kwargs):

TypeError: odeint() got an unexpected keyword argument 'mxsteps'


 4. When I run `loss(get_params(opt_state), batch_data, 0.0)` I get:

/content/lagrangian_nns__modified/hyperopt/HyperparameterSearch.py in dynamics(q, q_t) 32 # assert q.shape == (2,) 33 state = wrap_coords(jnp.concatenate([q, q_t])) ---> 34 return jnp.squeeze(nn_forward_fn(params, state), axis=-1) 35 return dynamics 36

/usr/local/lib/python3.7/dist-packages/jax/numpy/lax_numpy.py in squeeze(a, axis) 1165 axis = frozenset(_canonicalize_axis(i, ndim(a)) for i in axis) 1166 if _any(shape_a[a] != 1 for a in axis): -> 1167 raise ValueError("cannot select an axis to squeeze out which has size " 1168 "not equal to one") 1169 newshape = [d for i, d in enumerate(shape_a)

ValueError: cannot select an axis to squeeze out which has size not equal to one



I managed to fix this by changing: 
THIS COMMAND: return jnp.squeeze(nn_forward_fn(params, state), axis=-1)
TO THIS COMMAND: return nn_forward_fn(params, state)
MilesCranmer commented 1 year ago

Thanks, will fix the issue about moved stax location soon. @greydanus do you want to update your colab notebook?

MilesCranmer commented 1 year ago

What line is issue 2 from?

MilesCranmer commented 1 year ago

Ah, JAX changed mxsteps to mxstep... Why make such a small breaking change I do not know...

MilesCranmer commented 1 year ago

PR in #4 will fix this.

MariosGkMeng commented 1 year ago

What line is issue 2 from?

Once I replicate the error (not receiving it right now), I will specify :)

Ah, JAX changed mxsteps to mxstep... Why make such a small breaking change I do not know...

Haha, actually it does ring a bell, but after many different packages and trials, I had forgotten about it. Thanks for fixing 😀

MariosGkMeng commented 1 year ago

@MilesCranmer I added a 4th issue in my initial comment with a (probably inproper) fix.

MilesCranmer commented 1 year ago

For the 4th issue, what is the full traceback? i.e., what chunk of code is it coming from?

You should try this replacement instead:

  def dynamics(q, q_t):
    state = wrap_coords(jnp.concatenate([q, q_t]))
    updated_state = nn_forward_fn(params, state)
    if len(updated_state.shape) == 2:
        return jnp.squeeze(updated_state, axis=-1)
    else:
        return updated_state

Not sure why one of the states has dim=1

MariosGkMeng commented 1 year ago

For the 4th issue, what is the full traceback? i.e., what chunk of code is it coming from?

Updated the comment on bullet number 4.

What line is issue 2 from?

I reproduced it just by having the most recent version of jax. See, I had reverted to an older version (0.1.68) and the error was not occuring. But I installed the newest again because 0.1.68 seemed to not be able to find my colab gpu, even though TensorFlow can find it.

MariosGkMeng commented 1 year ago

@MilesCranmer Also, I noticed that I did not fully reply to your question regarding issue 2:

The issue is triggered when I run from HyperparameterSearch import learned_dynamics

And only when I am using the recent jax version (0.3.25).

zdjordje123 commented 1 year ago

I am running JAX 0.4.6. I corrected import statements for stax and optimizers. I replaced mxsteps with mxstep in one place. I change the definition of function dynamic in files train.py and Hyperparametersearch.py. Should the recommended change be made in both places? I am still getting the following error inn cell [4] of DoublePendulum-Baseline.ipynb. The same error appears in other notebooks. error_lagrange.txt

MilesCranmer commented 1 year ago

Hi @zdjordje123, do you want to push your edits as a draft pull request and I can help work on it with you? It would be great if the code would be updated for modern JAX!

zdjordje123 commented 1 year ago

I would be glad to help. Unfortunately, I am not sure my edits are very useful. If you give me prices tasks I could work on them. Regards Zoran Djordjevic

zdjordje123 commented 1 year ago

I tried creating a virtual environment with JAX 0.1.68, which I see someone mentioned as providing error free runs. I was hoping to go in parallel through both old and new environments and see the differences. Unfortunately, my 0.1.68 JAX environment has no intention to work. I will have time to revisit this code only at the end of May. If someone migrates the code to the newest JAX, I will be happy to use it and reference it. Otherwise, I could try to migrate it myself, then. Best Regards Zoran Djordjevic

MilesCranmer commented 1 year ago

Will let you know. Busy time, but hope I can get a chance to update things to new JAX.

xzhuzhu commented 1 year ago

Hi, I am running JAX 0.4.13. I corrected import statements for stax and optimizers, replaced mxsteps with mxstep and changed the definition of function dynamic in files Hyperparametersearch.py. However, when I run from HyperparameterSearch import learned_dynamics, I got these errors: UnexpectedTracerError: Found a JAX Tracer object passed as an argument to a custom_vjp function in a position indicated by nondiff_argnums as non-differentiable. Tracers cannot be passed as non-differentiable arguments to custom_vjp functions; instead, nondiff_argnums should only be used for arguments that can't be or contain JAX tracers, e.g. function-valued arguments. In particular, array-valued arguments should typically not be indicated as nondiff_argnums. See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError. How can I fix it, thanks.

MilesCranmer commented 1 year ago

Hmm, I will take a look this weekend. Thanks for the report!!

MilesCranmer commented 1 year ago

(If you find a fix before I can get to it I will gladly accept a PR btw)

umerhuzaifa commented 1 year ago

Hi, I am running JAX 0.4.13. I corrected import statements for stax and optimizers, replaced mxsteps with mxstep and changed the definition of function dynamic in files Hyperparametersearch.py. However, when I run from HyperparameterSearch import learned_dynamics, I got these errors: UnexpectedTracerError: Found a JAX Tracer object passed as an argument to a custom_vjp function in a position indicated by nondiff_argnums as non-differentiable. Tracers cannot be passed as non-differentiable arguments to custom_vjp functions; instead, nondiff_argnums should only be used for arguments that can't be or contain JAX tracers, e.g. function-valued arguments. In particular, array-valued arguments should typically not be indicated as nondiff_argnums. See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError. How can I fix it, thanks.

Any updates on this issue?

MilesCranmer commented 11 months ago

Thanks for the ping and sorry for not finding enough time to fix this yet. Please keep pinging me though, eventually will be able to get it done.

(However, I will also say for anybody reading this that I am always immensely of PRs, and will gladly review + merge it!)

Raul-Create commented 6 months ago

bump :) I've been trying to solve this problem for a few hours. Can't get it to work! I'm on the "fix-stax-import" branch, all other problems seem fixed there.

MilesCranmer commented 6 months ago

One other option is to use https://github.com/astrofrog/pypi-timemachine to install the exact same versions of app dependencies as they existed at publication time. And maybe we can set those in requirements.txt

But yes I’d still like to fix this for recent JAX, I’ve just been pretty time deficient to do things myself.

MilesCranmer commented 4 months ago

Would someone be willing to test out the code in https://github.com/MilesCranmer/lagrangian_nns/pull/8? I've tried to fix most of the issues caused by JAX updates.