aesara-devs / aeppl

Tools for an Aesara-based PPL.
https://aeppl.readthedocs.io
MIT License
64 stars 20 forks source link

`test_scan.py::test_initial_values` failing with latest Aesara release #148

Open ricardoV94 opened 2 years ago

ricardoV94 commented 2 years ago

This showed up in #147 (https://github.com/aesara-devs/aeppl/runs/6918324742?check_suite_focus=true)

https://github.com/aesara-devs/aeppl/blob/cc78f30b5ed89e5b247b88d697467a76cc1e424e/tests/test_scan.py#L349

Traceback ```python test_scan.py::test_initial_values FAILED [100%] tests/test_scan.py:347 (test_initial_values) node = Shape_i{0}(*2-) def compute_test_value(node: Apply): r"""Computes the test value of a node. Parameters ---------- node : Apply The `Apply` node for which the test value is computed. Returns ------- None The `tag.test_value`\s are updated in each `Variable` in `node.outputs`. """ # Gather the test values for each input of the node storage_map = {} compute_map = {} for i, ins in enumerate(node.inputs): try: > storage_map[ins] = [ins.get_test_value()] ../venv/lib/python3.8/site-packages/aesara/graph/op.py:89: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ self = NominalTensorVariable(2, TensorType(float64, (None, None))) def get_test_value(self): """Get the test value. Raises ------ TestValueError """ if not hasattr(self.tag, "test_value"): detailed_err_msg = get_variable_trace_string(self) > raise TestValueError(f"{self} has no test value {detailed_err_msg}") E aesara.graph.utils.TestValueError: *2- has no test value ../venv/lib/python3.8/site-packages/aesara/graph/basic.py:472: TestValueError During handling of the above exception, another exception occurred: @aesara.config.change_flags(compute_test_value="raise") def test_initial_values(): srng = at.random.RandomStream(seed=2320) p_S_0 = np.array([0.9, 0.1]) S_0_rv = srng.categorical(p_S_0, name="S_0") S_0_rv.tag.test_value = 0 Gamma_at = at.matrix("Gamma") Gamma_at.tag.test_value = np.array([[0, 1], [1, 0]]) s_0_vv = S_0_rv.clone() s_0_vv.name = "s_0" def step_fn(S_tm1, Gamma): S_t = srng.categorical(Gamma[S_tm1], name="S_t") return S_t S_1T_rv, _ = aesara.scan( fn=step_fn, outputs_info=[{"initial": S_0_rv, "taps": [-1]}], non_sequences=[Gamma_at], strict=True, n_steps=10, name="S_0T", ) S_1T_rv.name = "S_1T" s_1T_vv = S_1T_rv.clone() s_1T_vv.name = "s_1T" > logp_parts = factorized_joint_logprob({S_1T_rv: s_1T_vv, S_0_rv: s_0_vv}) test_scan.py:379: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ ../aeppl/joint_logprob.py:147: in factorized_joint_logprob q_logprob_vars = _logprob( /usr/lib/python3.8/functools.py:875: in wrapper return dispatch(args[0].__class__)(*args, **kw) ../aeppl/scan.py:270: in logprob_ScanRV logp_scan_args = convert_outer_out_to_in( ../aeppl/scan.py:210: in convert_outer_out_to_in inner_out_fn(remapped_io_to_ii) ../aeppl/scan.py:267: in create_inner_out_logp logp_parts = factorized_joint_logprob(value_map, warn_missing_rvs=False) ../aeppl/joint_logprob.py:79: in factorized_joint_logprob fgraph, rv_values, _ = construct_ir_fgraph(rv_values) ../aeppl/opt.py:301: in construct_ir_fgraph fgraph = FunctionGraph( ../venv/lib/python3.8/site-packages/aesara/graph/fg.py:153: in __init__ self.add_output(output, reason="init") ../venv/lib/python3.8/site-packages/aesara/graph/fg.py:163: in add_output self.import_var(var, reason=reason, import_missing=import_missing) ../venv/lib/python3.8/site-packages/aesara/graph/fg.py:304: in import_var self.import_node(var.owner, reason=reason, import_missing=import_missing) ../venv/lib/python3.8/site-packages/aesara/graph/fg.py:385: in import_node self.execute_callbacks("on_import", node, reason) ../venv/lib/python3.8/site-packages/aesara/graph/fg.py:725: in execute_callbacks fn(self, *args, **kwargs) ../venv/lib/python3.8/site-packages/aesara/tensor/basic_opt.py:1270: in on_import self.init_r(r) ../venv/lib/python3.8/site-packages/aesara/tensor/basic_opt.py:1222: in init_r self.set_shape(r, self.shape_tuple(r)) ../venv/lib/python3.8/site-packages/aesara/tensor/basic_opt.py:964: in shape_tuple return tuple(self.shape_ir(i, r) for i in range(r.type.ndim)) ../venv/lib/python3.8/site-packages/aesara/tensor/basic_opt.py:964: in return tuple(self.shape_ir(i, r) for i in range(r.type.ndim)) ../venv/lib/python3.8/site-packages/aesara/tensor/basic_opt.py:952: in shape_ir s = Shape_i(i)(r) ../venv/lib/python3.8/site-packages/aesara/graph/op.py:299: in __call__ compute_test_value(node) _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ node = Shape_i{0}(*2-) def compute_test_value(node: Apply): r"""Computes the test value of a node. Parameters ---------- node : Apply The `Apply` node for which the test value is computed. Returns ------- None The `tag.test_value`\s are updated in each `Variable` in `node.outputs`. """ # Gather the test values for each input of the node storage_map = {} compute_map = {} for i, ins in enumerate(node.inputs): try: storage_map[ins] = [ins.get_test_value()] compute_map[ins] = [True] except TestValueError: # no test-value was specified, act accordingly if config.compute_test_value == "warn": warnings.warn( f"Warning, Cannot compute test value: input {i} ({ins}) of Op {node} missing default value", stacklevel=2, ) return elif config.compute_test_value == "raise": detailed_err_msg = get_variable_trace_string(ins) > raise ValueError( f"Cannot compute test value: input {i} ({ins}) of Op {node} missing default value. {detailed_err_msg}" E ValueError: Cannot compute test value: input 0 (*2-) of Op Shape_i{0}(*2-) missing default value. ../venv/lib/python3.8/site-packages/aesara/graph/op.py:102: ValueError ```
ricardoV94 commented 2 years ago

@brandonwillard suggestion for what might be going on:

it has something to do with test values assigned to NominalVariables remember that those kinds of Variables are identical when their IDs are equal anyway, the Scan manipulations in aeppl needed some kind of adjustment to account for NominalVariables appropriately and their associated test values it might be that the Scan inner-graphs aren't creating NominalVariables that are properly ordered e.g. if an inner-graph is constructed from other graphs that have the same NominalVariables in them, then the new inner-graph needs to reorder/recreate the NominalVariables so that they're deduplicated and represent all the new inner-graph inputs distinctly