Getting the following error when trying to sample with a dynamic_parameter_intervention:
ERROR:root:
###############################
There was an exception in pyciemss
Error occured in function: sample
Function docs :
Load a model from a file, compile it into a probabilistic program, and sample from it.
Args:
model_path_or_json: Union[str, Dict]
- A path to a AMR model file or JSON containing a model in AMR form.
end_time: float
- The end time of the sampled simulation.
logging_step_size: float
- The step size to use for logging the trajectory.
num_samples: int
- The number of samples to draw from the model.
solver_method: str
- The method to use for solving the ODE. See torchdiffeq's `odeint` method for more details.
- If performance is incredibly slow, we suggest using `euler` to debug.
If using `euler` results in faster simulation, the issue is likely that the model is stiff.
solver_options: Dict[str, Any]
- Options to pass to the solver. See torchdiffeq' `odeint` method for more details.
start_time: float
- The start time of the model. This is used to align the `start_state` from the
AMR model with the simulation timepoints.
- By default we set the `start_time` to be 0.
inferred_parameters: Optional[pyro.nn.PyroModule]
- A Pyro module that contains the inferred parameters of the model.
This is typically the result of `calibrate`.
- If not provided, we will use the default values from the AMR model.
static_state_interventions: Dict[float, Dict[str, Intervention]]
- A dictionary of static interventions to apply to the model.
- Each key is the time at which the intervention is applied.
- Each value is a dictionary of the form {state_variable_name: intervention_assignment}.
- Note that the `intervention_assignment` can be any type supported by
:func:`~chirho.interventional.ops.intervene`, including functions.
static_parameter_interventions: Dict[float, Dict[str, Intervention]]
- A dictionary of static interventions to apply to the model.
- Each key is the time at which the intervention is applied.
- Each value is a dictionary of the form {parameter_name: intervention_assignment}.
- Note that the `intervention_assignment` can be any type supported by
:func:`~chirho.interventional.ops.intervene`, including functions.
dynamic_state_interventions: Dict[
Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],
Dict[str, Intervention]
]
- A dictionary of dynamic interventions to apply to the model.
- Each key is a function that takes in the current state of the model and returns a tensor.
When this function crosses 0, the dynamic intervention is applied.
- Each value is a dictionary of the form {state_variable_name: intervention_assignment}.
- Note that the `intervention_assignment` can be any type supported by
:func:`~chirho.interventional.ops.intervene`, including functions.
dynamic_parameter_interventions: Dict[
Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],
Dict[str, Intervention]
]
- A dictionary of dynamic interventions to apply to the model.
- Each key is a function that takes in the current state of the model and returns a tensor.
When this function crosses 0, the dynamic intervention is applied.
- Each value is a dictionary of the form {parameter_name: intervention_assignment}.
- Note that the `intervention_assignment` can be any type supported by
:func:`~chirho.interventional.ops.intervene`, including functions.
alpha: float
- Risk level for alpha-superquantile outputs in the results dictionary.
Returns:
result: Dict[str, torch.Tensor]
- Dictionary of outputs with following attributes:
- data: The samples from the model as a pandas DataFrame.
- unprocessed_result: Dictionary of outputs from the model.
- Each key is the name of a parameter or state variable in the model.
- Each value is a tensor of shape (num_samples, num_timepoints) for state variables
and (num_samples,) for parameters.
- quantiles: The quantiles for ensemble score calculation as a pandas DataFrames.
- risk: Dictionary with each key as the name of a state with
a dictionary of risk estimates for each state at the final timepoint.
- risk: alpha-superquantile risk estimate
Superquantiles can be intuitively thought of as a tail expectation, or an average
over a portion of worst-case outcomes. Given a distribution of a
quantity of interest (QoI), the superquantile at level \alpha\in[0, 1] is
the expected value of the largest 100(1 -\alpha)% realizations of the QoI.
- qoi: Samples of quantity of interest (value of the state at the final timepoint)
- schema: Visualization. (If visual_options is truthy)
################################
Traceback (most recent call last):
File "/Users/altu809/Projects/pyciemss/pyciemss/integration_utils/custom_decorators.py", line 10, in wrapped
result = function(*args, **kwargs)
File "/Users/altu809/Projects/pyciemss/pyciemss/interfaces.py", line 332, in sample
samples = pyro.infer.Predictive(
File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py", line 273, in forward
return _predictive(
File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py", line 127, in _predictive
return _predictive_sequential(
File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py", line 55, in _predictive_sequential
{site: trace.nodes[site]["value"] for site in return_site_shapes}
File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py", line 55, in <dictcomp>
{site: trace.nodes[site]["value"] for site in return_site_shapes}
KeyError: 'parameter_intervention_value_p_cbeta_0'
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
Cell In[15], line 10
7 infection_threshold = make_var_threshold("I", torch.tensor(400.0))
8 dynamic_parameter_interventions1 = {infection_threshold: {"p_cbeta": torch.tensor(0.3)}}
---> 10 result = pyciemss.sample(model3, end_time, logging_step_size, num_samples, start_time=start_time,
11 dynamic_parameter_interventions=dynamic_parameter_interventions1,
12 solver_method="dopri5")
13 display(result["data"].head())
15 # Plot the result
File ~/Projects/pyciemss/pyciemss/integration_utils/custom_decorators.py:29, in pyciemss_logging_wrapper.<locals>.wrapped(*args, **kwargs)
17 log_message = """
18 ###############################
19
(...)
26 ################################
27 """
28 logging.exception(log_message, function.__name__, function.__doc__)
---> 29 raise e
File ~/Projects/pyciemss/pyciemss/integration_utils/custom_decorators.py:10, in pyciemss_logging_wrapper.<locals>.wrapped(*args, **kwargs)
8 try:
9 start_time = time.perf_counter()
---> 10 result = function(*args, **kwargs)
11 end_time = time.perf_counter()
12 logging.info(
13 "Elapsed time for %s: %f", function.__name__, end_time - start_time
14 )
File ~/Projects/pyciemss/pyciemss/interfaces.py:332, in sample(model_path_or_json, end_time, logging_step_size, num_samples, noise_model, noise_model_kwargs, solver_method, solver_options, start_time, time_unit, inferred_parameters, static_state_interventions, static_parameter_interventions, dynamic_state_interventions, dynamic_parameter_interventions, alpha)
320 compiled_noise_model(full_trajectory)
322 parallel = (
323 False
324 if len(
(...)
329 else True
330 )
--> 332 samples = pyro.infer.Predictive(
333 wrapped_model,
334 guide=inferred_parameters,
335 num_samples=num_samples,
336 parallel=parallel,
337 )()
339 risk_results = {}
340 for k, vals in samples.items():
File ~/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, **kwargs)
1126 # If we don't have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(*input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py:273, in Predictive.forward(self, *args, **kwargs)
263 return_sites = None if not return_sites else return_sites
264 posterior_samples = _predictive(
265 self.guide,
266 posterior_samples,
(...)
271 model_kwargs=kwargs,
272 )
--> 273 return _predictive(
274 self.model,
275 posterior_samples,
276 self.num_samples,
277 return_sites=return_sites,
278 parallel=self.parallel,
279 model_args=args,
280 model_kwargs=kwargs,
281 )
File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py:127, in _predictive(model, posterior_samples, num_samples, return_sites, return_trace, parallel, model_args, model_kwargs)
124 return_site_shapes["_RETURN"] = shape
126 if not parallel:
--> 127 return _predictive_sequential(
128 model,
129 posterior_samples,
130 model_args,
131 model_kwargs,
132 num_samples,
133 return_site_shapes,
134 return_trace=False,
135 )
137 trace = poutine.trace(
138 poutine.condition(vectorize(model), reshaped_samples)
139 ).get_trace(*model_args, **model_kwargs)
140 predictions = {}
File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py:55, in _predictive_sequential(model, posterior_samples, model_args, model_kwargs, num_samples, return_site_shapes, return_trace)
52 collected.append(trace)
53 else:
54 collected.append(
---> 55 {site: trace.nodes[site]["value"] for site in return_site_shapes}
56 )
58 if return_trace:
59 return collected
File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py:55, in <dictcomp>(.0)
52 collected.append(trace)
53 else:
54 collected.append(
---> 55 {site: trace.nodes[site]["value"] for site in return_site_shapes}
56 )
58 if return_trace:
59 return collected
KeyError: 'parameter_intervention_value_p_cbeta_0'
Getting the following error when trying to
sample
with adynamic_parameter_intervention
: