OpenSourceEconomics / lcm

Solution and simulation of life cycle models in Python with GPU acceleration.
Apache License 2.0
16 stars 1 forks source link

Implement simulation function #13

Closed timmens closed 1 year ago

timmens commented 1 year ago
codecov[bot] commented 1 year ago

Codecov Report

Merging #13 (ec9cc9b) into main (aaa4c96) will increase coverage by 0.46%. The diff coverage is 99.80%.

@@            Coverage Diff             @@
##             main      #13      +/-   ##
==========================================
+ Coverage   98.29%   98.76%   +0.46%     
==========================================
  Files          25       30       +5     
  Lines        1058     1540     +482     
==========================================
+ Hits         1040     1521     +481     
- Misses         18       19       +1     
Impacted Files Coverage Δ
src/lcm/state_space.py 97.27% <ø> (ø)
src/lcm/simulate.py 99.09% <99.09%> (ø)
src/lcm/argmax.py 100.00% <100.00%> (ø)
src/lcm/dispatchers.py 97.29% <100.00%> (+0.52%) :arrow_up:
src/lcm/entry_point.py 98.46% <100.00%> (+1.40%) :arrow_up:
src/lcm/example_models.py 86.20% <100.00%> (+2.20%) :arrow_up:
src/lcm/model_functions.py 100.00% <100.00%> (ø)
src/lcm/solve_brute.py 100.00% <100.00%> (ø)
tests/test_argmax.py 100.00% <100.00%> (ø)
tests/test_dispatchers.py 100.00% <100.00%> (ø)
... and 5 more
timmens commented 1 year ago

Saving this code snippet here because it will be deleted from git when we squash-merge this PR.

def _create_segment_nd_arange(segment_ids, num_segments, shape):
    """Create an nd-arange for each segment.

    Args:
        segment_ids (jax.numpy.ndarray): 1d array with segment identifiers. See
            jax.ops.segment_max.
        num_segments (int): Total number of segments. See jax.ops.segment_max.
        shape (tuple): The shape to which the index map should be broadcasted. The first
            dimension must coincide with len(segment_ids).

    Returns:
        jax.numpy.ndarray: An array of shape 'shape' containing an arange in the first
            dimension for each segment, and repeating this value along the remaining
            shape[1:] dimensions.

    """
    bincount = jnp.bincount(segment_ids, length=num_segments)
    cumsum = jnp.cumsum(bincount)

    # shift the cumulative sum to the right and pad with zero at the beginning
    shifted_cumsum = jnp.pad(cumsum[:-1], (1, 0), constant_values=0)

    # create an array of indices for each segment
    segment_arange = jnp.arange(shape[0]) - shifted_cumsum[segment_ids]

    # reshape the array of indices to be broadcastable to the desired shape
    reshaped = segment_arange.reshape(-1, *((1,) * (len(shape) - 1)))
    return jnp.broadcast_to(reshaped, shape)