Open adamcweiner opened 2 years ago
I’m using a CPU on a linux cluster with these versions:
funsor==0.4.3
pyro-api==0.1.2
pyro-ppl==1.8.2
torch==1.12.1
The link to my initial discussion board post can be found here
Last, here’s the full error trace:
AssertionError Traceback (most recent call last)
<ipython-input-7-3ccf5e1cbc19> in <module>
8
9 with pyro_backend(PYRO_BACKEND):
---> 10 trace, sequences, lengths = main(args)
<ipython-input-4-8c5832ced74d> in main(args)
169 temperature=0, first_available_dim=first_available_dim
170 )
--> 171 ).get_trace(sequences, lengths)
172
173 # trained_model = handlers.replay(model, trace=guide_trace)
/juno/work/venv3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in get_trace(self, *args, **kwargs)
196 Calls this poutine and returns its trace instead of the function's return value.
197 """
--> 198 self(*args, **kwargs)
199 return self.msngr.get_trace()
/juno/work/venv3/lib/python3.7/site-packages/pyro/poutine/trace_messenger.py in __call__(self, *args, **kwargs)
172 )
173 try:
--> 174 ret = self.fn(*args, **kwargs)
175 except (ValueError, RuntimeError) as e:
176 exc_type, exc_value, traceback = sys.exc_info()
/juno/work/venv3/lib/python3.7/site-packages/pyro/contrib/funsor/infer/discrete.py in _sample_posterior(model, first_available_dim, temperature, *args, **kwargs)
44
45 with approx:
---> 46 approx_factors = funsor.adjoint.adjoint(sum_op, prod_op, log_prob)
47
48 # construct a result trace to replay against the model
/juno/work/venv3/lib/python3.7/site-packages/funsor/adjoint.py in adjoint(sum_op, bin_op, expr)
140
141 def adjoint(sum_op, bin_op, expr):
--> 142 forward, backward = forward_backward(sum_op, bin_op, expr)
143 return backward
144
/juno/work/venv3/lib/python3.7/site-packages/funsor/adjoint.py in forward_backward(sum_op, bin_op, expr, batch_vars)
135 # TODO fix traversal order in AdjointTape instead of using stack_reinterpret
136 forward = stack_reinterpret(expr)
--> 137 backward = tape.adjoint(sum_op, bin_op, forward, batch_vars=batch_vars)
138 return forward, backward
139
/juno/work/venv3/lib/python3.7/site-packages/funsor/adjoint.py in adjoint(self, sum_op, bin_op, root, targets, batch_vars)
113 self._eager_to_lazy[output] = lazy_output
114
--> 115 in_adjs = adjoint_ops(fn, sum_op, bin_op, adjoint_values[output], *inputs)
116 for v, adjv in in_adjs:
117 # Marginalize out message variables that don't appear in recipients.
/juno/work/venv3/lib/python3.7/site-packages/funsor/registry.py in __call__(self, key, *args)
104
105 def __call__(self, key, *args):
--> 106 return self[key](*args)
107
108 def dispatch(self, key, *args):
/juno/work/venv3/lib/python3.7/site-packages/funsor/registry.py in __call__(self, *args)
61
62 def __call__(self, *args):
---> 63 return self.partial_call(*args)(*args)
64
65
/juno/work/venv3/lib/python3.7/site-packages/funsor/adjoint.py in adjoint_contract_generic(adj_sum_op, adj_prod_op, out_adj, sum_op, prod_op, reduced_vars, terms)
215 adj_sum_op, adj_prod_op, out_adj, sum_op, prod_op, reduced_vars, terms
216 ):
--> 217 assert len(terms) == 1 or len(terms) == 2
218 return adjoint_ops(
219 Contraction,
AssertionError:
@ordabayevy any idea what might be happening?
I am trying to extract the discrete hidden states from the funsor HMM example; but am getting an odd error when using this pattern to infer discrete sites. My model exactly matches the example posted here and I added this chunk of code to the end of
main()
which is supposed to extract the hidden statesThe error arises when
.get_trace()
is invoked and its downstream calling offorward_backward()
in adjoint.py. This error occurs no matter which model structure I pick within the funsor hmm example (i.e. it’s an issue pertaining to all funsor HMM’s, not justmodel_7()
with vectorized time dimension). Also, this routine to infer discrete sites works just fine when I’m analyzing Bach Chorales using the standard pyro HMM example.I continue to get this same error when I use this condition workaround to avoid any compatibility issues between
replay()
and funsor'sinfer_discrete()