Jacob-Stevens-Haas / gen-experiments

Pysindy experiments. Composable and extensible
MIT License
1 stars 2 forks source link

Handle different different simulation lengths when LSODA returns early #3

Open Jacob-Stevens-Haas opened 1 year ago

Jacob-Stevens-Haas commented 1 year ago

Sometimes in stiff equations (e.g. Rossler), the return from scipy.integrate.solve_ivp() contains fewer points than t_eval. x Do we reject these samples? Or adapt the code to not stack the arrays and allow different simulation lengths (and then bias the regression to the blow-up trajectories?)

For instance, the final line of the following raises ValueError: all input arrays must have the same shape :

import numpy as np
from pysindy.utils.odes import rossler
from scipy.integrate import solve_ivp

y0_1 = np.array([ 4.50040013, -1.72136214, -5.54952269])
y0_2 = np.array([ 0.41888012, -5.95648243,  2.85360632])
t_train = np.arange(0, 4, .01)

y_train = []
for y0 in (y0_1, y0_2):
    y = solve_ivp(
        rossler,
        y0=y0,
        t_span=(0.0, 3.99),
        t_eval=t_train,
        **{"rtol": 1e-12, "method": "LSODA", "atol": 1e-12},
    ).y.T
    y_train.append(y)

np.stack(y_train)

I'm thinking we drop them with a warning, and raise an error if we drop all trials.

Jacob-Stevens-Haas commented 1 year ago

Shows up in gridsearch.py, plot_test_trajectories()


---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[2], line 1
----> 1 results = module.run(seed, **args)

File ~/github/gen-experiments/src/gen_experiments/gridsearch.py:95, in run(seed, ex_name, grid_params, grid_vals, grid_decisions, other_params, series_params, metrics, plot_prefs, skinny_specs)
     91     curr_results, recent_data = base_ex.run(
     92         new_seed, **curr_other_params, display=False, return_all=True
     93     )
     94     if _params_match(curr_other_params, plot_prefs.grid_plot_match) and plot_prefs:
---> 95         plot_gridpoint(recent_data, curr_other_params)
     96     full_results[(slice(None), *ind)] = [
     97         curr_results[metric] for metric in metrics
     98     ]
     99 series_searches.append(_marginalize_grid_views(new_grid_decisions, full_results))

File ~/github/gen-experiments/src/gen_experiments/gridsearch.py:197, in plot_gridpoint(grid_data, other_params)
    190 plot_training_data(x_train, x_true, smooth_train)
    191 compare_coefficient_plots(
    192     grid_data["coefficients"],
    193     grid_data["coeff_true"],
    194     input_features=grid_data["input_features"],
    195     feature_names=grid_data["feature_names"],
    196 )
--> 197 plot_test_trajectories(grid_data["x_test"][sim_ind], model, grid_data["dt"])
    198 plt.show()

File ~/github/gen-experiments/src/gen_experiments/utils.py:467, in plot_test_trajectories(last_test, model, dt)
    465 for i in range(last_test.shape[1]):
    466     axs[i].plot(t_test, last_test[:, i], "k", label="true trajectory")
--> 467     axs[i].plot(t_test, x_test_sim[:, i], "r--", label="model simulation")
    468     axs[i].legend()
    469     axs[i].set(xlabel="t", ylabel="$x_{}$".format(i))

File ~/github/gen-experiments/env/lib/python3.10/site-packages/matplotlib/axes/_axes.py:1688, in Axes.plot(self, scalex, scaley, data, *args, **kwargs)
   1445 """
   1446 Plot y versus x as lines and/or markers.
   1447 
   (...)
   1685 (``'green'``) or hex strings (``'#008000'``).
   1686 """
   1687 kwargs = cbook.normalize_kwargs(kwargs, mlines.Line2D)
-> 1688 lines = [*self._get_lines(*args, data=data, **kwargs)]
   1689 for line in lines:
   1690     self.add_line(line)

File ~/github/gen-experiments/env/lib/python3.10/site-packages/matplotlib/axes/_base.py:311, in _process_plot_var_args.__call__(self, data, *args, **kwargs)
    309     this += args[0],
    310     args = args[1:]
--> 311 yield from self._plot_args(
    312     this, kwargs, ambiguous_fmt_datakey=ambiguous_fmt_datakey)

File ~/github/gen-experiments/env/lib/python3.10/site-packages/matplotlib/axes/_base.py:504, in _process_plot_var_args._plot_args(self, tup, kwargs, return_kwargs, ambiguous_fmt_datakey)
    501     self.axes.yaxis.update_units(y)
    503 if x.shape[0] != y.shape[0]:
--> 504     raise ValueError(f"x and y must have same first dimension, but "
    505                      f"have shapes {x.shape} and {y.shape}")
    506 if x.ndim > 2 or y.ndim > 2:
    507     raise ValueError(f"x and y can be no greater than 2D, but have "
    508                      f"shapes {x.shape} and {y.shape}")

ValueError: x and y must have same first dimension, but have shapes (1600,) and (1441,)