OpenSourceEconomics / lcm

Solution and simulation of life cycle models in Python with GPU acceleration.
Apache License 2.0
16 stars 1 forks source link

ENH: Add linear extrapolation to `map_coordinates` #83

Closed timmens closed 1 month ago

timmens commented 4 months ago

Why

Linear extrapolation around the boundaries is useful for various reasons. For example, (1) it can robustify the solution if, in a state, transition values outside the original grid are requested, or (2) it helps when simulating stochastic where node points are outside the grid.

When

I have asked a question in JAX's discussion forum. If they decide not to add this feature, we can simply implement a custom map_coordinates function for LCM.

Implementation

When using map_coordinates to linearly interpolate between values, internally the following function is called:

def _linear_indices_and_weights(coordinate: Array) -> list[tuple[Array, ArrayLike]]:
  lower = jnp.floor(coordinate)
  upper_weight = coordinate - lower
  lower_weight = 1 - upper_weight
  index = lower.astype(jnp.int32)
  return [(index, lower_weight), (index + 1, upper_weight)]

To allow for linear extrapolation for coordinate values that are either below 0, or larger than the original grid size input_size, one only has to change the first line to:

...
lower = jnp.clip(jnp.floor(coordinate), min=0, max=input_size - 2)
...