Open mattkiim opened 3 months ago
Hi, it appears that JAX has deprecated jax.experimental.host_callback as of March 21, 2024 (google/jax/issues/20385). I've implemented a working solution below for the TqdmWrapper class, located in solver.py. This resolved all issues for me.
jax.experimental.host_callback
TqdmWrapper
solver.py
class TqdmWrapper: def __init__(self, tqdm, reference_time, total, *args, **kwargs): self.reference_time = reference_time jax.experimental.io_callback(lambda total: self._create_tqdm(tqdm, total, *args, **kwargs), None, total) def _create_tqdm(self, tqdm, total, *args, **kwargs): self._tqdm = tqdm.tqdm(total=total, *args, **kwargs) def update_to(self, n): jax.experimental.io_callback(lambda n: self._tqdm.update(n - self._tqdm.n), None, n) def close(self): jax.experimental.io_callback(lambda _: self._tqdm.close(), None, None) def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.close()`
Hi, it appears that JAX has deprecated
jax.experimental.host_callback
as of March 21, 2024 (google/jax/issues/20385). I've implemented a working solution below for theTqdmWrapper
class, located insolver.py
. This resolved all issues for me.