Closed timmens closed 1 year ago
Merging #13 (ec9cc9b) into main (aaa4c96) will increase coverage by
0.46%
. The diff coverage is99.80%
.
@@ Coverage Diff @@
## main #13 +/- ##
==========================================
+ Coverage 98.29% 98.76% +0.46%
==========================================
Files 25 30 +5
Lines 1058 1540 +482
==========================================
+ Hits 1040 1521 +481
- Misses 18 19 +1
Impacted Files | Coverage Δ | |
---|---|---|
src/lcm/state_space.py | 97.27% <ø> (ø) |
|
src/lcm/simulate.py | 99.09% <99.09%> (ø) |
|
src/lcm/argmax.py | 100.00% <100.00%> (ø) |
|
src/lcm/dispatchers.py | 97.29% <100.00%> (+0.52%) |
:arrow_up: |
src/lcm/entry_point.py | 98.46% <100.00%> (+1.40%) |
:arrow_up: |
src/lcm/example_models.py | 86.20% <100.00%> (+2.20%) |
:arrow_up: |
src/lcm/model_functions.py | 100.00% <100.00%> (ø) |
|
src/lcm/solve_brute.py | 100.00% <100.00%> (ø) |
|
tests/test_argmax.py | 100.00% <100.00%> (ø) |
|
tests/test_dispatchers.py | 100.00% <100.00%> (ø) |
|
... and 5 more |
Saving this code snippet here because it will be deleted from git when we squash-merge this PR.
def _create_segment_nd_arange(segment_ids, num_segments, shape):
"""Create an nd-arange for each segment.
Args:
segment_ids (jax.numpy.ndarray): 1d array with segment identifiers. See
jax.ops.segment_max.
num_segments (int): Total number of segments. See jax.ops.segment_max.
shape (tuple): The shape to which the index map should be broadcasted. The first
dimension must coincide with len(segment_ids).
Returns:
jax.numpy.ndarray: An array of shape 'shape' containing an arange in the first
dimension for each segment, and repeating this value along the remaining
shape[1:] dimensions.
"""
bincount = jnp.bincount(segment_ids, length=num_segments)
cumsum = jnp.cumsum(bincount)
# shift the cumulative sum to the right and pad with zero at the beginning
shifted_cumsum = jnp.pad(cumsum[:-1], (1, 0), constant_values=0)
# create an array of indices for each segment
segment_arange = jnp.arange(shape[0]) - shifted_cumsum[segment_ids]
# reshape the array of indices to be broadcastable to the desired shape
reshaped = segment_arange.reshape(-1, *((1,) * (len(shape) - 1)))
return jnp.broadcast_to(reshaped, shape)