OpenSourceEconomics / lcm

Solution and simulation of life cycle models in Python with GPU acceleration.
Apache License 2.0
16 stars 1 forks source link

ENH: Support arbitrary discrete integer grids #82

Closed timmens closed 1 month ago

timmens commented 4 months ago

Problem

Discrete grids need to coincide with their indices, i.e., valid_grid = [0, 1, 2, 3] and invalid_grid = [2, 3].

What we want

Support arbitrary integer grids. I.e., the following grids should be supported:

Implementation

Before

Before #81, we tried to solve this in lcm.function_evaluator.get_label_translator(), using:

val_to_pos = dict(zip(_grid, range(len(labels)), strict=True))

@with_signature(args=[in_name])
def translate_label(*args, **kwargs):
    kwargs = all_as_kwargs(args, kwargs, arg_names=[in_name])
    return val_to_pos[kwargs[in_name]]

The problem was, however, that the lookup val_to_pos[x] is not vmappable/jittable in JAX (this might have been different in an earlier JAX version). The underlying problem here is that (1) JAX does not know scalar values but only 0-dimensional arrays, but val_to_pos has int keys (i.e. 0 vs. jax.array(0)); and (2) we cannot use jax.array(x) as a dictionary key, because JAX arrays are not hashable.

Bad solution

Add a new label translator that is JAX compatible:

labels_array = jnp.array(labels)
@with_signature(args=[in_name])
def translate_label(*args, **kwargs):
    kwargs = all_as_kwargs(args, kwargs, arg_names=[in_name])
    return jnp.nonzero(labels_array == kwargs[in_name], size=1, fill_value=None)[0][0]

The problem here is that this function is much slower (10-100 times) for arbitrary grids compared to index grids. The computational cost also grows linearly with the size of the label arrays.

Better solution

The idea is to solve this problem during the model processing and not in the function evaluator. As an example, consider a health state variables that attains values [-1, 0, 1]. In the model specification the user will write

...
"states": {
   "health": {"options": [-1, 0, 1]},
   ...
}
...

In the model processing, we would not add health as a state variable, but __health_index__, which would live on the grid [0, 1, 2]. Since the user wrote functions that depend on health and not on the index, we also need to add a dynamically defined function

def health(__health_index__):
   return health_options[__health_index__]  # health_options must be a jax array

Given this, arbitrary discrete integer grids should be supported.

hmgaudecker commented 4 months ago

Better solution

The idea is to solve this problem during the model processing and not in the function evaluator. As an example, consider a health state variables that attains values [-1, 0, 1]. In the model specification the user will write

...
"states": {
   "health": {"options": [-1, 0, 1]},
   ...
}

If we could come up with a solution that allowed

"states": {
   "health": {"options": ["fair", "good", "excellent"]},
   ...
}

that would be even nicer, of course. Isn't Jax smart enough to do away with that?

timmens commented 4 months ago

To my understanding, JAX cannot deal with string data since the underlying XLA compiler does not support strings (see e.g. JAX Issue #2084 or JAX Issue #3045.

I think we can come up with a solution in LCM that solves this during the model processing. But this would require large changes (e.g., we would have to dynamically rewrite user functions to not use the strings anymore but the integer-based representation of the grids, before passing the updated model to LCM internals). We can open a feature request for this, but I think given the complexity of this change, we should focus on other things for now.

hmgaudecker commented 4 months ago

Yeah, would be a nice-to-have more than anything else. Thanks for the detailed explanation!