Open buddejul opened 7 months ago
Thank you for opening the issue!
This seems to be related to JAX's out-of-bounds indexing. The following code runs without error:
import jax.numpy as jnp
jnp.arange(10)[11]
> Array(9, dtype=int32)
For reference,see: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#out-of-bounds-indexing
I will have to check if it is not a more serious problem however. In any case, we have to throw an informative error on the LCM side.
Great detective work. Sounds like we'll need to add bound checks up front.
Also, a good example why you want to test with asymmetric inputs. I wonder whether things still work with [0.4, 0.6]
or you'd see a violation of add-to-one constraints?
Re asymmetry: I originally noticed this for a state in our replication and the row had asymmetric entries (but would need to confirm). So I guess the indexing is not at the individual entry level but the rows (at least in this example with just period as argument).
FWIW I also tried different combinations of required rows (i.e..periods) and provided rows but this didn't seem to have an effect, consistent with the above.
Edit:
Dummy state transition: [[0.5 0.5]]
Solved
Dummy state transition: [[0.3 0.7]]
Solved
Dummy state transition: [[0.5]]
Solved
Dummy state transition: [[0.3]]
Solved
Dummy state transition: [[0.3 0.7]
[0.1 0.9]]
Solved
Dummy state transition: []
Failed with error: mul got incompatible shapes for broadcasting: (2, 2, 2, 2), (0, 2, 2, 2).
Dummy state transition: [[0.3 0.6 0.1]]
Failed with error: mul got incompatible shapes for broadcasting: (2, 2, 2, 2), (3, 2, 2, 2).
Dummy state transition: [[0.3 0.7 0.9]]
Failed with error: mul got incompatible shapes for broadcasting: (2, 2, 2, 2), (3, 2, 2, 2).
I came across unexpected behavior for stochastic next functions with the
_period
argument.For
N_PERIODS = 3
and a binary state I thought we needed a 2x2 matrix (one for each period transition). But the following does not throw an error:I checked it's actually included in the model, providing an empty transition matrices raises an indexing error as expected.
Providing a matrix that's "too large" (e.g.
(4, 2)
) here also doesn't result in an error.