google-deepmind / optax

Optax is a gradient processing and optimization library for JAX.
https://optax.readthedocs.io
Apache License 2.0
1.56k stars 166 forks source link

Add an assignment problem solver #954

Open carlosgmartin opened 2 months ago

carlosgmartin commented 2 months ago

Feature request: Add a GPU/TPU-friendly solver for the assignment problem. For context, see:

  1. scipy.optimize.linear_sum_assignment
  2. https://github.com/google/jax/issues/10403
  3. https://github.com/google/jax/pull/16974

The last page contains the following comment:

There is a TPU-friendly implementation of the Hungarian algorithm here: https://github.com/google-research/scenic/blob/main/scenic/model_lib/matchers/hungarian_cover.py

Potentially relevant: