Open chris-n-self opened 3 years ago
Yes this seems to be a known regression with jax
-- https://github.com/google/jax/issues/5043, where jitted functions can't be pickled, even by cloudpickle
. I don't have any immediate solution other than to try a different autodiff_backend
..
I don't seem to be able to use a Jax autodiff backend with a TNOptimizer object and also use a dask executor to parallelise over a multi-component loss function.
Versions: quimb -- 1.3.0+276.gee67688 cotengra -- up to date with GitHub version jax -- 0.2.10 jaxlib -- 0.1.62 dask -- 2021.3.0 cloud pickle -- 1.6.0
Example that gives me the problem:
The end of the Traceback is: