pnkraemer / probdiffeq

Probabilistic solvers for differential equations in JAX. Adaptive ODE solvers with calibration, state-space model factorisations, and custom information operators. Compatible with the broader JAX scientific computing ecosystem.
https://pnkraemer.github.io/probdiffeq/
MIT License
30 stars 2 forks source link

Reshape Error with v0.3.4 #744

Closed adam-hartshorne closed 2 months ago

adam-hartshorne commented 2 months ago

Using JAX / JaxLib v0.4.28 and upgrading to v0.3.4 results in the following error (which doesn't exist when using v0.3.3

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/media/adam/shared_drive/PycharmProjects/path_shape_arap_tests/learn_hand_deformation_node.py", line 1606, in <module>
    loss, model, opt_state = make_step(model, data, ode_output_scale, new_key, opt_state, optim, varifold_params)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/equinox/_jit.py", line 206, in __call__
    return self._call(False, args, kwargs)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/equinox/_module.py", line 1053, in __call__
    return self.__func__(self.__self__, *args, **kwargs)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/equinox/_jit.py", line 200, in _call
    out = self._cached(dynamic_donate, dynamic_nodonate, static)
  File "/media/adam/shared_drive/PycharmProjects/path_shape_arap_tests/learn_hand_deformation_node.py", line 1519, in make_step
    obj, grads = return_value_and_grad(model, data, ode_output_scale, varifold_params, random_key)
  File "/media/adam/shared_drive/PycharmProjects/path_shape_arap_tests/learn_hand_deformation_node.py", line 1514, in return_value_and_grad
    return calc_cost(model, data, ode_output_scale, varifold_params, random_key)
  File "/media/adam/shared_drive/PycharmProjects/path_shape_arap_tests/learn_hand_deformation_node.py", line 1414, in calc_cost
    sol, y_pred, y_pred_traj, rigid_transformed_sample_traj, transformed_sample_locations, sample_path, rigid_transformed_sample_field = model.calc_flow_field(
  File "/media/adam/shared_drive/PycharmProjects/path_shape_arap_tests/learn_hand_deformation_node.py", line 1363, in calc_flow_field
    sol, transformed_points, transformaed_points_traj = self.push_through_field_vectorised(x0=x_test,
  File "/media/adam/shared_drive/PycharmProjects/path_shape_arap_tests/learn_hand_deformation_node.py", line 1341, in push_through_field_vectorised
    sol = ivpsolve.solve_fixed_grid(
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/probdiffeq/ivpsolve.py", line 301, in solve_fixed_grid
    _, result_state = control_flow.scan(body_fn, init=state0, xs=np.diff(grid))
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/probdiffeq/backend/control_flow.py", line 31, in scan
    return _jax_scan(step_func, init=init, xs=xs, reverse=reverse, length=length)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/probdiffeq/ivpsolve.py", line 296, in body_fn
    _error, s_new = solver.step(state=s, vector_field=vector_field, dt=dt)
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/probdiffeq/solvers/uncalibrated.py", line 23, in step
    error, _observed, state_strategy = self.strategy.predict_error(
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/probdiffeq/solvers/strategies/strategy.py", line 100, in predict_error
    error, observed, corr = self.correction.estimate_error(
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/probdiffeq/solvers/strategies/components/corrections.py", line 55, in estimate_error
    observed = impl.transform.marginalise(hidden_state, (A, b))
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/probdiffeq/impl/isotropic/_transform.py", line 12, in marginalise
    cholesky_squeezed = np.reshape(cholesky_new, ())
  File "/home/adam/anaconda3/envs/jax_torch_latest/lib/python3.10/site-packages/probdiffeq/backend/numpy.py", line 52, in reshape
    return jnp.reshape(arr, shape=new_shape, order=order)
TypeError: reshape() got an unexpected keyword argument 'shape'
pnkraemer commented 2 months ago

This is unexpected (the tests pass with jax-version 0.4.28). Can you post a pip freeze?

adam-hartshorne commented 2 months ago

Sorry, my bad, I had forgotten I had changed Pycharm to use a virtual environment to an earlier version of JAX to test something.

I can confirm it works fine on v0.4.28

pnkraemer commented 2 months ago

Thanks for checking and clarifying! :)