ciemss / pyciemss

Causal and probabilistic reasoning with continuous time dynamical systems
Other
12 stars 4 forks source link

AssertionError: Event handling for fixed step solvers currently requires `step_size` to be provided in options. #502

Closed sabinala closed 4 months ago

sabinala commented 4 months ago

I'm getting the error message below when I try:

def make_var_threshold(var: str, threshold: torch.Tensor):
    return lambda time, state: state[var] - threshold  

infection_threshold = make_var_threshold("I", torch.tensor(400.0))
dynamic_parameter_interventions1 = {infection_threshold: {"p_cbeta": torch.tensor(0.3)}}

result = pyciemss.sample(model3, end_time, logging_step_size, num_samples, start_time=start_time, 
                         dynamic_parameter_interventions=dynamic_parameter_interventions1, 
                         solver_method="euler")

Where model3 = os.path.join(MODEL_PATH, "SIR_stockflow.json") and MODEL_PATH = "https://raw.githubusercontent.com/DARPA-ASKEM/simulation-integration/main/data/models/"

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[6], line 10
      7 infection_threshold = make_var_threshold("I", torch.tensor(400.0))
      8 dynamic_parameter_interventions1 = {infection_threshold: {"p_cbeta": torch.tensor(0.3)}}
---> 10 result = pyciemss.sample(model3, end_time, logging_step_size, num_samples, start_time=start_time, 
     11                          dynamic_parameter_interventions=dynamic_parameter_interventions1, 
     12                          solver_method="euler")
     13 display(result["data"].head())
     15 # Plot the result

File ~/Projects/pyciemss/pyciemss/integration_utils/custom_decorators.py:29, in pyciemss_logging_wrapper.<locals>.wrapped(*args, **kwargs)
     17 log_message = """
     18     ###############################
     19 
   (...)
     26     ################################
     27 """
     28 logging.exception(log_message, function.__name__, function.__doc__)
---> 29 raise e

File ~/Projects/pyciemss/pyciemss/integration_utils/custom_decorators.py:10, in pyciemss_logging_wrapper.<locals>.wrapped(*args, **kwargs)
      8 try:
      9     start_time = time.perf_counter()
---> 10     result = function(*args, **kwargs)
     11     end_time = time.perf_counter()
     12     logging.info(
     13         "Elapsed time for %s: %f", function.__name__, end_time - start_time
     14     )

File ~/Projects/pyciemss/pyciemss/interfaces.py:298, in sample(model_path_or_json, end_time, logging_step_size, num_samples, noise_model, noise_model_kwargs, solver_method, solver_options, start_time, inferred_parameters, static_state_interventions, static_parameter_interventions, dynamic_state_interventions, dynamic_parameter_interventions)
    294         compiled_noise_model(full_trajectory)
    296 parallel = False if len(intervention_handlers) > 0 else True
--> 298 samples = pyro.infer.Predictive(
    299     wrapped_model,
    300     guide=inferred_parameters,
    301     num_samples=num_samples,
    302     parallel=parallel,
    303 )()
    305 return prepare_interchange_dictionary(samples)

File ~/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py:273, in Predictive.forward(self, *args, **kwargs)
    263     return_sites = None if not return_sites else return_sites
    264     posterior_samples = _predictive(
    265         self.guide,
    266         posterior_samples,
   (...)
    271         model_kwargs=kwargs,
    272     )
--> 273 return _predictive(
    274     self.model,
    275     posterior_samples,
    276     self.num_samples,
    277     return_sites=return_sites,
    278     parallel=self.parallel,
    279     model_args=args,
    280     model_kwargs=kwargs,
    281 )

File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py:78, in _predictive(model, posterior_samples, num_samples, return_sites, return_trace, parallel, model_args, model_kwargs)
     67 def _predictive(
     68     model,
     69     posterior_samples,
   (...)
     75     model_kwargs={},
     76 ):
     77     model = torch.no_grad()(poutine.mask(model, mask=False))
---> 78     max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs)
     79     vectorize = pyro.plate(
     80         "_num_predictive_samples", num_samples, dim=-max_plate_nesting - 1
     81     )
     82     model_trace = prune_subsample_sites(
     83         poutine.trace(model).get_trace(*model_args, **model_kwargs)
     84     )

File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py:21, in _guess_max_plate_nesting(model, args, kwargs)
     15 """
     16 Guesses max_plate_nesting by running the model once
     17 without enumeration. This optimistically assumes static model
     18 structure.
     19 """
     20 with poutine.block():
---> 21     model_trace = poutine.trace(model).get_trace(*args, **kwargs)
     22 sites = [site for site in model_trace.nodes.values() if site["type"] == "sample"]
     24 dims = [
     25     frame.dim
     26     for site in sites
     27     for frame in site["cond_indep_stack"]
     28     if frame.vectorized
     29 ]

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:198, in TraceHandler.get_trace(self, *args, **kwargs)
    190 def get_trace(self, *args, **kwargs):
    191     """
    192     :returns: data structure
    193     :rtype: pyro.poutine.Trace
   (...)
    196     Calls this poutine and returns its trace instead of the function's return value.
    197     """
--> 198     self(*args, **kwargs)
    199     return self.msngr.get_trace()

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.__call__(self, *args, **kwargs)
    170 self.msngr.trace.add_node(
    171     "_INPUT", name="_INPUT", type="args", args=args, kwargs=kwargs
    172 )
    173 try:
--> 174     ret = self.fn(*args, **kwargs)
    175 except (ValueError, RuntimeError) as e:
    176     exc_type, exc_value, traceback = sys.exc_info()

File ~/anaconda3/lib/python3.10/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, **kwargs)
     10 def _context_wrap(context, fn, *args, **kwargs):
     11     with context:
---> 12         return fn(*args, **kwargs)

File ~/Projects/pyciemss/pyciemss/interfaces.py:282, in sample.<locals>.wrapped_model()
    280         for handler in intervention_handlers:
    281             stack.enter_context(handler)
--> 282         full_trajectory = model(
    283             torch.as_tensor(start_time),
    284             torch.as_tensor(end_time),
    285             logging_times=logging_times,
    286             is_traced=True,
    287         )
    289 if noise_model is not None:
    290     compiled_noise_model = compile_noise_model(
    291         noise_model, vars=set(full_trajectory.keys()), **noise_model_kwargs
    292     )

File ~/anaconda3/lib/python3.10/site-packages/pyro/nn/module.py:449, in PyroModule.__call__(self, *args, **kwargs)
    447 def __call__(self, *args, **kwargs):
    448     with self._pyro_context:
--> 449         result = super().__call__(*args, **kwargs)
    450     if (
    451         pyro.settings.get("validate_poutine")
    452         and not self._pyro_context.active
    453         and _is_module_local_param_enabled()
    454     ):
    455         self._check_module_local_param_usage()

File ~/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/Projects/pyciemss/pyciemss/compiled_dynamics.py:77, in CompiledDynamics.forward(self, start_time, end_time, logging_times, is_traced)
     75 if logging_times is not None:
     76     with LogTrajectory(logging_times) as lt:
---> 77         simulate(self.deriv, self.initial_state(), start_time, end_time)
     78         state = lt.trajectory
     79 else:

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:281, in effectful.<locals>._fn(*args, **kwargs)
    264 msg = {
    265     "type": type,
    266     "name": name,
   (...)
    278     "infer": infer,
    279 }
    280 # apply the stack and return its return value
--> 281 apply_stack(msg)
    282 return msg["value"]

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:212, in apply_stack(initial_msg)
    209 for frame in reversed(stack):
    210     pointer = pointer + 1
--> 212     frame._process_message(msg)
    214     if msg["stop"]:
    215         break

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:162, in Messenger._process_message(self, msg)
    160 method = getattr(self, "_pyro_{}".format(msg["type"]), None)
    161 if method is not None:
--> 162     return method(msg)
    163 return None

File ~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/solver.py:109, in Solver._pyro_simulate(self, msg)
    106     if ph.priority > start_time:
    107         break
--> 109 state, start_time, next_interruption = simulate_to_interruption(
    110     possible_interruptions,
    111     dynamics,
    112     state,
    113     start_time,
    114     end_time,
    115     **msg["kwargs"],
    116 )
    118 if next_interruption is not None:
    119     dynamics, state = next_interruption.callback(dynamics, state)

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:281, in effectful.<locals>._fn(*args, **kwargs)
    264 msg = {
    265     "type": type,
    266     "name": name,
   (...)
    278     "infer": infer,
    279 }
    280 # apply the stack and return its return value
--> 281 apply_stack(msg)
    282 return msg["value"]

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:212, in apply_stack(initial_msg)
    209 for frame in reversed(stack):
    210     pointer = pointer + 1
--> 212     frame._process_message(msg)
    214     if msg["stop"]:
    215         break

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:162, in Messenger._process_message(self, msg)
    160 method = getattr(self, "_pyro_{}".format(msg["type"]), None)
    161 if method is not None:
--> 162     return method(msg)
    163 return None

File ~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py:89, in TorchDiffEq._pyro_simulate_to_interruption(self, msg)
     87 interruptions, dynamics, initial_state, start_time, end_time = msg["args"]
     88 msg["kwargs"].update(self.odeint_kwargs)
---> 89 msg["value"] = torchdiffeq_simulate_to_interruption(
     90     interruptions,
     91     dynamics,
     92     initial_state,
     93     start_time,
     94     end_time,
     95     **msg["kwargs"],
     96 )
     97 msg["done"] = True

File ~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py:244, in torchdiffeq_simulate_to_interruption(interruptions, dynamics, initial_state, start_time, end_time, **kwargs)
    234 def torchdiffeq_simulate_to_interruption(
    235     interruptions: List[Interruption[torch.Tensor]],
    236     dynamics: Dynamics[torch.Tensor],
   (...)
    240     **kwargs,
    241 ) -> Tuple[State[torch.Tensor], torch.Tensor, Optional[Interruption[torch.Tensor]]]:
    242     assert len(interruptions) > 0, "should have at least one interruption here"
--> 244     (next_interruption,), interruption_time = _torchdiffeq_get_next_interruptions(
    245         dynamics, initial_state, start_time, interruptions, **kwargs
    246     )
    248     value = simulate_point(
    249         dynamics, initial_state, start_time, interruption_time, **kwargs
    250     )
    251     return value, interruption_time, next_interruption

File ~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py:192, in _torchdiffeq_get_next_interruptions(dynamics, start_state, start_time, interruptions, **kwargs)
    189 combined_event_f = torchdiffeq_combined_event_f(interruptions, var_order)
    191 # Simulate to the event execution.
--> 192 event_time, event_solutions = _batched_odeint(  # torchdiffeq.odeint_event(
    193     functools.partial(_deriv, dynamics, var_order),
    194     tuple(start_state[v] for v in var_order),
    195     start_time,
    196     event_fn=combined_event_f,
    197     **kwargs,
    198 )
    200 # event_state has both the first and final state of the interrupted simulation. We just want the last.
    201 event_solution: Tuple[torch.Tensor, ...] = tuple(
    202     s[..., -1] for s in event_solutions
    203 )  # TODO support event_dim > 0

File ~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py:119, in _batched_odeint(func, y0, t, event_fn, **odeint_kwargs)
    112 y0_expanded = tuple(
    113     # y0_[(None,) * (len(y0_batch_shape) - (len(y0_.shape) - event_dim)) + (...,)]
    114     y0_.expand(y0_batch_shape + y0_.shape[len(y0_.shape) - event_dim :])
    115     for y0_ in y0
    116 )
    118 if event_fn is not None:
--> 119     event_t, yt_raw = torchdiffeq.odeint_event(
    120         func, y0_expanded, t, event_fn=event_fn, **odeint_kwargs
    121     )
    122 else:
    123     yt_raw = torchdiffeq.odeint(func, y0_expanded, t, **odeint_kwargs)

File ~/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/odeint.py:101, in odeint_event(func, y0, t0, event_fn, reverse_time, odeint_interface, **kwargs)
     98 else:
     99     t = torch.cat([t0.reshape(-1), t0.reshape(-1).detach() + 1.0])
--> 101 event_t, solution = odeint_interface(func, y0, t, event_fn=event_fn, **kwargs)
    103 # Dummy values for rtol, atol, method, and options.
    104 shapes, _func, _, t, _, _, _, _, event_fn, _ = _check_inputs(func, y0, t, 0.0, 0.0, None, None, event_fn, SOLVERS)

File ~/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/odeint.py:79, in odeint(func, y0, t, rtol, atol, method, options, event_fn)
     77     solution = solver.integrate(t)
     78 else:
---> 79     event_t, solution = solver.integrate_until_event(t[0], event_fn)
     80     event_t = event_t.to(t)
     81     if t_is_reversed:

File ~/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/solvers.py:122, in FixedGridODESolver.integrate_until_event(self, t0, event_fn)
    121 def integrate_until_event(self, t0, event_fn):
--> 122     assert self.step_size is not None, "Event handling for fixed step solvers currently requires `step_size` to be provided in options."
    124     t0 = t0.type_as(self.y0)
    125     y0 = self.y0

AssertionError: Event handling for fixed step solvers currently requires `step_size` to be provided in options.
sabinala commented 4 months ago

This minimal notebook produces the error: https://github.com/ciemss/pyciemss/blob/502-assertionerror-event-handling-for-fixed-step-solvers-currently-requires-step_size-to-be-provided-in-options/docs/source/step_size_error_producing_ntbk.ipynb

djinnome commented 4 months ago

This fixes the problem:

def make_var_threshold(var: str, threshold: torch.Tensor):
    return lambda time, state: state[var] - threshold  

infection_threshold = make_var_threshold("I", torch.tensor(40.0))
dynamic_parameter_interventions1 = {infection_threshold: {"p_cbeta": torch.tensor(0.3)}}
# Specify solver options including the step_size
solver_options = {"step_size": 1e-2}  # Example step size, adjust as needed

result = pyciemss.sample(model3, end_time, logging_step_size, num_samples, start_time=start_time, 
                         dynamic_parameter_interventions=dynamic_parameter_interventions1, 
                         solver_method="euler",
                         solver_options=solver_options)

However, you do not need to specify step_size with dopri5 and it is much faster and more stable.

def make_var_threshold(var: str, threshold: torch.Tensor):
    return lambda time, state: state[var] - threshold  

infection_threshold = make_var_threshold("I", torch.tensor(400.0))
dynamic_parameter_interventions1 = {infection_threshold: {"p_cbeta": torch.tensor(0.3)}}
# Specify solver options including the step_size
solver_options = {"step_size": 1e-2}  # Example step size, adjust as needed

result = pyciemss.sample(model3, end_time, logging_step_size, num_samples, start_time=start_time, 
                         dynamic_parameter_interventions=dynamic_parameter_interventions1, 
                         solver_method="dopri5")
                         #solver_options=solver_options)
sabinala commented 4 months ago

I believe this issue is now resolved and can be closed.