StanfordASL / hj_reachability

Hamilton-Jacobi reachability analysis in JAX.
MIT License
103 stars 16 forks source link

Incompatibility with latest JAX version #14

Open mattkiim opened 3 months ago

mattkiim commented 3 months ago

Hi, it appears that JAX has deprecated jax.experimental.host_callback as of March 21, 2024 (google/jax/issues/20385). I've implemented a working solution below for the TqdmWrapper class, located in solver.py. This resolved all issues for me.

class TqdmWrapper:

    def __init__(self, tqdm, reference_time, total, *args, **kwargs):
        self.reference_time = reference_time
        jax.experimental.io_callback(lambda total: self._create_tqdm(tqdm, total, *args, **kwargs), None, total)

    def _create_tqdm(self, tqdm, total, *args, **kwargs):
        self._tqdm = tqdm.tqdm(total=total, *args, **kwargs)

    def update_to(self, n):
        jax.experimental.io_callback(lambda n: self._tqdm.update(n - self._tqdm.n), None, n)

    def close(self):
        jax.experimental.io_callback(lambda _: self._tqdm.close(), None, None)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        self.close()`