ciemss / pyciemss

Causal and probabilistic reasoning with continuous time dynamical systems
Other
12 stars 4 forks source link

`ValueError` when running integration tests with `ensemble_calibrate` #540

Closed sabinala closed 3 months ago

sabinala commented 3 months ago

You can find the full error message here, this is the PyCIEMSS-specific part:

                ###############################

                There was an exception in pyciemss

                Error occured in function: ensemble_calibrate

                Function docs : 
    Infer parameters for an ensemble of DynamicalSystem models conditional on data.
    This uses variational inference with a mean-field variational family to infer the parameters of the model.

    Args:
    model_paths_or_jsons: List[Union[str, Dict]]
        - A list of paths to AMR model files or JSONs containing models in AMR form.
    solution_mappings: List[Callable[[Dict[str, torch.Tensor]], Dict[str, torch.Tensor]]]
        - A list of functions that map the solution of each model to a common solution space.
        - Each function takes in a dictionary of the form {state_variable_name: value}
            and returns a dictionary of the same form.
    data_path: str
        - A path to the data file.
    dirichlet_alpha: Optional[torch.Tensor]
        - A tensor of shape (num_models,) containing the Dirichlet alpha values for the ensemble.
            - A higher proportion of alpha values will result in higher weights for the corresponding models.
            - A larger total alpha values will result in more certain priors.
            - e.g. torch.tensor([1, 1, 1]) will result in a uniform prior over vectors of length 3 that sum to 1.
            - e.g. torch.tensor([1, 2, 3]) will result in a prior that is biased towards the third model.
        - If not provided, we will use a uniform Dirichlet prior.
    data_mapping: Dict[str, str]
        - A mapping from column names in the data file to state variable names in the model.
            - keys: str name of column in dataset
            - values: str name of state/observable in model
        - If not provided, we will assume that the column names in the data file match the state variable names.
        - Note: This mapping must match output of `solution_mappings`.
    noise_model: str
        - The noise model to use for the data.
        - Currently we only support the normal distribution.
    noise_model_kwargs: Dict[str, Any]
        - Keyword arguments to pass to the noise model.
        - Currently we only support the `scale` keyword argument for the normal distribution.
    solver_method: str
        - The method to use for solving the ODE. See torchdiffeq's `odeint` method for more details.
        - If performance is incredibly slow, we suggest using `euler` to debug.
            If using `euler` results in faster simulation, the issue is likely that the model is stiff.
    solver_options: Dict[str, Any]
        - Options to pass to the solver. See torchdiffeq' `odeint` method for more details.
    start_time: float
        - The start time of the model. This is used to align the `start_state` from the
            AMR model with the simulation timepoints.
        - By default we set the `start_time` to be 0.
    num_iterations: int
        - The number of iterations to run the inference algorithm for.
    lr: float
        - The learning rate to use for the inference algorithm.
    verbose: bool
        - Whether to print out the loss at each iteration.
    num_particles: int
        - The number of particles to use for the inference algorithm.
    deterministic_learnable_parameters: List[str]
        - A list of parameter names that should be learned deterministically.
        - By default, all parameters are learned probabilistically.
    progress_hook: Callable[[int, float], None]
        - A function that takes in the current iteration and the current loss.
        - This is called at the beginning of each iteration.
        - By default, this is a no-op.
        - This can be used to implement custom progress bars.

    Returns:
        result: Dict[str, Any]
            - Dictionary with the following key-value pairs.
                - inferred_parameters: pyro.nn.PyroModule
                    - A Pyro module that contains the inferred parameters of the model.
                    - This can be passed to `ensemble_sample` to sample from the model conditional on the data.
                - loss: float
                    - The final loss value of the approximate ELBO loss.

                ################################

Traceback (most recent call last):
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py", line 174, in __call__
    ret = self.fn(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/pyciemss/interfaces.py", line 303, in wrapped_model
    observe(solution, obs)
  File "/nix/store/y027d3bvlaizbri04c1bzh28hqd6lj01-python3-3.11.7/lib/python3.11/functools.py", line 909, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/chirho/observational/internals.py", line 56, in _observe_dict
    obs = obs(rv)
          ^^^^^^^
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/pyro/nn/module.py", line 449, in __call__
    result = super().__call__(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/pyciemss/observation.py", line 44, in forward
    self.markov_kernel(k, state[k]),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/pyciemss/observation.py", line 63, in markov_kernel
    return pyro.distributions.Normal(val, self.scale * torch.abs(val)).to_event(1)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/pyro/distributions/distribution.py", line 24, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/torch/distributions/normal.py", line 56, in __init__
    super().__init__(batch_shape, validate_args=validate_args)
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/torch/distributions/distribution.py", line 68, in __init__
    raise ValueError(
ValueError: Expected parameter scale (Tensor of shape (5,)) of distribution Normal(loc: torch.Size([5]), scale: torch.Size([5])) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values:
tensor([0.0000, 0.0357, 0.0796, 0.1264, 0.1738], grad_fn=<MulBackward0>)

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/pyciemss/integration_utils/custom_decorators.py", line 10, in wrapped
    result = function(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/pyciemss/interfaces.py", line 305, in ensemble_calibrate
    inferred_parameters = autoguide(wrapped_model)
                          ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/pyciemss/interfaces.py", line 273, in autoguide
    mvn_guide._setup_prototype()
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/pyro/infer/autoguide/guides.py", line 1002, in _setup_prototype
    super()._setup_prototype(*args, **kwargs)
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/pyro/infer/autoguide/guides.py", line 636, in _setup_prototype
    super()._setup_prototype(*args, **kwargs)
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/pyro/infer/autoguide/guides.py", line 157, in _setup_prototype
    self.prototype_trace = poutine.block(poutine.trace(model).get_trace)(
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py", line 198, in get_trace
    self(*args, **kwargs)
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py", line 180, in __call__
    raise exc from e
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/pyro/poutine/trace_messenger.py", line 174, in __call__
    ret = self.fn(*args, **kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/pyciemss/interfaces.py", line 303, in wrapped_model
    observe(solution, obs)
  File "/nix/store/y027d3bvlaizbri04c1bzh28hqd6lj01-python3-3.11.7/lib/python3.11/functools.py", line 909, in wrapper
    return dispatch(args[0].__class__)(*args, **kw)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/chirho/observational/internals.py", line 56, in _observe_dict
    obs = obs(rv)
          ^^^^^^^
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/pyro/nn/module.py", line 449, in __call__
    result = super().__call__(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/pyciemss/observation.py", line 44, in forward
    self.markov_kernel(k, state[k]),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/pyciemss/observation.py", line 63, in markov_kernel
    return pyro.distributions.Normal(val, self.scale * torch.abs(val)).to_event(1)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/pyro/distributions/distribution.py", line 24, in __call__
    return super().__call__(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/torch/distributions/normal.py", line 56, in __init__
    super().__init__(batch_shape, validate_args=validate_args)
  File "/home/five/Code/GitHub/pyciemss-service/.venv/lib/python3.11/site-packages/torch/distributions/distribution.py", line 68, in __init__
    raise ValueError(
ValueError: Expected parameter scale (Tensor of shape (5,)) of distribution Normal(loc: torch.Size([5]), scale: torch.Size([5])) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values:
tensor([0.0000, 0.0357, 0.0796, 0.1264, 0.1738], grad_fn=<MulBackward0>)
                     Trace Shapes:    
                      Param Sites:    
                     Sample Sites:    
                model_weights dist | 2
                             value | 2
    model_0/persistent_beta_c dist |  
                             value |  
     model_0/persistent_kappa dist |  
                             value |  
     model_0/persistent_gamma dist |  
                             value |  
      model_0/persistent_hosp dist |  
                             value |  
model_0/persistent_death_hosp dist |  
                             value |  
        model_0/persistent_I0 dist |  
                             value |  
      model_1/persistent_beta dist |  
                             value |  
     model_1/persistent_gamma dist |  
                             value |  
      model_1/persistent_hosp dist |  
                             value |  
model_1/persistent_death_hosp dist |  
                             value |  
        model_1/persistent_I0 dist |  
                             value |  
SamWitty commented 3 months ago

@sabinala , I am unable to diagnose this bug without seeing how it was generated. Could you please create a minimum working example that illustrates the issue using PyCIEMSS.

sabinala commented 3 months ago

@SamWitty see this notebook for error reproducing example: https://github.com/ciemss/pyciemss/blob/540-valueerror-when-running-integration-tests-with-ensemble_calibrate/docs/source/Issue540.ipynb

SamWitty commented 3 months ago

@sabinala , do you have a link to the model that you saved locally?

SamWitty commented 3 months ago

Closing as this turned out to be an issue with the dataset. See Slack for details: https://askemgroup.slack.com/archives/C05G7RDELKF/p1710793987089479