ciemss / pyciemss

Causal and probabilistic reasoning with continuous time dynamical systems
Other
12 stars 4 forks source link

Regnets fail differently on unstable simulations where petrinets persist #513

Closed sabinala closed 3 months ago

sabinala commented 4 months ago

Getting the following error message when trying to simulate a regnet model that blows up in finite time (whereas the petrinet version of the same model produces a result):


ERROR:root: 
       ############################### 

       There was an exception in CIEMSS Service 

       job: 4f12c3b4-fa06-4d2f-a861-7499e424657f 
       <class 'TypeError'>: `y0` must be a floating point Tensor but is a torch.LongTensor 
       ################################ 

Traceback (most recent call last): 
 File "/usr/local/lib/python3.10/site-packages/rq/worker.py", line 1428, in perform_job 
   rv = job.perform() 
 File "/usr/local/lib/python3.10/site-packages/rq/job.py", line 1278, in perform 
   self._result = self._execute() 
 File "/usr/local/lib/python3.10/site-packages/rq/job.py", line 1315, in _execute 
   result = self.func(*self.args, **self.kwargs) 
 File "/service/./execute.py", line 34, in run 
   output = eval(operation_name)(**kwargs) 
 File "/usr/local/lib/python3.10/site-packages/pyciemss/integration_utils/custom_decorators.py", line 29, in wrapped 
   raise e 
 File "/usr/local/lib/python3.10/site-packages/pyciemss/integration_utils/custom_decorators.py", line 10, in wrapped 
   result = function(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/pyciemss/interfaces.py", line 314, in sample 
   samples = pyro.infer.Predictive( 
 File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl 
   return self._call_impl(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl 
   return forward_call(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/pyro/infer/predictive.py", line 273, in forward 
   return _predictive( 
 File "/usr/local/lib/python3.10/site-packages/pyro/infer/predictive.py", line 78, in _predictive 
   max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs) 
 File "/usr/local/lib/python3.10/site-packages/pyro/infer/predictive.py", line 21, in _guess_max_plate_nesting 
   model_trace = poutine.trace(model).get_trace(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 198, in get_trace 
   self(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 174, in __call__ 
   ret = self.fn(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context 
   return func(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap 
   return fn(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/pyciemss/interfaces.py", line 300, in wrapped_model 
   full_trajectory = model( 
 File "/usr/local/lib/python3.10/site-packages/pyro/nn/module.py", line 449, in __call__ 
   result = super().__call__(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl 
   return self._call_impl(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl 
   return forward_call(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/pyciemss/compiled_dynamics.py", line 77, in forward 
   simulate(self.deriv, self.initial_state(), start_time, end_time) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn 
   apply_stack(msg) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack 
   frame._process_message(msg) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message 
   return method(msg) 
 File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/internals/solver.py", line 109, in _pyro_simulate 
   state, start_time, next_interruption = simulate_to_interruption( 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn 
   apply_stack(msg) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack 
   frame._process_message(msg) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message 
   return method(msg) 
 File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py", line 89, in _pyro_simulate_to_interruption 
   msg["value"] = torchdiffeq_simulate_to_interruption( 
 File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 248, in torchdiffeq_simulate_to_inter
ruption 
   value = simulate_point( 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn 
   apply_stack(msg) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack 
   frame._process_message(msg) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message 
   return method(msg) 
 File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/handlers/trajectory.py", line 97, in _pyro_simulate_point 
   trajectory: State[T] = simulate_trajectory( 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn 
   apply_stack(msg) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack 
   frame._process_message(msg) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message 
   return method(msg) 
 File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py", line 77, in _pyro_simulate_trajectory 
   msg["value"] = torchdiffeq_simulate_trajectory( 
 File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 165, in torchdiffeq_simulate_trajecto
ry 
   return _torchdiffeq_ode_simulate_inner(dynamics, initial_state, timespan, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 71, in _torchdiffeq_ode_simulate_inne
r 
   solns = _batched_odeint( # torchdiffeq.odeint( 
 File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 123, in _batched_odeint 
   yt_raw = torchdiffeq.odeint(func, y0_expanded, t, **odeint_kwargs) 
 File "/usr/local/lib/python3.10/site-packages/torchdiffeq/_impl/odeint.py", line 72, in odeint 
   shapes, func, y0, t, rtol, atol, method, options, event_fn, t_is_reversed = _check_inputs(func, y0, t, rtol, atol, method, options, event_f
n, SOLVERS) 
 File "/usr/local/lib/python3.10/site-packages/torchdiffeq/_impl/misc.py", line 213, in _check_inputs 
   _assert_floating('y0', y0) 
 File "/usr/local/lib/python3.10/site-packages/torchdiffeq/_impl/misc.py", line 106, in _assert_floating 
   raise TypeError('`{}` must be a floating point Tensor but is a {}'.format(name, t.type())) 
TypeError: `y0` must be a floating point Tensor but is a torch.LongTensor 
16:41:17 [Job 4f12c3b4-fa06-4d2f-a861-7499e424657f]: exception raised while executing (execute.run) 
Traceback (most recent call last): 
 File "/usr/local/lib/python3.10/site-packages/rq/worker.py", line 1428, in perform_job 
   rv = job.perform() 
 File "/usr/local/lib/python3.10/site-packages/rq/job.py", line 1278, in perform 
   self._result = self._execute() 
 File "/usr/local/lib/python3.10/site-packages/rq/job.py", line 1315, in _execute 
   result = self.func(*self.args, **self.kwargs) 
 File "/service/./execute.py", line 34, in run 
   output = eval(operation_name)(**kwargs) 
 File "/usr/local/lib/python3.10/site-packages/pyciemss/integration_utils/custom_decorators.py", line 29, in wrapped 
   raise e 
 File "/usr/local/lib/python3.10/site-packages/pyciemss/integration_utils/custom_decorators.py", line 10, in wrapped 
   result = function(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/pyciemss/interfaces.py", line 314, in sample 
   samples = pyro.infer.Predictive( 
 File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl 
   return self._call_impl(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl 
   return forward_call(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/pyro/infer/predictive.py", line 273, in forward 
   return _predictive( 
 File "/usr/local/lib/python3.10/site-packages/pyro/infer/predictive.py", line 78, in _predictive 
   max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs) 
 File "/usr/local/lib/python3.10/site-packages/pyro/infer/predictive.py", line 21, in _guess_max_plate_nesting 
   model_trace = poutine.trace(model).get_trace(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 198, in get_trace 
   self(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 174, in __call__ 
   ret = self.fn(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context 
   return func(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap 
   return fn(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/pyciemss/interfaces.py", line 300, in wrapped_model 
   full_trajectory = model( 
 File "/usr/local/lib/python3.10/site-packages/pyro/nn/module.py", line 449, in __call__ 
   result = super().__call__(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl 
   return self._call_impl(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl 
   return forward_call(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/pyciemss/compiled_dynamics.py", line 77, in forward 
   simulate(self.deriv, self.initial_state(), start_time, end_time) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn 
   apply_stack(msg) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack 
   frame._process_message(msg) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message 
   return method(msg) 
 File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/internals/solver.py", line 109, in _pyro_simulate 
   state, start_time, next_interruption = simulate_to_interruption( 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn 
   apply_stack(msg) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack 
   frame._process_message(msg) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message 
   return method(msg) 
 File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py", line 89, in _pyro_simulate_to_interruption 
   msg["value"] = torchdiffeq_simulate_to_interruption( 
 File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 248, in torchdiffeq_simulate_to_inter
ruption 
   value = simulate_point( 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn 
   apply_stack(msg) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack 
   frame._process_message(msg) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message 
   return method(msg) 
 File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/handlers/trajectory.py", line 97, in _pyro_simulate_point 
   trajectory: State[T] = simulate_trajectory( 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn 
   apply_stack(msg) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack 
   frame._process_message(msg) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message 
   return method(msg) 
 File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py", line 77, in _pyro_simulate_trajectory 
   msg["value"] = torchdiffeq_simulate_trajectory( 
 File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 165, in torchdiffeq_simulate_trajecto
ry 
   return _torchdiffeq_ode_simulate_inner(dynamics, initial_state, timespan, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 71, in _torchdiffeq_ode_simulate_inne
r 
   solns = _batched_odeint( # torchdiffeq.odeint( 
 File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 123, in _batched_odeint 
   yt_raw = torchdiffeq.odeint(func, y0_expanded, t, **odeint_kwargs) 
 File "/usr/local/lib/python3.10/site-packages/torchdiffeq/_impl/odeint.py", line 72, in odeint 
   shapes, func, y0, t, rtol, atol, method, options, event_fn, t_is_reversed = _check_inputs(func, y0, t, rtol, atol, method, options, event_f
n, SOLVERS) 
 File "/usr/local/lib/python3.10/site-packages/torchdiffeq/_impl/misc.py", line 213, in _check_inputs 
   _assert_floating('y0', y0) 
 File "/usr/local/lib/python3.10/site-packages/torchdiffeq/_impl/misc.py", line 106, in _assert_floating 
   raise TypeError('`{}` must be a floating point Tensor but is a {}'.format(name, t.type())) 
TypeError: `y0` must be a floating point Tensor but is a torch.LongTensor 

ERROR:rq.worker:[Job 4f12c3b4-fa06-4d2f-a861-7499e424657f]: exception raised while executing (execute.run) 
Traceback (most recent call last): 
 File "/usr/local/lib/python3.10/site-packages/rq/worker.py", line 1428, in perform_job 
   rv = job.perform() 
 File "/usr/local/lib/python3.10/site-packages/rq/job.py", line 1278, in perform 
   self._result = self._execute() 
 File "/usr/local/lib/python3.10/site-packages/rq/job.py", line 1315, in _execute 
   result = self.func(*self.args, **self.kwargs) 
 File "/service/./execute.py", line 34, in run 
   output = eval(operation_name)(**kwargs) 
 File "/usr/local/lib/python3.10/site-packages/pyciemss/integration_utils/custom_decorators.py", line 29, in wrapped 
   raise e 
 File "/usr/local/lib/python3.10/site-packages/pyciemss/integration_utils/custom_decorators.py", line 10, in wrapped 
   result = function(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/pyciemss/interfaces.py", line 314, in sample 
   samples = pyro.infer.Predictive( 
 File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl 
   return self._call_impl(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl 
   return forward_call(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/pyro/infer/predictive.py", line 273, in forward 
   return _predictive( 
 File "/usr/local/lib/python3.10/site-packages/pyro/infer/predictive.py", line 78, in _predictive 
   max_plate_nesting = _guess_max_plate_nesting(model, model_args, model_kwargs) 
 File "/usr/local/lib/python3.10/site-packages/pyro/infer/predictive.py", line 21, in _guess_max_plate_nesting 
   model_trace = poutine.trace(model).get_trace(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 198, in get_trace 
   self(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/trace_messenger.py", line 174, in __call__ 
   ret = self.fn(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context 
   return func(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 12, in _context_wrap 
   return fn(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/pyciemss/interfaces.py", line 300, in wrapped_model 
   full_trajectory = model( 
 File "/usr/local/lib/python3.10/site-packages/pyro/nn/module.py", line 449, in __call__ 
   result = super().__call__(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl 
   return self._call_impl(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl 
   return forward_call(*args, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/pyciemss/compiled_dynamics.py", line 77, in forward 
   simulate(self.deriv, self.initial_state(), start_time, end_time) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn 
   apply_stack(msg) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack 
   frame._process_message(msg) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message 
   return method(msg) 
 File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/internals/solver.py", line 109, in _pyro_simulate 
   state, start_time, next_interruption = simulate_to_interruption( 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn 
   apply_stack(msg) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack 
   frame._process_message(msg) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message 
   return method(msg) 
 File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py", line 89, in _pyro_simulate_to_interruption 
   msg["value"] = torchdiffeq_simulate_to_interruption( 
 File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 248, in torchdiffeq_simulate_to_inter
ruption 
   value = simulate_point( 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn 
   apply_stack(msg) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack 
   frame._process_message(msg) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message 
   return method(msg) 
 File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/handlers/trajectory.py", line 97, in _pyro_simulate_point 
   trajectory: State[T] = simulate_trajectory( 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 281, in _fn 
   apply_stack(msg) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/runtime.py", line 212, in apply_stack 
   frame._process_message(msg) 
 File "/usr/local/lib/python3.10/site-packages/pyro/poutine/messenger.py", line 162, in _process_message 
   return method(msg) 
 File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/handlers/solver.py", line 77, in _pyro_simulate_trajectory 
   msg["value"] = torchdiffeq_simulate_trajectory( 
 File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 165, in torchdiffeq_simulate_trajecto
ry 
   return _torchdiffeq_ode_simulate_inner(dynamics, initial_state, timespan, **kwargs) 
 File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 71, in _torchdiffeq_ode_simulate_inne
r 
   solns = _batched_odeint( # torchdiffeq.odeint( 
 File "/usr/local/lib/python3.10/site-packages/chirho/dynamical/internals/backends/torchdiffeq.py", line 123, in _batched_odeint 
   yt_raw = torchdiffeq.odeint(func, y0_expanded, t, **odeint_kwargs) 
 File "/usr/local/lib/python3.10/site-packages/torchdiffeq/_impl/odeint.py", line 72, in odeint 
   shapes, func, y0, t, rtol, atol, method, options, event_fn, t_is_reversed = _check_inputs(func, y0, t, rtol, atol, method, options, event_f
n, SOLVERS) 
 File "/usr/local/lib/python3.10/site-packages/torchdiffeq/_impl/misc.py", line 213, in _check_inputs 
   _assert_floating('y0', y0) 
 File "/usr/local/lib/python3.10/site-packages/torchdiffeq/_impl/misc.py", line 106, in _assert_floating 
   raise TypeError('`{}` must be a floating point Tensor but is a {}'.format(name, t.type())) 
TypeError: `y0` must be a floating point Tensor but is a torch.LongTensor
sabinala commented 4 months ago

Note that using Euler with the regnet produces a result in a notebook, but fails in Terarium

djinnome commented 4 months ago

@sabinala can you point me to the notebook where this error occurs?

sabinala commented 3 months ago

@djinnome This issue can probably be closed since it seems that the issue was actually in how the equations were being read for regnets from the AMR.