NeuroDiffGym / neurodiffeq

A library for solving differential equations using neural networks based on PyTorch, used by multiple research groups around the world, including at Harvard IACS.
http://pypi.org/project/neurodiffeq/
MIT License
664 stars 87 forks source link

Missing eq_param_index when loading BundleSolution #211

Open ptflores1 opened 6 months ago

ptflores1 commented 6 months ago

When calling BundleSolution1D.load() the new instance is initialized without an eq_param_index argument. Since the _diff_eqs_wrapper function uses the local variable eq_param_index there is no way of recovering the original value and the functionality is lost.

ptflores1 commented 6 months ago

The code in https://github.com/NeuroDiffGym/neurodiffeq/blob/1b6d8c5679b8403a31470c446e8fcf57e186b55f/neurodiffeq/solvers_utils.py#L519-L533 can be modified to

elif load_dict["type_name"] == "BundleSolver1D" or load_dict["parent_type_name"] == "BundleSolver1D": 
     t_min = load_dict['solver'].r_min[0] 
     t_max = load_dict['solver'].r_max[0] 

     solver = cls(ode_system=de_system, 
                  conditions=cond, 
                  metrics=load_dict['metrics'], 
                  nets=nets, 
                  optimizer=optimizer, 
                  train_generator=train_generator, 
                  valid_generator=valid_generator, 
                  t_min=t_min, 
                  t_max=t_max, 
                  theta_min=tuple(load_dict['solver'].r_min[1:]), 
                  theta_max=tuple(load_dict['solver'].r_max[1:]),
                  eq_param_index=(index - len(cond) - 1 for index in load_dict['solver'].eq_param_index)) # new line
) 
ptflores1 commented 6 months ago

@shuheng-liu would that be okay for a PR?

sathvikbhagavan commented 6 months ago

Thanks for the report! I don't think we need to recompute as it is stored in https://github.com/NeuroDiffGym/neurodiffeq/blob/master/neurodiffeq/solvers.py#L1354

It should be fixed in #212

ptflores1 commented 6 months ago

Great! Thanks