ciemss / pyciemss

Causal and probabilistic reasoning with continuous time dynamical systems
Other
12 stars 4 forks source link

RunTimeError when sampling with interventions #524

Closed sabinala closed 4 months ago

sabinala commented 4 months ago

When I try to run interventions in the interfaces notebook, I'm getting the following RunTimeError:

ERROR:root: ###############################

            There was an exception in pyciemss

            Error occured in function: sample

            Function docs : 
Load a model from a file, compile it into a probabilistic program, and sample from it.

Args:
    model_path_or_json: Union[str, Dict]
        - A path to a AMR model file or JSON containing a model in AMR form.
    end_time: float
        - The end time of the sampled simulation.
    logging_step_size: float
        - The step size to use for logging the trajectory.
    num_samples: int
        - The number of samples to draw from the model.
    solver_method: str
        - The method to use for solving the ODE. See torchdiffeq's `odeint` method for more details.
        - If performance is incredibly slow, we suggest using `euler` to debug.
          If using `euler` results in faster simulation, the issue is likely that the model is stiff.
    solver_options: Dict[str, Any]
        - Options to pass to the solver. See torchdiffeq' `odeint` method for more details.
    start_time: float
        - The start time of the model. This is used to align the `start_state` from the
          AMR model with the simulation timepoints.
        - By default we set the `start_time` to be 0.
    inferred_parameters: Optional[pyro.nn.PyroModule]
        - A Pyro module that contains the inferred parameters of the model.
          This is typically the result of `calibrate`.
        - If not provided, we will use the default values from the AMR model.
    static_state_interventions: Dict[float, Dict[str, Intervention]]
        - A dictionary of static interventions to apply to the model.
        - Each key is the time at which the intervention is applied.
        - Each value is a dictionary of the form {state_variable_name: intervention_assignment}.
        - Note that the `intervention_assignment` can be any type supported by
          :func:`~chirho.interventional.ops.intervene`, including functions.
    static_parameter_interventions: Dict[float, Dict[str, Intervention]]
        - A dictionary of static interventions to apply to the model.
        - Each key is the time at which the intervention is applied.
        - Each value is a dictionary of the form {parameter_name: intervention_assignment}.
        - Note that the `intervention_assignment` can be any type supported by
          :func:`~chirho.interventional.ops.intervene`, including functions.
    dynamic_state_interventions: Dict[
                                    Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],
                                    Dict[str, Intervention]
                                    ]
        - A dictionary of dynamic interventions to apply to the model.
        - Each key is a function that takes in the current state of the model and returns a tensor.
          When this function crosses 0, the dynamic intervention is applied.
        - Each value is a dictionary of the form {state_variable_name: intervention_assignment}.
        - Note that the `intervention_assignment` can be any type supported by
          :func:`~chirho.interventional.ops.intervene`, including functions.
    dynamic_parameter_interventions: Dict[
                                        Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],
                                        Dict[str, Intervention]
                                        ]
        - A dictionary of dynamic interventions to apply to the model.
        - Each key is a function that takes in the current state of the model and returns a tensor.
          When this function crosses 0, the dynamic intervention is applied.
        - Each value is a dictionary of the form {parameter_name: intervention_assignment}.
        - Note that the `intervention_assignment` can be any type supported by
          :func:`~chirho.interventional.ops.intervene`, including functions.
    alpha: float
        - Risk level for alpha-superquantile outputs in the results dictionary.

Returns:
    result: Dict[str, torch.Tensor]
        - Dictionary of outputs with following attributes:
            - data: The samples from the model as a pandas DataFrame.
            - unprocessed_result: Dictionary of outputs from the model.
                - Each key is the name of a parameter or state variable in the model.
                - Each value is a tensor of shape (num_samples, num_timepoints) for state variables
                and (num_samples,) for parameters.
            - quantiles: The quantiles for ensemble score calculation as a pandas DataFrames.
            - risk: Dictionary with each key as the name of a state with
            a dictionary of risk estimates for each state at the final timepoint.
                - risk: alpha-superquantile risk estimate
                Superquantiles can be intuitively thought of as a tail expectation, or an average
                over a portion of worst-case outcomes. Given a distribution of a
                quantity of interest (QoI), the superquantile at level \alpha\in[0, 1] is
                the expected value of the largest 100(1 -\alpha)% realizations of the QoI.
                - qoi: Samples of quantity of interest (value of the state at the final timepoint)
            - schema: Visualization. (If visual_options is truthy)

            ################################

Traceback (most recent call last): File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 174, in call ret = self.fn(*args, kwargs) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap return fn(*args, *kwargs) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap return fn(args, kwargs) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(*args, kwargs) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap return fn(*args, *kwargs) File "/Users/altu809/Projects/pyciemss/pyciemss/interfaces.py", line 305, in wrapped_model full_trajectory = model( File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/nn/module.py", line 449, in call result = super().call(args, kwargs) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, kwargs) File "/Users/altu809/Projects/pyciemss/pyciemss/compiled_dynamics.py", line 77, in forward simulate(self.deriv, self.initial_state(), start_time, end_time) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn apply_stack(msg) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack frame._process_message(msg) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message return method(msg) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/solver.py", line 109, in _pyro_simulate state, start_time, next_interruption = simulate_to_interruption( File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn apply_stack(msg) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack frame._process_message(msg) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message return method(msg) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py", line 89, in _pyro_simulate_to_interruption msg["value"] = torchdiffeq_simulate_to_interruption( File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 248, in torchdiffeq_simulate_to_interruption value = simulate_point( File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn apply_stack(msg) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack frame._process_message(msg) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message return method(msg) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/handlers/trajectory.py", line 97, in _pyro_simulate_point trajectory: State[T] = simulate_trajectory( File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn apply_stack(msg) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack frame._process_message(msg) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message return method(msg) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py", line 77, in _pyro_simulate_trajectory msg["value"] = torchdiffeq_simulate_trajectory( File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 165, in torchdiffeq_simulate_trajectory return _torchdiffeq_ode_simulate_inner(dynamics, initial_state, timespan, kwargs) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 71, in _torchdiffeq_ode_simulate_inner solns = _batched_odeint( # torchdiffeq.odeint( File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 123, in _batched_odeint yt_raw = torchdiffeq.odeint(func, y0_expanded, t, **odeint_kwargs) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/odeint.py", line 77, in odeint solution = solver.integrate(t) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/solvers.py", line 28, in integrate self._before_integrate(t) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/rk_common.py", line 163, in _before_integrate first_step = _select_initial_step(self.func, t[0], self.y0, self.order - 1, self.rtol, self.atol, File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/misc.py", line 54, in _select_initial_step d1 = norm(f0 / scale) RuntimeError: The size of tensor a (300) must match the size of tensor b (3) at non-singleton dimension 0

The above exception was the direct cause of the following exception:

Traceback (most recent call last): File "/Users/altu809/Projects/pyciemss/pyciemss/integration_utils/custom_decorators.py", line 10, in wrapped result = function(*args, kwargs) File "/Users/altu809/Projects/pyciemss/pyciemss/interfaces.py", line 332, in sample samples = pyro.infer.Predictive( File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, *kwargs) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py", line 273, in forward return _predictive( File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py", line 137, in _predictive trace = poutine.trace( File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 198, in get_trace self(args, kwargs) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 180, in call raise exc from e File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 174, in call ret = self.fn(*args, kwargs) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap return fn(*args, *kwargs) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap return fn(args, kwargs) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context return func(*args, kwargs) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap return fn(*args, *kwargs) File "/Users/altu809/Projects/pyciemss/pyciemss/interfaces.py", line 305, in wrapped_model full_trajectory = model( File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/nn/module.py", line 449, in call result = super().call(args, kwargs) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, kwargs) File "/Users/altu809/Projects/pyciemss/pyciemss/compiled_dynamics.py", line 77, in forward simulate(self.deriv, self.initial_state(), start_time, end_time) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn apply_stack(msg) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack frame._process_message(msg) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message return method(msg) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/solver.py", line 109, in _pyro_simulate state, start_time, next_interruption = simulate_to_interruption( File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn apply_stack(msg) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack frame._process_message(msg) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message return method(msg) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py", line 89, in _pyro_simulate_to_interruption msg["value"] = torchdiffeq_simulate_to_interruption( File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 248, in torchdiffeq_simulate_to_interruption value = simulate_point( File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn apply_stack(msg) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack frame._process_message(msg) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message return method(msg) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/handlers/trajectory.py", line 97, in _pyro_simulate_point trajectory: State[T] = simulate_trajectory( File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn apply_stack(msg) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack frame._process_message(msg) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message return method(msg) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py", line 77, in _pyro_simulate_trajectory msg["value"] = torchdiffeq_simulate_trajectory( File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 165, in torchdiffeq_simulate_trajectory return _torchdiffeq_ode_simulate_inner(dynamics, initial_state, timespan, kwargs) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 71, in _torchdiffeq_ode_simulate_inner solns = _batched_odeint( # torchdiffeq.odeint( File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 123, in _batched_odeint yt_raw = torchdiffeq.odeint(func, y0_expanded, t, **odeint_kwargs) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/odeint.py", line 77, in odeint solution = solver.integrate(t) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/solvers.py", line 28, in integrate self._before_integrate(t) File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/rk_common.py", line 163, in _before_integrate first_step = _select_initial_step(self.func, t[0], self.y0, self.order - 1, self.rtol, self.atol, File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/misc.py", line 54, in _select_initial_step d1 = norm(f0 / scale) RuntimeError: The size of tensor a (300) must match the size of tensor b (3) at non-singleton dimension 0 Trace Shapes:
Param Sites:
numeric_deriv_func$$$_nodes.0._args.0._value
numeric_deriv_func$$$_nodes.1._args.0._value
Sample Sites:
persistent_p_cbeta dist 100 | value 100 | persistent_p_tr dist 100 | value 100 |

RuntimeError Traceback (most recent call last) File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.call(self, *args, *kwargs) 173 try: --> 174 ret = self.fn(args, **kwargs) 175 except (ValueError, RuntimeError) as e:

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, *kwargs) 11 with context: ---> 12 return fn(args, **kwargs)

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, *kwargs) 11 with context: ---> 12 return fn(args, **kwargs)

File ~/anaconda3/lib/python3.10/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.call..decorate_context(*args, *kwargs) 26 with self.clone(): ---> 27 return func(args, **kwargs)

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, *kwargs) 11 with context: ---> 12 return fn(args, **kwargs)

File ~/Projects/pyciemss/pyciemss/interfaces.py:305, in sample..wrapped_model() 304 stack.enter_context(handler) --> 305 full_trajectory = model( 306 torch.as_tensor(start_time), 307 torch.as_tensor(end_time), 308 logging_times=logging_times, 309 is_traced=True, 310 ) 312 if noise_model is not None:

File ~/anaconda3/lib/python3.10/site-packages/pyro/nn/module.py:449, in PyroModule.call(self, *args, *kwargs) 448 with self._pyro_context: --> 449 result = super().call(args, **kwargs) 450 if ( 451 pyro.settings.get("validate_poutine") 452 and not self._pyro_context.active 453 and _is_module_local_param_enabled() 454 ):

File ~/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, *kwargs) 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(input, **kwargs) 1131 # Do not call functions when jit is used

File ~/Projects/pyciemss/pyciemss/compiled_dynamics.py:77, in CompiledDynamics.forward(self, start_time, end_time, logging_times, is_traced) 76 with LogTrajectory(logging_times) as lt: ---> 77 simulate(self.deriv, self.initial_state(), start_time, end_time) 78 state = lt.trajectory

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:281, in effectful.._fn(*args, **kwargs) 280 # apply the stack and return its return value --> 281 apply_stack(msg) 282 return msg["value"]

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:212, in apply_stack(initial_msg) 210 pointer = pointer + 1 --> 212 frame._process_message(msg) 214 if msg["stop"]:

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:162, in Messenger._process_message(self, msg) 161 if method is not None: --> 162 return method(msg) 163 return None

File ~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/solver.py:109, in Solver._pyro_simulate(self, msg) 107 break --> 109 state, start_time, next_interruption = simulate_to_interruption( 110 possible_interruptions, 111 dynamics, 112 state, 113 start_time, 114 end_time, 115 **msg["kwargs"], 116 ) 118 if next_interruption is not None:

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:281, in effectful.._fn(*args, **kwargs) 280 # apply the stack and return its return value --> 281 apply_stack(msg) 282 return msg["value"]

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:212, in apply_stack(initial_msg) 210 pointer = pointer + 1 --> 212 frame._process_message(msg) 214 if msg["stop"]:

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:162, in Messenger._process_message(self, msg) 161 if method is not None: --> 162 return method(msg) 163 return None

File ~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py:89, in TorchDiffEq._pyro_simulate_to_interruption(self, msg) 88 msg["kwargs"].update(self.odeint_kwargs) ---> 89 msg["value"] = torchdiffeq_simulate_to_interruption( 90 interruptions, 91 dynamics, 92 initial_state, 93 start_time, 94 end_time, 95 **msg["kwargs"], 96 ) 97 msg["done"] = True

File ~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py:248, in torchdiffeq_simulate_to_interruption(interruptions, dynamics, initial_state, start_time, end_time, kwargs) 244 (next_interruption,), interruption_time = _torchdiffeq_get_next_interruptions( 245 dynamics, initial_state, start_time, interruptions, kwargs 246 ) --> 248 value = simulate_point( 249 dynamics, initial_state, start_time, interruption_time, **kwargs 250 ) 251 return value, interruption_time, next_interruption

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:281, in effectful.._fn(*args, **kwargs) 280 # apply the stack and return its return value --> 281 apply_stack(msg) 282 return msg["value"]

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:212, in apply_stack(initial_msg) 210 pointer = pointer + 1 --> 212 frame._process_message(msg) 214 if msg["stop"]:

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:162, in Messenger._process_message(self, msg) 161 if method is not None: --> 162 return method(msg) 163 return None

File ~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/handlers/trajectory.py:97, in LogTrajectory._pyro_simulate_point(self, msg) 96 with pyro.poutine.messenger.block_messengers(lambda m: m is self): ---> 97 trajectory: State[T] = simulate_trajectory( 98 dynamics, initial_state, timespan, **msg["kwargs"] 99 ) 101 # TODO support dim != -1

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:281, in effectful.._fn(*args, **kwargs) 280 # apply the stack and return its return value --> 281 apply_stack(msg) 282 return msg["value"]

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:212, in apply_stack(initial_msg) 210 pointer = pointer + 1 --> 212 frame._process_message(msg) 214 if msg["stop"]:

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:162, in Messenger._process_message(self, msg) 161 if method is not None: --> 162 return method(msg) 163 return None

File ~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py:77, in TorchDiffEq._pyro_simulate_trajectory(self, msg) 75 msg["kwargs"].update(self.odeint_kwargs) ---> 77 msg["value"] = torchdiffeq_simulate_trajectory( 78 dynamics, initial_state, timespan, **msg["kwargs"] 79 ) 80 msg["done"] = True

File ~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py:165, in torchdiffeq_simulate_trajectory(dynamics, initial_state, timespan, kwargs) 159 def torchdiffeq_simulate_trajectory( 160 dynamics: Dynamics[torch.Tensor], 161 initial_state: State[torch.Tensor], 162 timespan: torch.Tensor, 163 kwargs, 164 ) -> State[torch.Tensor]: --> 165 return _torchdiffeq_ode_simulate_inner(dynamics, initial_state, timespan, **kwargs)

File ~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py:71, in _torchdiffeq_ode_simulate_inner(dynamics, initial_state, timespan, odeint_kwargs) 70 if torch.any(diff): ---> 71 solns = _batched_odeint( # torchdiffeq.odeint( 72 functools.partial(_deriv, dynamics, var_order), 73 tuple(initial_state[v] for v in varorder), 74 timespan, 75 odeint_kwargs, 76 ) 77 else:

File ~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py:123, in _batched_odeint(func, y0, t, event_fn, odeint_kwargs) 122 else: --> 123 yt_raw = torchdiffeq.odeint(func, y0_expanded, t, odeintkwargs) 125 yt = tuple( 126 torch.transpose( 127 yt[(..., None) + yt.shape[len(yt.shape) - eventdim :]], (...) 131 for yt in yt_raw 132 )

File ~/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/odeint.py:77, in odeint(func, y0, t, rtol, atol, method, options, event_fn) 76 if event_fn is None: ---> 77 solution = solver.integrate(t) 78 else:

File ~/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/solvers.py:28, in AdaptiveStepsizeODESolver.integrate(self, t) 27 t = t.to(self.dtype) ---> 28 self._before_integrate(t) 29 for i in range(1, len(t)):

File ~/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/rk_common.py:163, in RKAdaptiveStepsizeODESolver._before_integrate(self, t) 162 if self.first_step is None: --> 163 first_step = _select_initial_step(self.func, t[0], self.y0, self.order - 1, self.rtol, self.atol, 164 self.norm, f0=f0) 165 else:

File ~/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/misc.py:54, in _select_initial_step(func, t0, y0, order, rtol, atol, norm, f0) 53 d0 = norm(y0 / scale) ---> 54 d1 = norm(f0 / scale) 56 if d0 < 1e-5 or d1 < 1e-5:

RuntimeError: The size of tensor a (300) must match the size of tensor b (3) at non-singleton dimension 0

The above exception was the direct cause of the following exception:

RuntimeError Traceback (most recent call last) Cell In[10], line 6 3 logging_step_size = 1.0 4 num_samples = 5 if smoke_test else 100 ----> 6 result = pyciemss.sample(model3, end_time, logging_step_size, num_samples, start_time=start_time, 7 static_parameter_interventions={torch.tensor(1.): {"p_cbeta": torch.tensor(0.5)}}) 8 display(result["data"].head()) 10 # Plot the result

File ~/Projects/pyciemss/pyciemss/integration_utils/custom_decorators.py:29, in pyciemss_logging_wrapper..wrapped(*args, **kwargs) 17 log_message = """ 18 ############################### 19 (...) 26 ################################ 27 """ 28 logging.exception(log_message, function.name, function.doc) ---> 29 raise e

File ~/Projects/pyciemss/pyciemss/integration_utils/custom_decorators.py:10, in pyciemss_logging_wrapper..wrapped(*args, *kwargs) 8 try: 9 start_time = time.perf_counter() ---> 10 result = function(args, **kwargs) 11 end_time = time.perf_counter() 12 logging.info( 13 "Elapsed time for %s: %f", function.name, end_time - start_time 14 )

File ~/Projects/pyciemss/pyciemss/interfaces.py:332, in sample(model_path_or_json, end_time, logging_step_size, num_samples, noise_model, noise_model_kwargs, solver_method, solver_options, start_time, time_unit, inferred_parameters, static_state_interventions, static_parameter_interventions, dynamic_state_interventions, dynamic_parameter_interventions, alpha) 320 compiled_noise_model(full_trajectory) 322 parallel = ( 323 False 324 if len( (...) 329 else True 330 ) --> 332 samples = pyro.infer.Predictive( 333 wrapped_model, 334 guide=inferred_parameters, 335 num_samples=num_samples, 336 parallel=parallel, 337 )() 339 risk_results = {} 340 for k, vals in samples.items():

File ~/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, *kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py:273, in Predictive.forward(self, *args, **kwargs) 263 return_sites = None if not return_sites else return_sites 264 posterior_samples = _predictive( 265 self.guide, 266 posterior_samples, (...) 271 model_kwargs=kwargs, 272 ) --> 273 return _predictive( 274 self.model, 275 posterior_samples, 276 self.num_samples, 277 return_sites=return_sites, 278 parallel=self.parallel, 279 model_args=args, 280 model_kwargs=kwargs, 281 )

File ~/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py:137, in _predictive(model, posterior_samples, num_samples, return_sites, return_trace, parallel, model_args, model_kwargs) 126 if not parallel: 127 return _predictive_sequential( 128 model, 129 posterior_samples, (...) 134 return_trace=False, 135 ) --> 137 trace = poutine.trace( 138 poutine.condition(vectorize(model), reshaped_samples) 139 ).get_trace(*model_args, **model_kwargs) 140 predictions = {} 141 for site, shape in return_site_shapes.items():

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:198, in TraceHandler.get_trace(self, *args, kwargs) 190 def get_trace(self, *args, *kwargs): 191 """ 192 :returns: data structure 193 :rtype: pyro.poutine.Trace (...) 196 Calls this poutine and returns its trace instead of the function's return value. 197 """ --> 198 self(args, kwargs) 199 return self.msngr.get_trace()

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:180, in TraceHandler.call(self, *args, **kwargs) 178 exc = exc_type("{}\n{}".format(exc_value, shapes)) 179 exc = exc.with_traceback(traceback) --> 180 raise exc from e 181 self.msngr.trace.add_node( 182 "_RETURN", name="_RETURN", type="return", value=ret 183 ) 184 return ret

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py:174, in TraceHandler.call(self, *args, *kwargs) 170 self.msngr.trace.add_node( 171 "_INPUT", name="_INPUT", type="args", args=args, kwargs=kwargs 172 ) 173 try: --> 174 ret = self.fn(args, **kwargs) 175 except (ValueError, RuntimeError) as e: 176 exc_type, exc_value, traceback = sys.exc_info()

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, kwargs) 10 def _context_wrap(context, fn, *args, *kwargs): 11 with context: ---> 12 return fn(args, kwargs)

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, kwargs) 10 def _context_wrap(context, fn, *args, *kwargs): 11 with context: ---> 12 return fn(args, kwargs)

File ~/anaconda3/lib/python3.10/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.call..decorate_context(*args, kwargs) 24 @functools.wraps(func) 25 def decorate_context(*args, *kwargs): 26 with self.clone(): ---> 27 return func(args, kwargs)

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:12, in _context_wrap(context, fn, *args, kwargs) 10 def _context_wrap(context, fn, *args, *kwargs): 11 with context: ---> 12 return fn(args, kwargs)

File ~/Projects/pyciemss/pyciemss/interfaces.py:305, in sample..wrapped_model() 303 for handler in intervention_handlers: 304 stack.enter_context(handler) --> 305 full_trajectory = model( 306 torch.as_tensor(start_time), 307 torch.as_tensor(end_time), 308 logging_times=logging_times, 309 is_traced=True, 310 ) 312 if noise_model is not None: 313 compiled_noise_model = compile_noise_model( 314 noise_model, 315 vars=set(full_trajectory.keys()), 316 observables=model.observables, 317 **noise_model_kwargs, 318 )

File ~/anaconda3/lib/python3.10/site-packages/pyro/nn/module.py:449, in PyroModule.call(self, *args, kwargs) 447 def call(self, *args, *kwargs): 448 with self._pyro_context: --> 449 result = super().call(args, kwargs) 450 if ( 451 pyro.settings.get("validate_poutine") 452 and not self._pyro_context.active 453 and _is_module_local_param_enabled() 454 ): 455 self._check_module_local_param_usage()

File ~/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1130, in Module._call_impl(self, *input, *kwargs) 1126 # If we don't have any hooks, we want to skip the rest of the logic in 1127 # this function, and just call forward. 1128 if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks 1129 or _global_forward_hooks or _global_forward_pre_hooks): -> 1130 return forward_call(input, **kwargs) 1131 # Do not call functions when jit is used 1132 full_backward_hooks, non_full_backward_hooks = [], []

File ~/Projects/pyciemss/pyciemss/compiled_dynamics.py:77, in CompiledDynamics.forward(self, start_time, end_time, logging_times, is_traced) 75 if logging_times is not None: 76 with LogTrajectory(logging_times) as lt: ---> 77 simulate(self.deriv, self.initial_state(), start_time, end_time) 78 state = lt.trajectory 79 else:

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:281, in effectful.._fn(*args, **kwargs) 264 msg = { 265 "type": type, 266 "name": name, (...) 278 "infer": infer, 279 } 280 # apply the stack and return its return value --> 281 apply_stack(msg) 282 return msg["value"]

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:212, in apply_stack(initial_msg) 209 for frame in reversed(stack): 210 pointer = pointer + 1 --> 212 frame._process_message(msg) 214 if msg["stop"]: 215 break

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:162, in Messenger._process_message(self, msg) 160 method = getattr(self, "pyro{}".format(msg["type"]), None) 161 if method is not None: --> 162 return method(msg) 163 return None

File ~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/solver.py:109, in Solver._pyro_simulate(self, msg) 106 if ph.priority > start_time: 107 break --> 109 state, start_time, next_interruption = simulate_to_interruption( 110 possible_interruptions, 111 dynamics, 112 state, 113 start_time, 114 end_time, 115 **msg["kwargs"], 116 ) 118 if next_interruption is not None: 119 dynamics, state = next_interruption.callback(dynamics, state)

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:281, in effectful.._fn(*args, **kwargs) 264 msg = { 265 "type": type, 266 "name": name, (...) 278 "infer": infer, 279 } 280 # apply the stack and return its return value --> 281 apply_stack(msg) 282 return msg["value"]

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:212, in apply_stack(initial_msg) 209 for frame in reversed(stack): 210 pointer = pointer + 1 --> 212 frame._process_message(msg) 214 if msg["stop"]: 215 break

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:162, in Messenger._process_message(self, msg) 160 method = getattr(self, "pyro{}".format(msg["type"]), None) 161 if method is not None: --> 162 return method(msg) 163 return None

File ~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py:89, in TorchDiffEq._pyro_simulate_to_interruption(self, msg) 87 interruptions, dynamics, initial_state, start_time, end_time = msg["args"] 88 msg["kwargs"].update(self.odeint_kwargs) ---> 89 msg["value"] = torchdiffeq_simulate_to_interruption( 90 interruptions, 91 dynamics, 92 initial_state, 93 start_time, 94 end_time, 95 **msg["kwargs"], 96 ) 97 msg["done"] = True

File ~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py:248, in torchdiffeq_simulate_to_interruption(interruptions, dynamics, initial_state, start_time, end_time, kwargs) 242 assert len(interruptions) > 0, "should have at least one interruption here" 244 (next_interruption,), interruption_time = _torchdiffeq_get_next_interruptions( 245 dynamics, initial_state, start_time, interruptions, kwargs 246 ) --> 248 value = simulate_point( 249 dynamics, initial_state, start_time, interruption_time, **kwargs 250 ) 251 return value, interruption_time, next_interruption

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:281, in effectful.._fn(*args, **kwargs) 264 msg = { 265 "type": type, 266 "name": name, (...) 278 "infer": infer, 279 } 280 # apply the stack and return its return value --> 281 apply_stack(msg) 282 return msg["value"]

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:212, in apply_stack(initial_msg) 209 for frame in reversed(stack): 210 pointer = pointer + 1 --> 212 frame._process_message(msg) 214 if msg["stop"]: 215 break

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:162, in Messenger._process_message(self, msg) 160 method = getattr(self, "pyro{}".format(msg["type"]), None) 161 if method is not None: --> 162 return method(msg) 163 return None

File ~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/handlers/trajectory.py:97, in LogTrajectory._pyro_simulate_point(self, msg) 92 timespan = torch.concat( 93 (start_time.unsqueeze(-1), filtered_timespan, end_time.unsqueeze(-1)) 94 ) 96 with pyro.poutine.messenger.block_messengers(lambda m: m is self): ---> 97 trajectory: State[T] = simulate_trajectory( 98 dynamics, initial_state, timespan, **msg["kwargs"] 99 ) 101 # TODO support dim != -1 102 idx_name = "__time"

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:281, in effectful.._fn(*args, **kwargs) 264 msg = { 265 "type": type, 266 "name": name, (...) 278 "infer": infer, 279 } 280 # apply the stack and return its return value --> 281 apply_stack(msg) 282 return msg["value"]

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py:212, in apply_stack(initial_msg) 209 for frame in reversed(stack): 210 pointer = pointer + 1 --> 212 frame._process_message(msg) 214 if msg["stop"]: 215 break

File ~/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py:162, in Messenger._process_message(self, msg) 160 method = getattr(self, "pyro{}".format(msg["type"]), None) 161 if method is not None: --> 162 return method(msg) 163 return None

File ~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py:77, in TorchDiffEq._pyro_simulate_trajectory(self, msg) 74 dynamics, initial_state, timespan = msg["args"] 75 msg["kwargs"].update(self.odeint_kwargs) ---> 77 msg["value"] = torchdiffeq_simulate_trajectory( 78 dynamics, initial_state, timespan, **msg["kwargs"] 79 ) 80 msg["done"] = True

File ~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py:165, in torchdiffeq_simulate_trajectory(dynamics, initial_state, timespan, kwargs) 159 def torchdiffeq_simulate_trajectory( 160 dynamics: Dynamics[torch.Tensor], 161 initial_state: State[torch.Tensor], 162 timespan: torch.Tensor, 163 kwargs, 164 ) -> State[torch.Tensor]: --> 165 return _torchdiffeq_ode_simulate_inner(dynamics, initial_state, timespan, **kwargs)

File ~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py:71, in _torchdiffeq_ode_simulate_inner(dynamics, initial_state, timespan, odeint_kwargs) 68 time_dim = -1 70 if torch.any(diff): ---> 71 solns = _batched_odeint( # torchdiffeq.odeint( 72 functools.partial(_deriv, dynamics, var_order), 73 tuple(initial_state[v] for v in varorder), 74 timespan, 75 odeint_kwargs, 76 ) 77 else: 78 solns = tuple(initial_state[v].unsqueeze(time_dim) for v in var_order)

File ~/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py:123, in _batched_odeint(func, y0, t, event_fn, odeint_kwargs) 119 event_t, yt_raw = torchdiffeq.odeint_event( 120 func, y0_expanded, t, event_fn=event_fn, odeint_kwargs 121 ) 122 else: --> 123 yt_raw = torchdiffeq.odeint(func, y0_expanded, t, **odeintkwargs) 125 yt = tuple( 126 torch.transpose( 127 yt[(..., None) + yt.shape[len(yt.shape) - eventdim :]], (...) 131 for yt in yt_raw 132 ) 133 return yt if event_fn is None else (event_t, yt)

File ~/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/odeint.py:77, in odeint(func, y0, t, rtol, atol, method, options, event_fn) 74 solver = SOLVERS[method](func=func, y0=y0, rtol=rtol, atol=atol, **options) 76 if event_fn is None: ---> 77 solution = solver.integrate(t) 78 else: 79 event_t, solution = solver.integrate_until_event(t[0], event_fn)

File ~/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/solvers.py:28, in AdaptiveStepsizeODESolver.integrate(self, t) 26 solution[0] = self.y0 27 t = t.to(self.dtype) ---> 28 self._before_integrate(t) 29 for i in range(1, len(t)): 30 solution[i] = self._advance(t[i])

File ~/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/rk_common.py:163, in RKAdaptiveStepsizeODESolver._before_integrate(self, t) 161 f0 = self.func(t[0], self.y0) 162 if self.first_step is None: --> 163 first_step = _select_initial_step(self.func, t[0], self.y0, self.order - 1, self.rtol, self.atol, 164 self.norm, f0=f0) 165 else: 166 first_step = self.first_step

File ~/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/misc.py:54, in _select_initial_step(func, t0, y0, order, rtol, atol, norm, f0) 51 scale = atol + torch.abs(y0) * rtol 53 d0 = norm(y0 / scale) ---> 54 d1 = norm(f0 / scale) 56 if d0 < 1e-5 or d1 < 1e-5: 57 h0 = torch.tensor(1e-6, dtype=dtype, device=device)

RuntimeError: The size of tensor a (300) must match the size of tensor b (3) at non-singleton dimension 0 Trace Shapes:
Param Sites:
numeric_deriv_func$$$_nodes.0._args.0._value
numeric_deriv_func$$$_nodes.1._args.0._value
Sample Sites:
persistent_p_cbeta dist 100 | value 100 | persistent_p_tr dist 100 | value 100 |

sabinala commented 4 months ago

See this notebook to reproduce the error: https://github.com/ciemss/pyciemss/blob/524-runtimeerror-when-sampling-with-interventions/docs/source/Notebook%20for%20issue%20524.ipynb

SamWitty commented 4 months ago

@sabinala , I do not get an error when running that notebook locally. Have you updated your jupyter kernel since #509 ? The error you're seeing looks like it's using an outdated version of ChiRho? I would recommend building a fresh environment and trying again.