BasisResearch / chirho

An experimental language for causal reasoning
https://basisresearch.github.io/chirho/getting_started.html
Apache License 2.0
169 stars 12 forks source link

Issue composing `make_empirical_fisher_vp`, conjugate gradients, and `NMCLogPredictiveLikelihood` #407

Closed agrawalraj closed 11 months ago

agrawalraj commented 11 months ago

In the tests so far in #405, I've been focusing on the MLE case (i.e., when the guide is just equal to a point estimate) since I have closed-form analytical expressions for testing purposes. I've been trying to construct more tests for more general guides, and I think there's an issue composing make_empirical_fisher_vp, conjugate gradients, and NMCLogPredictiveLikelihood in general. There are several issues all basically rooted in the fact that each call to NMCLogPredictiveLikelihood is random, namely if log_prob = NMCLogPredictiveLikelihood(model, guide) and call_one = log_prob(x_n) and call_two = log_prob(x_n), then call_one != call_two unless the guide is just a point estimate guide.

Issue 1: Our empirical fisher info equals

1/N \sum_{n=1}^N grad log_prob(x_n) grad log_prob(x_n)^T. Since log_prob(x_n) != log_prob(x_n) for two successive calls, the empirical Fisher matrix is no longer a symmetric matrix. In general, vanilla conjugate gradients do not work for non-symmetric matrices but there are more general CG methods. For a counterexample, see https://math.stackexchange.com/questions/4795473/conjugate-gradient-descent-when-a-is-non-singular-but-not-symmetric.

Issue 2: Conjugate gradients assumes that there is a fixed matrix A, and we want to solve for x in Ax = b. In our implementation, there is not a fixed A across CG iterations because fvp(v) != fvp(v), where fvp = make_empirical_fisher_vp(func_log_prob, log_prob_params, data). Again this occurs because log_prob is randomized.

Potential solution: One option would be to sample the particles in NMCLogPredictiveLikelihood in __init__ so that successive calls would be deterministic. That would fix "Issue 2". Assuming we make this change, there are two options for addressing "Issue 1".

Option 1: Use 1/N \sum_{n=1}^N grad log_prob(x_n) grad log_prob(x_n)^T, where log_prob has particles sampled in __init__ instead of forward method. Then, this matrix is symmetric but biased because E[grad log_prob(x_n) grad log_prob(x_n)^T] != E[grad log_prob(x_n)] E[grad log_prob(x_n)]^T since the same particles are used for both calls to log_prob.

Option 2: Instantiate logprob twice. Here, we could use 1/N \sum{n=1}^N grad log_prob1(x_n) grad log_prob2(x_n)^T as the estimate. Then, E[grad log_prob1(x_n) grad log_prob2(x_n)^T] = E[grad log_prob1(x_n)] E[grad log_prob2(x_n)]^T since the particles used for log_prob1 are drawn independently of those used in log_prob2. Here, the empirical fisher matrix is not symmetric in general, so we would could not just use vanilla conjugate gradients.

Current Test: This simple test here shows how unstable conjugate gradient behaves under our current implementation: https://github.com/BasisResearch/chirho/blob/ra-sw-fisher-tests/tests/robust/test_internals_compositions.py.

This issue will be closed once #408 and #428 are complete.

eb8680 commented 11 months ago

Would switching from estimating the Fisher with a product of first-order gradients to estimating Fisher-vector products with Hessian-vector products in make_empirical_fisher_vp address issue 1?

agrawalraj commented 11 months ago

Would switching from a product of first-order gradients to Hessian-vector products in make_empirical_fisher_vp address issue 1?

That's a good point! I want to think a bit more about this but I think you're right that it could also address "Issue 1"

eb8680 commented 11 months ago

I think for issue 2 we can either pin the samples used in NMCLogPredictiveLikelihood at the level of a single linearize call as you suggest or use a different solver that tolerates stochasticity.

Pinning the samples is conceptually easier, but it might be tricky to get PyTorch to propagate gradients through them correctly in multiple Fisher-vector product calls.

Using a different solver will probably require more reading, although I think vanilla stochastic gradient descent would work correctly as a baseline at the cost of a significant slowdown relative to CG.

agrawalraj commented 11 months ago

Ok cool. I'm going to make a simple modification to NMCLogPredictiveLikelihood and test to see how pinning samples composes with PyTorch gradients.

agrawalraj commented 11 months ago

Would switching from estimating the Fisher with a product of first-order gradients to estimating Fisher-vector products with Hessian-vector products in make_empirical_fisher_vp address issue 1?

I think the Hessian formulation might be pretty nice. With the current score formulation of Fisher, we need the number of outer monte carlo samples to, at minimum, exceed the dimension of the guide parameters for the empirical fisher matrix to be invertible. With the Hessian formulation, we might only need a single Monte Carlo sample for the inverse to exist, although we usually would like more Monte Carlo samples for a good approximation still.

agrawalraj commented 11 months ago

@eb8680 Here's one simple way to pin down samples while not breaking gradients. The idea is to set the seed and reset the seed back to old state in forward as in pyro.poutine.seed_messenger.SeedMessenger. There are only 3 new line additions to our old implementation (marked with the comment "# New" below):

from pyro.util import get_rng_state, set_rng_seed, set_rng_state

class NMCLogPredictiveLikelihood(Generic[P, T], torch.nn.Module):
    model: Callable[P, Any]
    guide: Callable[P, Any]
    num_samples: int
    max_plate_nesting: Optional[int]

    def __init__(
        self,
        model: torch.nn.Module,
        guide: torch.nn.Module,
        *,
        num_samples: int = 1,
        max_plate_nesting: Optional[int] = None,
        rng_seed: int = 123, # NEW ARGUMENT
    ):
        super().__init__()
        self.model = model
        self.guide = guide
        self.num_samples = num_samples
        self.max_plate_nesting = max_plate_nesting
        self.rng_seed = rng_seed
        self.old_state = get_rng_state()

    def forward(
        self, data: Point[T], *args: P.args, **kwargs: P.kwargs
    ) -> torch.Tensor:
        set_rng_seed(self.rng_seed) # NEW LINE
        if self.max_plate_nesting is None:
            self.max_plate_nesting = guess_max_plate_nesting(
                self.model, self.guide, *args, **kwargs
            )

        masked_guide = pyro.poutine.mask(mask=False)(self.guide)
        masked_model = _UnmaskNamedSites(names=set(data.keys()))(
            condition(data=data)(self.model)
        )
        log_weights = pyro.infer.importance.vectorized_importance_weights(
            masked_model,
            masked_guide,
            *args,
            num_samples=self.num_samples,
            max_plate_nesting=self.max_plate_nesting,
            **kwargs,
        )[0]
        set_rng_state(self.old_state) # NEW LINE
        return torch.logsumexp(log_weights, dim=0) - math.log(self.num_samples)

Note: I first tried to just wrap the guide in SeedMessenger but that did not compose with make_functional_call. In particular, if you call make_functional_call on SeedMessenger(123)(guide), then the log_prob_params returned is the empty dictionary.

eb8680 commented 11 months ago

Does that code using set_rng_state work correctly with vmap/jvp/vjp/grad?

agrawalraj commented 11 months ago

Yes I believe so. I just fed it directly into make_empirical_fisher_vp which uses (vmap, jvp, and vjp) and the result was non-zero. So it seems like gradients are propagating.

eb8680 commented 11 months ago

Do they match across calls?

agrawalraj commented 11 months ago

Yes they do I'm going to commit the test now!

agrawalraj commented 11 months ago

@eb8680 the test is called test_nmc_likelihood_seeded in https://github.com/BasisResearch/chirho/blob/ra-pinned-nmc/tests/robust/test_internals_compositions.py

eb8680 commented 11 months ago

Great, go ahead and make a PR that includes that test.

agrawalraj commented 11 months ago

Ok cool I'll do that now. @eb8680 I made the PR here: https://github.com/BasisResearch/chirho/pull/408.

agrawalraj commented 11 months ago

Closed by #428, #408