ciemss / pyciemss

Causal and probabilistic reasoning with continuous time dynamical systems
Other
12 stars 4 forks source link

KeyError for dynamic parameter intervention #525

Closed sabinala closed 4 months ago

sabinala commented 4 months ago

Getting the following error when trying to sample with a dynamic_parameter_intervention:

ERROR:root:
                ###############################

                There was an exception in pyciemss

                Error occured in function: sample

                Function docs : 
    Load a model from a file, compile it into a probabilistic program, and sample from it.

    Args:
        model_path_or_json: Union[str, Dict]
            - A path to a AMR model file or JSON containing a model in AMR form.
        end_time: float
            - The end time of the sampled simulation.
        logging_step_size: float
            - The step size to use for logging the trajectory.
        num_samples: int
            - The number of samples to draw from the model.
        solver_method: str
            - The method to use for solving the ODE. See torchdiffeq's `odeint` method for more details.
            - If performance is incredibly slow, we suggest using `euler` to debug.
              If using `euler` results in faster simulation, the issue is likely that the model is stiff.
        solver_options: Dict[str, Any]
            - Options to pass to the solver. See torchdiffeq' `odeint` method for more details.
        start_time: float
            - The start time of the model. This is used to align the `start_state` from the
              AMR model with the simulation timepoints.
            - By default we set the `start_time` to be 0.
        inferred_parameters: Optional[pyro.nn.PyroModule]
            - A Pyro module that contains the inferred parameters of the model.
              This is typically the result of `calibrate`.
            - If not provided, we will use the default values from the AMR model.
        static_state_interventions: Dict[float, Dict[str, Intervention]]
            - A dictionary of static interventions to apply to the model.
            - Each key is the time at which the intervention is applied.
            - Each value is a dictionary of the form {state_variable_name: intervention_assignment}.
            - Note that the `intervention_assignment` can be any type supported by
              :func:`~chirho.interventional.ops.intervene`, including functions.
        static_parameter_interventions: Dict[float, Dict[str, Intervention]]
            - A dictionary of static interventions to apply to the model.
            - Each key is the time at which the intervention is applied.
            - Each value is a dictionary of the form {parameter_name: intervention_assignment}.
            - Note that the `intervention_assignment` can be any type supported by
              :func:`~chirho.interventional.ops.intervene`, including functions.
        dynamic_state_interventions: Dict[
                                        Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],
                                        Dict[str, Intervention]
                                        ]
            - A dictionary of dynamic interventions to apply to the model.
            - Each key is a function that takes in the current state of the model and returns a tensor.
              When this function crosses 0, the dynamic intervention is applied.
            - Each value is a dictionary of the form {state_variable_name: intervention_assignment}.
            - Note that the `intervention_assignment` can be any type supported by
              :func:`~chirho.interventional.ops.intervene`, including functions.
        dynamic_parameter_interventions: Dict[
                                            Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],
                                            Dict[str, Intervention]
                                            ]
            - A dictionary of dynamic interventions to apply to the model.
            - Each key is a function that takes in the current state of the model and returns a tensor.
              When this function crosses 0, the dynamic intervention is applied.
            - Each value is a dictionary of the form {parameter_name: intervention_assignment}.
            - Note that the `intervention_assignment` can be any type supported by
              :func:`~chirho.interventional.ops.intervene`, including functions.
        alpha: float
            - Risk level for alpha-superquantile outputs in the results dictionary.

    Returns:
        result: Dict[str, torch.Tensor]
            - Dictionary of outputs with following attributes:
                - data: The samples from the model as a pandas DataFrame.
                - unprocessed_result: Dictionary of outputs from the model.
                    - Each key is the name of a parameter or state variable in the model.
                    - Each value is a tensor of shape (num_samples, num_timepoints) for state variables
                    and (num_samples,) for parameters.
                - quantiles: The quantiles for ensemble score calculation as a pandas DataFrames.
                - risk: Dictionary with each key as the name of a state with
                a dictionary of risk estimates for each state at the final timepoint.
                    - risk: alpha-superquantile risk estimate
                    Superquantiles can be intuitively thought of as a tail expectation, or an average
                    over a portion of worst-case outcomes. Given a distribution of a
                    quantity of interest (QoI), the superquantile at level \alpha\in[0, 1] is
                    the expected value of the largest 100(1 -\alpha)% realizations of the QoI.
                    - qoi: Samples of quantity of interest (value of the state at the final timepoint)
                - schema: Visualization. (If visual_options is truthy)

                ################################

Traceback (most recent call last):
  File "/Users/altu809/Projects/pyciemss/pyciemss/integration_utils/custom_decorators.py", line 10, in wrapped
    result = function(*args, **kwargs)
  File "/Users/altu809/Projects/pyciemss/pyciemss/interfaces.py", line 332, in sample
    samples = pyro.infer.Predictive(
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py", line 273, in forward
    return _predictive(
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py", line 127, in _predictive
    return _predictive_sequential(
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py", line 55, in _predictive_sequential
    {site: trace.nodes[site]["value"] for site in return_site_shapes}
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py", line 55, in <dictcomp>
    {site: trace.nodes[site]["value"] for site in return_site_shapes}
KeyError: 'parameter_intervention_value_p_cbeta_0'
---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[15], line 10
      7 infection_threshold = make_var_threshold("I", torch.tensor(400.0))
      8 dynamic_parameter_interventions1 = {infection_threshold: {"p_cbeta": torch.tensor(0.3)}}
---> 10 result = pyciemss.sample(model3, end_time, logging_step_size, num_samples, start_time=start_time, 
     11                          dynamic_parameter_interventions=dynamic_parameter_interventions1, 
     12                          solver_method="dopri5")
     13 display(result["data"].head())
     15 # Plot the result

File ~/Projects/pyciemss/pyciemss/integration_utils/custom_decorators.py:29, in pyciemss_logging_wrapper.<locals>.wrapped(*args, **kwargs)
     17 log_message = """
     18     ###############################
     19 
   (...)
     26     ################################
     27 """
     28 logging.exception(log_message, function.__name__, function.__doc__)
---> 29 raise e

File ~/Projects/pyciemss/pyciemss/integration_utils/custom_decorators.py:10, in pyciemss_logging_wrapper.<locals>.wrapped(*args, **kwargs)
      8 try:
      9     start_time = time.perf_counter()
---> 10     result = function(*args, **kwargs)
     11     end_time = time.perf_counter()
     12     logging.info(
     13         "Elapsed time for %s: %f", function.__name__, end_time - start_time
     14     )

File ~/Projects/pyciemss/pyciemss/interfaces.py:332, in sample(model_path_or_json, end_time, logging_step_size, num_samples, noise_model, noise_model_kwargs, solver_method, solver_options, start_time, time_unit, inferred_parameters, static_state_interventions, static_parameter_interventions, dynamic_state_interventions, dynamic_parameter_interventions, alpha)
    320         compiled_noise_model(full_trajectory)
    322 parallel = (
    323     False
    324     if len(
   (...)
    329     else True
    330 )
--> 332 samples = pyro.infer.Predictive(
    333     wrapped_model,
    334     guide=inferred_parameters,
    335     num_samples=num_samples,
    336     parallel=parallel,
    337 )()
    339 risk_results = {}
    340 for k, vals in samples.items():

File ~/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
   1126 # If we don't have any hooks, we want to skip the rest of the logic in
   1127 # this function, and just call forward.
   1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1129         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130     return forward_call(*input, **kwargs)
   1131 # Do not call functions when jit is used
   1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py:273, in Predictive.forward(self, *args, **kwargs)
    263     return_sites = None if not return_sites else return_sites
    264     posterior_samples = _predictive(
    265         self.guide,
    266         posterior_samples,
   (...)
    271         model_kwargs=kwargs,
    272     )
--> 273 return _predictive(
    274     self.model,
    275     posterior_samples,
    276     self.num_samples,
    277     return_sites=return_sites,
    278     parallel=self.parallel,
    279     model_args=args,
    280     model_kwargs=kwargs,
    281 )

File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py:127, in _predictive(model, posterior_samples, num_samples, return_sites, return_trace, parallel, model_args, model_kwargs)
    124     return_site_shapes["_RETURN"] = shape
    126 if not parallel:
--> 127     return _predictive_sequential(
    128         model,
    129         posterior_samples,
    130         model_args,
    131         model_kwargs,
    132         num_samples,
    133         return_site_shapes,
    134         return_trace=False,
    135     )
    137 trace = poutine.trace(
    138     poutine.condition(vectorize(model), reshaped_samples)
    139 ).get_trace(*model_args, **model_kwargs)
    140 predictions = {}

File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py:55, in _predictive_sequential(model, posterior_samples, model_args, model_kwargs, num_samples, return_site_shapes, return_trace)
     52         collected.append(trace)
     53     else:
     54         collected.append(
---> 55             {site: trace.nodes[site]["value"] for site in return_site_shapes}
     56         )
     58 if return_trace:
     59     return collected

File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py:55, in <dictcomp>(.0)
     52         collected.append(trace)
     53     else:
     54         collected.append(
---> 55             {site: trace.nodes[site]["value"] for site in return_site_shapes}
     56         )
     58 if return_trace:
     59     return collected

KeyError: 'parameter_intervention_value_p_cbeta_0'