ciemss / pyciemss

Causal and probabilistic reasoning with continuous time dynamical systems
Other
17 stars 6 forks source link

Getting a RuntimeError when trying to sample a pde as petrinet #487

Closed sabinala closed 7 months ago

sabinala commented 7 months ago

Getting the following issue when trying result = pyciemss.sample(MODEL, end_time, logging_step_size, num_samples, start_time=start_time) where MODEL is an advection equation with backward derivative from here. See this notebook related to PR #460. See also this PR in DARPA-ASKEM/Model-Representations.

RuntimeError: The size of tensor a (9) must match the size of tensor b (3) at non-singleton dimension 0

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/Users/altu809/Projects/pyciemss/pyciemss/integration_utils/custom_decorators.py", line 10, in wrapped
    result = function(*args, **kwargs)
  File "/Users/altu809/Projects/pyciemss/pyciemss/interfaces.py", line 298, in sample
    samples = pyro.infer.Predictive(
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py", line 273, in forward
    return _predictive(
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/infer/predictive.py", line 137, in _predictive
    trace = poutine.trace(
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 198, in get_trace
    self(*args, **kwargs)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 180, in __call__
    raise exc from e
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 174, in __call__
    ret = self.fn(*args, **kwargs)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap
    return fn(*args, **kwargs)
  File "/Users/altu809/Projects/pyciemss/pyciemss/interfaces.py", line 282, in wrapped_model
    full_trajectory = model(
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/nn/module.py", line 449, in __call__
    result = super().__call__(*args, **kwargs)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/Users/altu809/Projects/pyciemss/pyciemss/compiled_dynamics.py", line 77, in forward
    simulate(self.deriv, self.initial_state(), start_time, end_time)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
    apply_stack(msg)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
    frame._process_message(msg)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
    return method(msg)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/solver.py", line 109, in _pyro_simulate
    state, start_time, next_interruption = simulate_to_interruption(
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
    apply_stack(msg)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
    frame._process_message(msg)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
    return method(msg)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py", line 89, in _pyro_simulate_to_interruption
    msg["value"] = torchdiffeq_simulate_to_interruption(
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 248, in torchdiffeq_simulate_to_interruption
    value = simulate_point(
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
    apply_stack(msg)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
    frame._process_message(msg)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
    return method(msg)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/handlers/trajectory.py", line 97, in _pyro_simulate_point
    trajectory: State[T] = simulate_trajectory(
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn
    apply_stack(msg)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack
    frame._process_message(msg)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message
    return method(msg)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py", line 77, in _pyro_simulate_trajectory
    msg["value"] = torchdiffeq_simulate_trajectory(
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 165, in torchdiffeq_simulate_trajectory
    return _torchdiffeq_ode_simulate_inner(dynamics, initial_state, timespan, **kwargs)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 71, in _torchdiffeq_ode_simulate_inner
    solns = _batched_odeint(  # torchdiffeq.odeint(
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 123, in _batched_odeint
    yt_raw = torchdiffeq.odeint(func, y0_expanded, t, **odeint_kwargs)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/odeint.py", line 77, in odeint
    solution = solver.integrate(t)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/solvers.py", line 28, in integrate
    self._before_integrate(t)
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/rk_common.py", line 163, in _before_integrate
    first_step = _select_initial_step(self.func, t[0], self.y0, self.order - 1, self.rtol, self.atol,
  File "/Users/altu809/anaconda3/lib/python3.10/site-packages/torchdiffeq/_impl/misc.py", line 54, in _select_initial_step
    d1 = norm(f0 / scale)
RuntimeError: The size of tensor a (9) must match the size of tensor b (3) at non-singleton dimension 0
                                       Trace Shapes:    
                                        Param Sites:    
        numeric_initial_state_func$$$_nodes.0._value    
        numeric_initial_state_func$$$_nodes.1._value    
        numeric_initial_state_func$$$_nodes.2._value    
numeric_deriv_func$$$_nodes.0._args.0._args.0._value    
                                       Sample Sites:    
                                  persistent_dx dist 3 |
                                               value 3 |
                                   persistent_u dist 3 |
                                               value 3 |
SamWitty commented 7 months ago

@sabinala , could you create a PR that adds this model to the tests? Thanks!

sabinala commented 7 months ago

@SamWitty see https://github.com/ciemss/pyciemss/pull/490.