For example, using this model (which has states I and H and observables infected and hospitalized) calling calibrate (as in calibrated_results = pyciemss.calibrate(model1, dataset1, data_mapping=data_mapping, num_iterations=num_iterations)) with the data mapping data_mapping = {"case": "I", "hosp": "H"} will work just fine, but when the data mapping data_mapping = {"case": "infected", "hosp": "hospitalized"} is used, I get the following error:
File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/autoguide/guides.py:157, in AutoGuide._setup_prototype(self, *args, kwargs)
154 def _setup_prototype(self, *args, *kwargs):
155 # run the model so we can inspect its structure
156 model = poutine.block(self.model, self._prototype_hide_fn)
--> 157 self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(
158 args, kwargs
159 )
160 if self.master is not None:
161 self.master()._check_prototype(self.prototype_trace)
File ~/Projects/pyciemss/pyciemss/interfaces.py:500, in calibrate..wrapped_model()
498 for handler in intervention_handlers:
499 stack.enter_context(handler)
--> 500 model(
501 torch.as_tensor(start_time),
502 torch.as_tensor(data_timepoints[-1]),
503 )
File ~/anaconda3/lib/python3.10/site-packages/pyro/nn/module.py:449, in PyroModule.call(self, *args, kwargs)
447 def call(self, *args, *kwargs):
448 with self._pyro_context:
--> 449 result = super().call(args, kwargs)
450 if (
451 pyro.settings.get("validate_poutine")
452 and not self._pyro_context.active
453 and _is_module_local_param_enabled()
454 ):
455 self._check_module_local_param_usage()
File ~/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, *kwargs)
1126 # If we don't have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
File ~/Projects/pyciemss/pyciemss/compiled_dynamics.py:80, in CompiledDynamics.forward(self, start_time, end_time, logging_times, is_traced)
78 state = lt.trajectory
79 else:
---> 80 state = simulate(self.deriv, self.initial_state(), start_time, end_time)
82 observables = self.observables(state)
84 if is_traced:
85 # Add the observables to the trace so that they can be accessed later.
File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:281, in effectful.._fn(*args, **kwargs)
264 msg = {
265 "type": type,
266 "name": name,
(...)
278 "infer": infer,
279 }
280 # apply the stack and return its return value
--> 281 apply_stack(msg)
282 return msg["value"]
File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:220, in apply_stack(initial_msg)
217 default_process_message(msg)
219 for frame in stack[-pointer:]:
--> 220 frame._postprocess_message(msg)
222 cont = msg["continuation"]
223 if cont is not None:
File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:168, in Messenger._postprocess_message(self, msg)
166 method = getattr(self, "_pyropost{}".format(msg["type"]), None)
167 if method is not None:
--> 168 return method(msg)
169 return None
File ~/anaconda3/lib/python3.10/functools.py:889, in singledispatch..wrapper(*args, *kw)
885 if not args:
886 raise TypeError(f'{funcname} requires at least '
887 '1 positional argument')
--> 889 return dispatch(args[0].class)(args, **kw)
File ~/anaconda3/lib/python3.10/site-packages/chirho/observational/internals.py:56, in _observe_dict(rv, obs, name, kwargs)
47 @observe.register(dict)
48 def _observe_dict(
49 rv: Mapping[K, T],
(...)
53 kwargs,
54 ) -> Mapping[K, T]:
55 if callable(obs):
---> 56 obs = obs(rv)
57 if obs is not rv and obs is not None:
58 raise NotImplementedError("Dependent observations are not yet supported")
File ~/anaconda3/lib/python3.10/site-packages/pyro/nn/module.py:449, in PyroModule.call(self, *args, kwargs)
447 def call(self, *args, *kwargs):
448 with self._pyro_context:
--> 449 result = super().call(args, kwargs)
450 if (
451 pyro.settings.get("validate_poutine")
452 and not self._pyro_context.active
453 and _is_module_local_param_enabled()
454 ):
455 self._check_module_local_param_usage()
File ~/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, *kwargs)
1126 # If we don't have any hooks, we want to skip the rest of the logic in
1127 # this function, and just call forward.
1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
1129 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1130 return forward_call(input, **kwargs)
1131 # Do not call functions when jit is used
1132 full_backward_hooks, non_full_backward_hooks = [], []
File ~/Projects/pyciemss/pyciemss/observation.py:33, in StateIndependentNoiseModel.forward(self, state)
29 def forward(self, state: Dict[str, torch.Tensor]) -> None:
30 for k in self.vars:
31 pyro.sample(
32 f"{k}_noisy",
---> 33 self.markov_kernel(k, state[k]),
34 )
For example, using this model (which has states
I
andH
and observablesinfected
andhospitalized
) callingcalibrate
(as incalibrated_results = pyciemss.calibrate(model1, dataset1, data_mapping=data_mapping, num_iterations=num_iterations)
) with the data mappingdata_mapping = {"case": "I", "hosp": "H"}
will work just fine, but when the data mappingdata_mapping = {"case": "infected", "hosp": "hospitalized"}
is used, I get the following error:KeyError Traceback (most recent call last) Cell In[42], line 3 1 data_mapping = {"case": "infected", "hosp": "hospitalized"} # data_mapping = {"column_name": "observable/state_variable"} 2 num_iterations = 10 if smoke_test else 1000 ----> 3 calibrated_results = pyciemss.calibrate(model1, dataset1, data_mapping=data_mapping, num_iterations=num_iterations) 4 parameter_estimates = calibrated_results["inferred_parameters"] 5 calibrated_results
File ~/Projects/pyciemss/pyciemss/integration_utils/custom_decorators.py:29, in pyciemss_logging_wrapper..wrapped(*args, **kwargs)
17 log_message = """
18 ###############################
19
(...)
26 ################################
27 """
28 logging.exception(log_message, function.name, function.doc)
---> 29 raise e
File ~/Projects/pyciemss/pyciemss/integration_utils/custom_decorators.py:10, in pyciemss_logging_wrapper..wrapped(*args, *kwargs)
8 try:
9 start_time = time.perf_counter()
---> 10 result = function(args, **kwargs)
11 end_time = time.perf_counter()
12 logging.info(
13 "Elapsed time for %s: %f", function.name, end_time - start_time
14 )
File ~/Projects/pyciemss/pyciemss/interfaces.py:505, in calibrate(model_path_or_json, data_path, data_mapping, noise_model, noise_model_kwargs, solver_method, solver_options, start_time, static_state_interventions, static_parameter_interventions, dynamic_state_interventions, dynamic_parameter_interventions, num_iterations, lr, verbose, num_particles, deterministic_learnable_parameters, progress_hook) 499 stack.enter_context(handler) 500 model( 501 torch.as_tensor(start_time), 502 torch.as_tensor(data_timepoints[-1]), 503 ) --> 505 inferred_parameters = autoguide(wrapped_model) 507 optim = pyro.optim.Adam({"lr": lr}) 508 loss = pyro.infer.Trace_ELBO(num_particles=num_particles)
File ~/Projects/pyciemss/pyciemss/interfaces.py:447, in calibrate..autoguide(model)
443 try:
444 mvn_guide = pyro.infer.autoguide.AutoLowRankMultivariateNormal(
445 pyro.poutine.block(model, hide=deterministic_learnable_parameters)
446 )
--> 447 mvn_guide._setup_prototype()
448 guide.append(mvn_guide)
449 except RuntimeError as re:
File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/autoguide/guides.py:1002, in AutoLowRankMultivariateNormal._setup_prototype(self, *args, kwargs) 1001 def _setup_prototype(self, *args, *kwargs): -> 1002 super()._setup_prototype(args, kwargs) 1003 # Initialize guide params 1004 self.loc = nn.Parameter(self._init_loc())
File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/autoguide/guides.py:636, in AutoContinuous._setup_prototype(self, *args, kwargs) 635 def _setup_prototype(self, *args, *kwargs): --> 636 super()._setup_prototype(args, kwargs) 637 self._unconstrained_shapes = {} 638 self._cond_indep_stacks = {}
File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/autoguide/guides.py:157, in AutoGuide._setup_prototype(self, *args, kwargs) 154 def _setup_prototype(self, *args, *kwargs): 155 # run the model so we can inspect its structure 156 model = poutine.block(self.model, self._prototype_hide_fn) --> 157 self.prototype_trace = poutine.block(poutine.trace(model).get_trace)( 158 args, kwargs 159 ) 160 if self.master is not None: 161 self.master()._check_prototype(self.prototype_trace)
File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, kwargs) 10 def _context_wrap(context, fn, *args, *kwargs): 11 with context: ---> 12 return fn(args, kwargs)
File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:198, in TraceHandler.get_trace(self, *args, kwargs) 190 def get_trace(self, *args, *kwargs): 191 """ 192 :returns: data structure 193 :rtype: pyro.poutine.Trace (...) 196 Calls this poutine and returns its trace instead of the function's return value. 197 """ --> 198 self(args, kwargs) 199 return self.msngr.get_trace()
File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.call(self, *args, *kwargs) 170 self.msngr.trace.add_node( 171 "_INPUT", name="_INPUT", type="args", args=args, kwargs=kwargs 172 ) 173 try: --> 174 ret = self.fn(args, **kwargs) 175 except (ValueError, RuntimeError) as e: 176 exc_type, exc_value, traceback = sys.exc_info()
File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, kwargs) 10 def _context_wrap(context, fn, *args, *kwargs): 11 with context: ---> 12 return fn(args, kwargs)
File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, kwargs) 10 def _context_wrap(context, fn, *args, *kwargs): 11 with context: ---> 12 return fn(args, kwargs)
File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, kwargs) 10 def _context_wrap(context, fn, *args, *kwargs): 11 with context: ---> 12 return fn(args, kwargs)
File ~/Projects/pyciemss/pyciemss/interfaces.py:500, in calibrate..wrapped_model()
498 for handler in intervention_handlers:
499 stack.enter_context(handler)
--> 500 model(
501 torch.as_tensor(start_time),
502 torch.as_tensor(data_timepoints[-1]),
503 )
File ~/anaconda3/lib/python3.10/site-packages/pyro/nn/module.py:449, in PyroModule.call(self, *args, kwargs) 447 def call(self, *args, *kwargs): 448 with self._pyro_context: --> 449 result = super().call(args, kwargs) 450 if ( 451 pyro.settings.get("validate_poutine") 452 and not self._pyro_context.active 453 and _is_module_local_param_enabled() 454 ): 455 self._check_module_local_param_usage()
File ~/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, *kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], []
File ~/Projects/pyciemss/pyciemss/compiled_dynamics.py:80, in CompiledDynamics.forward(self, start_time, end_time, logging_times, is_traced) 78 state = lt.trajectory 79 else: ---> 80 state = simulate(self.deriv, self.initial_state(), start_time, end_time) 82 observables = self.observables(state) 84 if is_traced: 85 # Add the observables to the trace so that they can be accessed later.
File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:281, in effectful.._fn(*args, **kwargs)
264 msg = {
265 "type": type,
266 "name": name,
(...)
278 "infer": infer,
279 }
280 # apply the stack and return its return value
--> 281 apply_stack(msg)
282 return msg["value"]
File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:220, in apply_stack(initial_msg) 217 default_process_message(msg) 219 for frame in stack[-pointer:]: --> 220 frame._postprocess_message(msg) 222 cont = msg["continuation"] 223 if cont is not None:
File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:168, in Messenger._postprocess_message(self, msg) 166 method = getattr(self, "_pyropost{}".format(msg["type"]), None) 167 if method is not None: --> 168 return method(msg) 169 return None
File ~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/handlers/interruption.py:249, in StaticBatchObservation._pyro_post_simulate(self, msg) 247 def _pyro_post_simulate(self, msg: dict) -> None: 248 super()._pyro_post_simulate(msg) --> 249 self.trajectory = observe(self.trajectory, self.observation)
File ~/anaconda3/lib/python3.10/functools.py:889, in singledispatch..wrapper(*args, *kw)
885 if not args:
886 raise TypeError(f'{funcname} requires at least '
887 '1 positional argument')
--> 889 return dispatch(args[0].class)(args, **kw)
File ~/anaconda3/lib/python3.10/site-packages/chirho/observational/internals.py:56, in _observe_dict(rv, obs, name, kwargs) 47 @observe.register(dict) 48 def _observe_dict( 49 rv: Mapping[K, T], (...) 53 kwargs, 54 ) -> Mapping[K, T]: 55 if callable(obs): ---> 56 obs = obs(rv) 57 if obs is not rv and obs is not None: 58 raise NotImplementedError("Dependent observations are not yet supported")
File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, kwargs) 10 def _context_wrap(context, fn, *args, *kwargs): 11 with context: ---> 12 return fn(args, kwargs)
File ~/anaconda3/lib/python3.10/site-packages/pyro/nn/module.py:449, in PyroModule.call(self, *args, kwargs) 447 def call(self, *args, *kwargs): 448 with self._pyro_context: --> 449 result = super().call(args, kwargs) 450 if ( 451 pyro.settings.get("validate_poutine") 452 and not self._pyro_context.active 453 and _is_module_local_param_enabled() 454 ): 455 self._check_module_local_param_usage()
File ~/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, *kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], []
File ~/Projects/pyciemss/pyciemss/observation.py:33, in StateIndependentNoiseModel.forward(self, state) 29 def forward(self, state: Dict[str, torch.Tensor]) -> None: 30 for k in self.vars: 31 pyro.sample( 32 f"{k}_noisy", ---> 33 self.markov_kernel(k, state[k]), 34 )
KeyError: 'hospitalized'