Open ordabayevy opened 3 years ago
To construct dim_to_symbol
that matches the shape of site["log_prbo"]
I'm using the following hack but not sure if this is a correct solution or might even break the code in other places:
class EnumMessenger(Messenger):
...
def _pyro_sample(self, msg): """
:param msg: current message at a trace site.
:returns: a sample from the stochastic function at the site.
"""
+ param_dims = _ENUM_ALLOCATOR.dim_to_id.copy() # enum dim -> unique id
+ self._param_dims[msg["name"]] = param_dims
if msg["done"] or not isinstance(msg["fn"], TorchDistributionMixin):
return
# Compute upstream dims in scope; these are unsafe to use for this site's target_dim.
scope = msg["infer"].get("_markov_scope") # site name -> markov depth
- param_dims = _ENUM_ALLOCATOR.dim_to_id.copy() # enum dim -> unique id
if scope is not None:
for name, depth in scope.items():
if self._markov_depths[name] == depth: # hide sites whose markov context has exited
param_dims.update(self._value_dims[name])
self._markov_depths[msg["name"]] = msg["infer"]["_markov_depth"]
- self._param_dims[msg["name"]] = param_dims
if msg["is_observed"] or msg["infer"].get("enumerate") != "parallel":
return
Issue Description
I'm trying to use
infer_discrete
for a model (code snippet below) which has local latent variable (locs
) that depends on discrete variable (assignment
). This leads to an error wheresite["log_prob"].shape
anddim_to_symbol
don't match:Using pdb:
Environment
Code Snippet