OpenSourceEconomics / lcm

Solution and simulation of life cycle models in Python with GPU acceleration.
Apache License 2.0
16 stars 1 forks source link

BUG: Transition matrices with incorrect shapes do not throw error #63

Open buddejul opened 7 months ago

buddejul commented 7 months ago

I came across unexpected behavior for stochastic next functions with the _period argument.

For N_PERIODS = 3 and a binary state I thought we needed a 2x2 matrix (one for each period transition). But the following does not throw an error:

@lcm.mark.stochastic
def next_dummy_state(_period):
    pass

DUMMY_STATE_TRANSITION = jnp.array(
    [[0.5, 0.5]]
)

I checked it's actually included in the model, providing an empty transition matrices raises an indexing error as expected.

Providing a matrix that's "too large" (e.g. (4, 2)) here also doesn't result in an error.

timmens commented 7 months ago

Thank you for opening the issue!

This seems to be related to JAX's out-of-bounds indexing. The following code runs without error:

import jax.numpy as jnp

jnp.arange(10)[11]

> Array(9, dtype=int32)

For reference,see: https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#out-of-bounds-indexing

I will have to check if it is not a more serious problem however. In any case, we have to throw an informative error on the LCM side.

hmgaudecker commented 7 months ago

Great detective work. Sounds like we'll need to add bound checks up front.

Also, a good example why you want to test with asymmetric inputs. I wonder whether things still work with [0.4, 0.6] or you'd see a violation of add-to-one constraints?

buddejul commented 7 months ago

Re asymmetry: I originally noticed this for a state in our replication and the row had asymmetric entries (but would need to confirm). So I guess the indexing is not at the individual entry level but the rows (at least in this example with just period as argument).

FWIW I also tried different combinations of required rows (i.e..periods) and provided rows but this didn't seem to have an effect, consistent with the above.


Edit:

Dummy state transition: [[0.5 0.5]] 
Solved 
Dummy state transition: [[0.3 0.7]] 
Solved 
Dummy state transition: [[0.5]] 
Solved 
Dummy state transition: [[0.3]] 
Solved 
Dummy state transition: [[0.3 0.7]
 [0.1 0.9]] 
Solved 
Dummy state transition: [] 
Failed with error: mul got incompatible shapes for broadcasting: (2, 2, 2, 2), (0, 2, 2, 2). 
Dummy state transition: [[0.3 0.6 0.1]] 
Failed with error: mul got incompatible shapes for broadcasting: (2, 2, 2, 2), (3, 2, 2, 2). 
Dummy state transition: [[0.3 0.7 0.9]] 
Failed with error: mul got incompatible shapes for broadcasting: (2, 2, 2, 2), (3, 2, 2, 2).