Closed guyao closed 2 weeks ago
The return type of pytorch cummax/cummin is tuple of two output tensors (values, indices).
However jax cummax/cummin does not return same result.
Implement this operation using associative_scan with running min/max reduce function.
associative_scan
resolves #7373 resolves #7374
The return type of pytorch cummax/cummin is tuple of two output tensors (values, indices).
However jax cummax/cummin does not return same result.
Implement this operation using
associative_scan
with running min/max reduce function.resolves #7373 resolves #7374