Open MariosGkMeng opened 1 year ago
Thanks, will fix the issue about moved stax
location soon. @greydanus do you want to update your colab notebook?
What line is issue 2 from?
Ah, JAX changed mxsteps
to mxstep
... Why make such a small breaking change I do not know...
PR in #4 will fix this.
What line is issue 2 from?
Once I replicate the error (not receiving it right now), I will specify :)
Ah, JAX changed mxsteps to mxstep... Why make such a small breaking change I do not know...
Haha, actually it does ring a bell, but after many different packages and trials, I had forgotten about it. Thanks for fixing 😀
@MilesCranmer I added a 4th issue in my initial comment with a (probably inproper) fix.
For the 4th issue, what is the full traceback? i.e., what chunk of code is it coming from?
You should try this replacement instead:
def dynamics(q, q_t):
state = wrap_coords(jnp.concatenate([q, q_t]))
updated_state = nn_forward_fn(params, state)
if len(updated_state.shape) == 2:
return jnp.squeeze(updated_state, axis=-1)
else:
return updated_state
Not sure why one of the states has dim=1
For the 4th issue, what is the full traceback? i.e., what chunk of code is it coming from?
Updated the comment on bullet number 4.
What line is issue 2 from?
I reproduced it just by having the most recent version of jax. See, I had reverted to an older version (0.1.68) and the error was not occuring. But I installed the newest again because 0.1.68 seemed to not be able to find my colab gpu, even though TensorFlow can find it.
@MilesCranmer Also, I noticed that I did not fully reply to your question regarding issue 2:
The issue is triggered when I run from HyperparameterSearch import learned_dynamics
And only when I am using the recent jax version (0.3.25).
I am running JAX 0.4.6. I corrected import statements for stax and optimizers. I replaced mxsteps with mxstep in one place. I change the definition of function dynamic in files train.py and Hyperparametersearch.py. Should the recommended change be made in both places? I am still getting the following error inn cell [4] of DoublePendulum-Baseline.ipynb. The same error appears in other notebooks. error_lagrange.txt
Hi @zdjordje123, do you want to push your edits as a draft pull request and I can help work on it with you? It would be great if the code would be updated for modern JAX!
I would be glad to help. Unfortunately, I am not sure my edits are very useful. If you give me prices tasks I could work on them. Regards Zoran Djordjevic
I tried creating a virtual environment with JAX 0.1.68, which I see someone mentioned as providing error free runs. I was hoping to go in parallel through both old and new environments and see the differences. Unfortunately, my 0.1.68 JAX environment has no intention to work. I will have time to revisit this code only at the end of May. If someone migrates the code to the newest JAX, I will be happy to use it and reference it. Otherwise, I could try to migrate it myself, then. Best Regards Zoran Djordjevic
Will let you know. Busy time, but hope I can get a chance to update things to new JAX.
Hi, I am running JAX 0.4.13. I corrected import statements for stax and optimizers, replaced mxsteps with mxstep and changed the definition of function dynamic in files Hyperparametersearch.py. However, when I run from HyperparameterSearch import learned_dynamics, I got these errors: UnexpectedTracerError: Found a JAX Tracer object passed as an argument to a custom_vjp function in a position indicated by nondiff_argnums as non-differentiable. Tracers cannot be passed as non-differentiable arguments to custom_vjp functions; instead, nondiff_argnums should only be used for arguments that can't be or contain JAX tracers, e.g. function-valued arguments. In particular, array-valued arguments should typically not be indicated as nondiff_argnums. See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError. How can I fix it, thanks.
Hmm, I will take a look this weekend. Thanks for the report!!
(If you find a fix before I can get to it I will gladly accept a PR btw)
Hi, I am running JAX 0.4.13. I corrected import statements for stax and optimizers, replaced mxsteps with mxstep and changed the definition of function dynamic in files Hyperparametersearch.py. However, when I run from HyperparameterSearch import learned_dynamics, I got these errors: UnexpectedTracerError: Found a JAX Tracer object passed as an argument to a custom_vjp function in a position indicated by nondiff_argnums as non-differentiable. Tracers cannot be passed as non-differentiable arguments to custom_vjp functions; instead, nondiff_argnums should only be used for arguments that can't be or contain JAX tracers, e.g. function-valued arguments. In particular, array-valued arguments should typically not be indicated as nondiff_argnums. See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.UnexpectedTracerError. How can I fix it, thanks.
Any updates on this issue?
Thanks for the ping and sorry for not finding enough time to fix this yet. Please keep pinging me though, eventually will be able to get it done.
(However, I will also say for anybody reading this that I am always immensely of PRs, and will gladly review + merge it!)
bump :) I've been trying to solve this problem for a few hours. Can't get it to work! I'm on the "fix-stax-import" branch, all other problems seem fixed there.
One other option is to use https://github.com/astrofrog/pypi-timemachine to install the exact same versions of app dependencies as they existed at publication time. And maybe we can set those in requirements.txt
But yes I’d still like to fix this for recent JAX, I’ve just been pretty time deficient to do things myself.
Would someone be willing to test out the code in https://github.com/MilesCranmer/lagrangian_nns/pull/8? I've tried to fix most of the issues caused by JAX updates.
Issues when I open the main branch on google colab and use the
base-nn-double-pendulum.ipynb
:cannot import from jax.example_libraries import stax
. Fixed it by replacing he line with:from jax.example_libraries import stax
HyperParameterSearch.py
:ValueError: Non-hashable static arguments are not supported, as this can lead to unexpected cache-misses. Static argument (index 3) of type <class 'jax.interpreters.partial_eval.DynamicJaxprTracer'> for function _odeint_wrapper is non-hashable.
for the functionodeint
.TypeError Traceback (most recent call last) /content/lagrangian_nns/experiment_dblpend/data.py in get_trajectory_analytic(y0, times, kwargs) 38 @partial(jax.jit, backend='cpu') 39 def get_trajectory_analytic(y0, times, kwargs): ---> 40 return odeint(analytical_fn, y0, t=times, rtol=1e-10, atol=1e-10, kwargs) 41 42 def get_dataset(seed=0, samples=1, t_span=[0,2000], fps=1, test_split=0.5, kwargs):
TypeError: odeint() got an unexpected keyword argument 'mxsteps'
/content/lagrangian_nns__modified/hyperopt/HyperparameterSearch.py in dynamics(q, q_t) 32 # assert q.shape == (2,) 33 state = wrap_coords(jnp.concatenate([q, q_t])) ---> 34 return jnp.squeeze(nn_forward_fn(params, state), axis=-1) 35 return dynamics 36
/usr/local/lib/python3.7/dist-packages/jax/numpy/lax_numpy.py in squeeze(a, axis) 1165 axis = frozenset(_canonicalize_axis(i, ndim(a)) for i in axis) 1166 if _any(shape_a[a] != 1 for a in axis): -> 1167 raise ValueError("cannot select an axis to squeeze out which has size " 1168 "not equal to one") 1169 newshape = [d for i, d in enumerate(shape_a)
ValueError: cannot select an axis to squeeze out which has size not equal to one