ciemss / pyciemss

Causal and probabilistic reasoning with continuous time dynamical systems
Other
12 stars 4 forks source link

KeyError in calibrate when data is mapped to observables #492

Closed sabinala closed 4 months ago

sabinala commented 4 months ago

For example, using this model (which has states I and H and observables infected and hospitalized) calling calibrate (as in calibrated_results = pyciemss.calibrate(model1, dataset1, data_mapping=data_mapping, num_iterations=num_iterations)) with the data mapping data_mapping = {"case": "I", "hosp": "H"} will work just fine, but when the data mapping data_mapping = {"case": "infected", "hosp": "hospitalized"} is used, I get the following error:


KeyError Traceback (most recent call last) Cell In[42], line 3 1 data_mapping = {"case": "infected", "hosp": "hospitalized"} # data_mapping = {"column_name": "observable/state_variable"} 2 num_iterations = 10 if smoke_test else 1000 ----> 3 calibrated_results = pyciemss.calibrate(model1, dataset1, data_mapping=data_mapping, num_iterations=num_iterations) 4 parameter_estimates = calibrated_results["inferred_parameters"] 5 calibrated_results

File ~/Projects/pyciemss/pyciemss/integration_utils/custom_decorators.py:29, in pyciemss_logging_wrapper..wrapped(*args, **kwargs) 17 log_message = """ 18 ############################### 19 (...) 26 ################################ 27 """ 28 logging.exception(log_message, function.name, function.doc) ---> 29 raise e

File ~/Projects/pyciemss/pyciemss/integration_utils/custom_decorators.py:10, in pyciemss_logging_wrapper..wrapped(*args, *kwargs) 8 try: 9 start_time = time.perf_counter() ---> 10 result = function(args, **kwargs) 11 end_time = time.perf_counter() 12 logging.info( 13 "Elapsed time for %s: %f", function.name, end_time - start_time 14 )

File ~/Projects/pyciemss/pyciemss/interfaces.py:505, in calibrate(model_path_or_json, data_path, data_mapping, noise_model, noise_model_kwargs, solver_method, solver_options, start_time, static_state_interventions, static_parameter_interventions, dynamic_state_interventions, dynamic_parameter_interventions, num_iterations, lr, verbose, num_particles, deterministic_learnable_parameters, progress_hook) 499 stack.enter_context(handler) 500 model( 501 torch.as_tensor(start_time), 502 torch.as_tensor(data_timepoints[-1]), 503 ) --> 505 inferred_parameters = autoguide(wrapped_model) 507 optim = pyro.optim.Adam({"lr": lr}) 508 loss = pyro.infer.Trace_ELBO(num_particles=num_particles)

File ~/Projects/pyciemss/pyciemss/interfaces.py:447, in calibrate..autoguide(model) 443 try: 444 mvn_guide = pyro.infer.autoguide.AutoLowRankMultivariateNormal( 445 pyro.poutine.block(model, hide=deterministic_learnable_parameters) 446 ) --> 447 mvn_guide._setup_prototype() 448 guide.append(mvn_guide) 449 except RuntimeError as re:

File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/autoguide/guides.py:1002, in AutoLowRankMultivariateNormal._setup_prototype(self, *args, kwargs) 1001 def _setup_prototype(self, *args, *kwargs): -> 1002 super()._setup_prototype(args, kwargs) 1003 # Initialize guide params 1004 self.loc = nn.Parameter(self._init_loc())

File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/autoguide/guides.py:636, in AutoContinuous._setup_prototype(self, *args, kwargs) 635 def _setup_prototype(self, *args, *kwargs): --> 636 super()._setup_prototype(args, kwargs) 637 self._unconstrained_shapes = {} 638 self._cond_indep_stacks = {}

File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/autoguide/guides.py:157, in AutoGuide._setup_prototype(self, *args, kwargs) 154 def _setup_prototype(self, *args, *kwargs): 155 # run the model so we can inspect its structure 156 model = poutine.block(self.model, self._prototype_hide_fn) --> 157 self.prototype_trace = poutine.block(poutine.trace(model).get_trace)( 158 args, kwargs 159 ) 160 if self.master is not None: 161 self.master()._check_prototype(self.prototype_trace)

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, kwargs) 10 def _context_wrap(context, fn, *args, *kwargs): 11 with context: ---> 12 return fn(args, kwargs)

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:198, in TraceHandler.get_trace(self, *args, kwargs) 190 def get_trace(self, *args, *kwargs): 191 """ 192 :returns: data structure 193 :rtype: pyro.poutine.Trace (...) 196 Calls this poutine and returns its trace instead of the function's return value. 197 """ --> 198 self(args, kwargs) 199 return self.msngr.get_trace()

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.call(self, *args, *kwargs) 170 self.msngr.trace.add_node( 171 "_INPUT", name="_INPUT", type="args", args=args, kwargs=kwargs 172 ) 173 try: --> 174 ret = self.fn(args, **kwargs) 175 except (ValueError, RuntimeError) as e: 176 exc_type, exc_value, traceback = sys.exc_info()

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, kwargs) 10 def _context_wrap(context, fn, *args, *kwargs): 11 with context: ---> 12 return fn(args, kwargs)

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, kwargs) 10 def _context_wrap(context, fn, *args, *kwargs): 11 with context: ---> 12 return fn(args, kwargs)

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, kwargs) 10 def _context_wrap(context, fn, *args, *kwargs): 11 with context: ---> 12 return fn(args, kwargs)

File ~/Projects/pyciemss/pyciemss/interfaces.py:500, in calibrate..wrapped_model() 498 for handler in intervention_handlers: 499 stack.enter_context(handler) --> 500 model( 501 torch.as_tensor(start_time), 502 torch.as_tensor(data_timepoints[-1]), 503 )

File ~/anaconda3/lib/python3.10/site-packages/pyro/nn/module.py:449, in PyroModule.call(self, *args, kwargs) 447 def call(self, *args, *kwargs): 448 with self._pyro_context: --> 449 result = super().call(args, kwargs) 450 if ( 451 pyro.settings.get("validate_poutine") 452 and not self._pyro_context.active 453 and _is_module_local_param_enabled() 454 ): 455 self._check_module_local_param_usage()

File ~/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, *kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/Projects/pyciemss/pyciemss/compiled_dynamics.py:80, in CompiledDynamics.forward(self, start_time, end_time, logging_times, is_traced) 78 state = lt.trajectory 79 else: ---> 80 state = simulate(self.deriv, self.initial_state(), start_time, end_time) 82 observables = self.observables(state) 84 if is_traced: 85 # Add the observables to the trace so that they can be accessed later.

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:281, in effectful.._fn(*args, **kwargs) 264 msg = { 265 "type": type, 266 "name": name, (...) 278 "infer": infer, 279 } 280 # apply the stack and return its return value --> 281 apply_stack(msg) 282 return msg["value"]

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:220, in apply_stack(initial_msg) 217 default_process_message(msg) 219 for frame in stack[-pointer:]: --> 220 frame._postprocess_message(msg) 222 cont = msg["continuation"] 223 if cont is not None:

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:168, in Messenger._postprocess_message(self, msg) 166 method = getattr(self, "_pyropost{}".format(msg["type"]), None) 167 if method is not None: --> 168 return method(msg) 169 return None

File ~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/handlers/interruption.py:249, in StaticBatchObservation._pyro_post_simulate(self, msg) 247 def _pyro_post_simulate(self, msg: dict) -> None: 248 super()._pyro_post_simulate(msg) --> 249 self.trajectory = observe(self.trajectory, self.observation)

File ~/anaconda3/lib/python3.10/functools.py:889, in singledispatch..wrapper(*args, *kw) 885 if not args: 886 raise TypeError(f'{funcname} requires at least ' 887 '1 positional argument') --> 889 return dispatch(args[0].class)(args, **kw)

File ~/anaconda3/lib/python3.10/site-packages/chirho/observational/internals.py:56, in _observe_dict(rv, obs, name, kwargs) 47 @observe.register(dict) 48 def _observe_dict( 49 rv: Mapping[K, T], (...) 53 kwargs, 54 ) -> Mapping[K, T]: 55 if callable(obs): ---> 56 obs = obs(rv) 57 if obs is not rv and obs is not None: 58 raise NotImplementedError("Dependent observations are not yet supported")

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, kwargs) 10 def _context_wrap(context, fn, *args, *kwargs): 11 with context: ---> 12 return fn(args, kwargs)

File ~/anaconda3/lib/python3.10/site-packages/pyro/nn/module.py:449, in PyroModule.call(self, *args, kwargs) 447 def call(self, *args, *kwargs): 448 with self._pyro_context: --> 449 result = super().call(args, kwargs) 450 if ( 451 pyro.settings.get("validate_poutine") 452 and not self._pyro_context.active 453 and _is_module_local_param_enabled() 454 ): 455 self._check_module_local_param_usage()

File ~/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, *kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/Projects/pyciemss/pyciemss/observation.py:33, in StateIndependentNoiseModel.forward(self, state) 29 def forward(self, state: Dict[str, torch.Tensor]) -> None: 30 for k in self.vars: 31 pyro.sample( 32 f"{k}_noisy", ---> 33 self.markov_kernel(k, state[k]), 34 )

KeyError: 'hospitalized'