pyro-ppl / numpyro

Probabilistic programming with NumPy powered by JAX for autograd and JIT compilation to GPU/TPU/CPU.
https://num.pyro.ai
Apache License 2.0
2.1k stars 227 forks source link

AttributeError: Can't pickle local object 'ESS.DifferentialMove.<locals>.make_differential_move.<locals>.differential_move' #1747

Closed amunozj closed 4 months ago

amunozj commented 4 months ago

Hello friends. I'm trying to save an ESS MCMC run, but I'm getting the error message:

AttributeError: Can't pickle local object 'ESS.DifferentialMove.<locals>.make_differential_move.<locals>.differential_move'

Any suggestions as to how to fix it?

amunozj commented 4 months ago

Oh I think I see the problem. That function has functions within functions:

    def DifferentialMove():
        """
        The `Karamanis & Beutler (2020) <https://arxiv.org/abs/2002.06212>`_ "Differential Move" with parallelization.
        When this Move is used the walkers move along directions defined by random pairs of walkers sampled (with no
        replacement) from the complementary ensemble. This is the default choice and performs well along a wide range
        of target distributions.
        """
        def make_differential_move(n_chains):
            PAIRS = get_nondiagonal_indices(n_chains // 2)

            def differential_move(rng_key, inactive, mu):
                n_active_chains, n_params = inactive.shape

                selected_pairs = random.choice(rng_key, PAIRS, shape=(n_active_chains,))
                diffs = jnp.diff(inactive[selected_pairs], axis=1).squeeze(
                    axis=1
                )  # get the pairwise difference of each vector

                return 2.0 * mu * diffs
            return differential_move

        return make_differential_move
amunozj commented 4 months ago

Ok. I can pickle if I use dill, instead of pickle