Closed timmens closed 1 month ago
Better solution
The idea is to solve this problem during the model processing and not in the function evaluator. As an example, consider a health state variables that attains values
[-1, 0, 1]
. In the model specification the user will write... "states": { "health": {"options": [-1, 0, 1]}, ... }
If we could come up with a solution that allowed
"states": {
"health": {"options": ["fair", "good", "excellent"]},
...
}
that would be even nicer, of course. Isn't Jax smart enough to do away with that?
To my understanding, JAX cannot deal with string data since the underlying XLA compiler does not support strings (see e.g. JAX Issue #2084 or JAX Issue #3045.
I think we can come up with a solution in LCM that solves this during the model processing. But this would require large changes (e.g., we would have to dynamically rewrite user functions to not use the strings anymore but the integer-based representation of the grids, before passing the updated model to LCM internals). We can open a feature request for this, but I think given the complexity of this change, we should focus on other things for now.
Yeah, would be a nice-to-have more than anything else. Thanks for the detailed explanation!
Problem
Discrete grids need to coincide with their indices, i.e.,
valid_grid = [0, 1, 2, 3]
andinvalid_grid = [2, 3]
.What we want
Support arbitrary integer grids. I.e., the following grids should be supported:
[2, 3]
[-1, 0, 1]
[1, 10, 100]
[1, 2, 1_000_000]
Implementation
Before
Before #81, we tried to solve this in
lcm.function_evaluator.get_label_translator()
, using:The problem was, however, that the lookup
val_to_pos[x]
is not vmappable/jittable in JAX (this might have been different in an earlier JAX version). The underlying problem here is that (1) JAX does not know scalar values but only 0-dimensional arrays, butval_to_pos
has int keys (i.e.0
vs.jax.array(0)
); and (2) we cannot usejax.array(x)
as a dictionary key, because JAX arrays are not hashable.Bad solution
Add a new label translator that is JAX compatible:
The problem here is that this function is much slower (10-100 times) for arbitrary grids compared to index grids. The computational cost also grows linearly with the size of the label arrays.
Better solution
The idea is to solve this problem during the model processing and not in the function evaluator. As an example, consider a health state variables that attains values
[-1, 0, 1]
. In the model specification the user will writeIn the model processing, we would not add
health
as a state variable, but__health_index__
, which would live on the grid[0, 1, 2]
. Since the user wrote functions that depend onhealth
and not on the index, we also need to add a dynamically defined functionGiven this, arbitrary discrete integer grids should be supported.