ciemss / pyciemss

Causal and probabilistic reasoning with continuous time dynamical systems
Other
12 stars 4 forks source link

Error in the Sample function when running with `num_samples` > 1 #496

Closed liunelson closed 4 months ago

liunelson commented 4 months ago

I'm trying to run a baseline scenario to compare with the results of an Optimize run.

MODELS_PATH = "https://raw.githubusercontent.com/DARPA-ASKEM/simulation-integration/main/data/models/"
model3 = os.path.join(MODELS_PATH, "SIR_stockflow.json")

start_time = 0.0
end_time = 50.0
logging_step_size = 1.0
num_samples = 1

results_baseline = pyciemss.sample(
    model3, 
    end_time, 
    logging_step_size, 
    num_samples, 
    start_time = start_time, 
    # static_parameter_interventions = {torch.tensor(0.0): {intervened_params: torch.tensor(0.35)}},
    solver_method = "euler"
)

This function call works but it raises this error whenever num_samples is not 1. However, I can get it to run without error using num_samples = 100 for example if I supply a static_parameter_interventions.

ERROR:root:
                ###############################

                There was an exception in pyciemss

                Error occured in function: sample

                Function docs : 
    Load a model from a file, compile it into a probabilistic program, and sample from it.

    Args:
        model_path_or_json: Union[str, Dict]
            - A path to a AMR model file or JSON containing a model in AMR form.
        end_time: float
            - The end time of the sampled simulation.
        logging_step_size: float
            - The step size to use for logging the trajectory.
        num_samples: int
            - The number of samples to draw from the model.
        solver_method: str
            - The method to use for solving the ODE. See torchdiffeq's `odeint` method for more details.
            - If performance is incredibly slow, we suggest using `euler` to debug.
              If using `euler` results in faster simulation, the issue is likely that the model is stiff.
        solver_options: Dict[str, Any]
            - Options to pass to the solver. See torchdiffeq' `odeint` method for more details.
        start_time: float
            - The start time of the model. This is used to align the `start_state` from the
              AMR model with the simulation timepoints.
            - By default we set the `start_time` to be 0.
        inferred_parameters: Optional[pyro.nn.PyroModule]
            - A Pyro module that contains the inferred parameters of the model.
              This is typically the result of `calibrate`.
            - If not provided, we will use the default values from the AMR model.
        static_state_interventions: Dict[float, Dict[str, Intervention]]
            - A dictionary of static interventions to apply to the model.
            - Each key is the time at which the intervention is applied.
            - Each value is a dictionary of the form {state_variable_name: intervention_assignment}.
            - Note that the `intervention_assignment` can be any type supported by
              :func:`~chirho.interventional.ops.intervene`, including functions.
        static_parameter_interventions: Dict[float, Dict[str, Intervention]]
            - A dictionary of static interventions to apply to the model.
            - Each key is the time at which the intervention is applied.
            - Each value is a dictionary of the form {parameter_name: intervention_assignment}.
            - Note that the `intervention_assignment` can be any type supported by
              :func:`~chirho.interventional.ops.intervene`, including functions.
        dynamic_state_interventions: Dict[
                                        Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],
                                        Dict[str, Intervention]
                                        ]
            - A dictionary of dynamic interventions to apply to the model.
            - Each key is a function that takes in the current state of the model and returns a tensor.
              When this function crosses 0, the dynamic intervention is applied.
            - Each value is a dictionary of the form {state_variable_name: intervention_assignment}.
            - Note that the `intervention_assignment` can be any type supported by
              :func:`~chirho.interventional.ops.intervene`, including functions.
        dynamic_parameter_interventions: Dict[
                                            Callable[[torch.Tensor, Dict[str, torch.Tensor]], torch.Tensor],
                                            Dict[str, Intervention]
                                            ]
            - A dictionary of dynamic interventions to apply to the model.
            - Each key is a function that takes in the current state of the model and returns a tensor.
              When this function crosses 0, the dynamic intervention is applied.
            - Each value is a dictionary of the form {parameter_name: intervention_assignment}.
            - Note that the `intervention_assignment` can be any type supported by
              :func:`~chirho.interventional.ops.intervene`, including functions.

    Returns:
        result: Dict[str, torch.Tensor]
            - Dictionary of outputs from the model.
                - Each key is the name of a parameter or state variable in the model.
                - Each value is a tensor of shape (num_samples, num_timepoints) for state variables
                    and (num_samples,) for parameters.

                ################################

Traceback (most recent call last):
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 174, in __call__
    ret = self.fn(*args, **kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "/home/nliu/projects/askem/pyciemss/pyciemss/interfaces.py", line 282, in wrapped_model
    full_trajectory = model(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/nn/module.py", line 449, in __call__
    result = super().__call__(*args, **kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/nliu/projects/askem/pyciemss/pyciemss/compiled_dynamics.py", line 77, in forward
    simulate(self.deriv, self.initial_state(), start_time, end_time)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
    apply_stack(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
    frame._process_message(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
    return method(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/internals/solver.py", line 109, in _pyro_simulate
    state, start_time, next_interruption = simulate_to_interruption(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
    apply_stack(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
    frame._process_message(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
    return method(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py", line 89, in _pyro_simulate_to_interruption
    msg["value"] = torchdiffeq_simulate_to_interruption(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 248, in torchdiffeq_simulate_to_interruption
    value = simulate_point(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
    apply_stack(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
    frame._process_message(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
    return method(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/handlers/trajectory.py", line 97, in _pyro_simulate_point
    trajectory: State[T] = simulate_trajectory(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
    apply_stack(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
    frame._process_message(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
    return method(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py", line 77, in _pyro_simulate_trajectory
    msg["value"] = torchdiffeq_simulate_trajectory(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 165, in torchdiffeq_simulate_trajectory
    return _torchdiffeq_ode_simulate_inner(dynamics, initial_state, timespan, **kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 71, in _torchdiffeq_ode_simulate_inner
    solns = _batched_odeint(  # torchdiffeq.odeint(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 123, in _batched_odeint
    yt_raw = torchdiffeq.odeint(func, y0_expanded, t, **odeint_kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/torchdiffeq/_impl/odeint.py", line 77, in odeint
    solution = solver.integrate(t)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/torchdiffeq/_impl/solvers.py", line 106, in integrate
    y1 = y0 + dy
RuntimeError: The size of tensor a (3) must match the size of tensor b (6) at non-singleton dimension 0

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/nliu/projects/askem/pyciemss/pyciemss/integration_utils/custom_decorators.py", line 10, in wrapped
    result = function(*args, **kwargs)
  File "/home/nliu/projects/askem/pyciemss/pyciemss/interfaces.py", line 298, in sample
    samples = pyro.infer.Predictive(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/infer/predictive.py", line 273, in forward
    return _predictive(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/infer/predictive.py", line 137, in _predictive
    trace = poutine.trace(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 198, in get_trace
    self(*args, **kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 180, in __call__
    raise exc from e
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 174, in __call__
    ret = self.fn(*args, **kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "/home/nliu/projects/askem/pyciemss/pyciemss/interfaces.py", line 282, in wrapped_model
    full_trajectory = model(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/nn/module.py", line 449, in __call__
    result = super().__call__(*args, **kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/nliu/projects/askem/pyciemss/pyciemss/compiled_dynamics.py", line 77, in forward
    simulate(self.deriv, self.initial_state(), start_time, end_time)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
    apply_stack(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
    frame._process_message(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
    return method(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/internals/solver.py", line 109, in _pyro_simulate
    state, start_time, next_interruption = simulate_to_interruption(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
    apply_stack(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
    frame._process_message(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
    return method(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py", line 89, in _pyro_simulate_to_interruption
    msg["value"] = torchdiffeq_simulate_to_interruption(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 248, in torchdiffeq_simulate_to_interruption
    value = simulate_point(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
    apply_stack(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
    frame._process_message(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
    return method(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/handlers/trajectory.py", line 97, in _pyro_simulate_point
    trajectory: State[T] = simulate_trajectory(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
    apply_stack(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
    frame._process_message(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
    return method(msg)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py", line 77, in _pyro_simulate_trajectory
    msg["value"] = torchdiffeq_simulate_trajectory(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 165, in torchdiffeq_simulate_trajectory
    return _torchdiffeq_ode_simulate_inner(dynamics, initial_state, timespan, **kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 71, in _torchdiffeq_ode_simulate_inner
    solns = _batched_odeint(  # torchdiffeq.odeint(
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 123, in _batched_odeint
    yt_raw = torchdiffeq.odeint(func, y0_expanded, t, **odeint_kwargs)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/torchdiffeq/_impl/odeint.py", line 77, in odeint
    solution = solver.integrate(t)
  File "/home/nliu/miniconda3/envs/CIEMSS_ENV/lib/python3.10/site-packages/torchdiffeq/_impl/solvers.py", line 106, in integrate
    y1 = y0 + dy
RuntimeError: The size of tensor a (3) must match the size of tensor b (6) at non-singleton dimension 0
                               Trace Shapes:    
                                Param Sites:    
numeric_deriv_func$$$_nodes.0._args.0._value    
numeric_deriv_func$$$_nodes.1._args.0._value    
                               Sample Sites:    
                     persistent_p_cbeta dist 2 |
                                       value 2 |
                        persistent_p_tr dist 2 |
                                       value 2 |
SamWitty commented 4 months ago

@liunelson , are you using the version of pyciemss on main or the last tagged release? I believe we addressed this issue with #491. If that doesn't do it, there should be a more robust upstream fix coming from ChiRho which was just merged in this morning. https://github.com/BasisResearch/chirho/pull/525

liunelson commented 4 months ago

@SamWitty Just updating here that your fix worked to resolve this issue.