Closed agrawalraj closed 11 months ago
Would switching from estimating the Fisher with a product of first-order gradients to estimating Fisher-vector products with Hessian-vector products in make_empirical_fisher_vp
address issue 1?
Would switching from a product of first-order gradients to Hessian-vector products in
make_empirical_fisher_vp
address issue 1?
That's a good point! I want to think a bit more about this but I think you're right that it could also address "Issue 1"
I think for issue 2 we can either pin the samples used in NMCLogPredictiveLikelihood
at the level of a single linearize
call as you suggest or use a different solver that tolerates stochasticity.
Pinning the samples is conceptually easier, but it might be tricky to get PyTorch to propagate gradients through them correctly in multiple Fisher-vector product calls.
Using a different solver will probably require more reading, although I think vanilla stochastic gradient descent would work correctly as a baseline at the cost of a significant slowdown relative to CG.
Ok cool. I'm going to make a simple modification to NMCLogPredictiveLikelihood
and test to see how pinning samples composes with PyTorch gradients.
Would switching from estimating the Fisher with a product of first-order gradients to estimating Fisher-vector products with Hessian-vector products in
make_empirical_fisher_vp
address issue 1?
I think the Hessian formulation might be pretty nice. With the current score formulation of Fisher, we need the number of outer monte carlo samples to, at minimum, exceed the dimension of the guide parameters for the empirical fisher matrix to be invertible. With the Hessian formulation, we might only need a single Monte Carlo sample for the inverse to exist, although we usually would like more Monte Carlo samples for a good approximation still.
@eb8680 Here's one simple way to pin down samples while not breaking gradients. The idea is to set the seed and reset the seed back to old state in forward
as in pyro.poutine.seed_messenger.SeedMessenger
. There are only 3 new line additions to our old implementation (marked with the comment "# New" below):
from pyro.util import get_rng_state, set_rng_seed, set_rng_state
class NMCLogPredictiveLikelihood(Generic[P, T], torch.nn.Module):
model: Callable[P, Any]
guide: Callable[P, Any]
num_samples: int
max_plate_nesting: Optional[int]
def __init__(
self,
model: torch.nn.Module,
guide: torch.nn.Module,
*,
num_samples: int = 1,
max_plate_nesting: Optional[int] = None,
rng_seed: int = 123, # NEW ARGUMENT
):
super().__init__()
self.model = model
self.guide = guide
self.num_samples = num_samples
self.max_plate_nesting = max_plate_nesting
self.rng_seed = rng_seed
self.old_state = get_rng_state()
def forward(
self, data: Point[T], *args: P.args, **kwargs: P.kwargs
) -> torch.Tensor:
set_rng_seed(self.rng_seed) # NEW LINE
if self.max_plate_nesting is None:
self.max_plate_nesting = guess_max_plate_nesting(
self.model, self.guide, *args, **kwargs
)
masked_guide = pyro.poutine.mask(mask=False)(self.guide)
masked_model = _UnmaskNamedSites(names=set(data.keys()))(
condition(data=data)(self.model)
)
log_weights = pyro.infer.importance.vectorized_importance_weights(
masked_model,
masked_guide,
*args,
num_samples=self.num_samples,
max_plate_nesting=self.max_plate_nesting,
**kwargs,
)[0]
set_rng_state(self.old_state) # NEW LINE
return torch.logsumexp(log_weights, dim=0) - math.log(self.num_samples)
Note: I first tried to just wrap the guide in SeedMessenger
but that did not compose with make_functional_call
. In particular, if you call make_functional_call
on SeedMessenger(123)(guide)
, then the log_prob_params
returned is the empty dictionary.
Does that code using set_rng_state
work correctly with vmap
/jvp
/vjp
/grad
?
Yes I believe so. I just fed it directly into make_empirical_fisher_vp
which uses (vmap, jvp, and vjp) and the result was non-zero. So it seems like gradients are propagating.
Do they match across calls?
Yes they do I'm going to commit the test now!
@eb8680 the test is called test_nmc_likelihood_seeded
in https://github.com/BasisResearch/chirho/blob/ra-pinned-nmc/tests/robust/test_internals_compositions.py
Great, go ahead and make a PR that includes that test.
Ok cool I'll do that now. @eb8680 I made the PR here: https://github.com/BasisResearch/chirho/pull/408.
Closed by #428, #408
In the tests so far in #405, I've been focusing on the MLE case (i.e., when the guide is just equal to a point estimate) since I have closed-form analytical expressions for testing purposes. I've been trying to construct more tests for more general guides, and I think there's an issue composing
make_empirical_fisher_vp
, conjugate gradients, andNMCLogPredictiveLikelihood
in general. There are several issues all basically rooted in the fact that each call toNMCLogPredictiveLikelihood
is random, namely iflog_prob = NMCLogPredictiveLikelihood(model, guide)
andcall_one = log_prob(x_n)
andcall_two = log_prob(x_n)
, thencall_one != call_two
unless the guide is just a point estimate guide.Issue 1: Our empirical fisher info equals
1/N \sum_{n=1}^N grad log_prob(x_n) grad log_prob(x_n)^T. Since
log_prob(x_n) != log_prob(x_n)
for two successive calls, the empirical Fisher matrix is no longer a symmetric matrix. In general, vanilla conjugate gradients do not work for non-symmetric matrices but there are more general CG methods. For a counterexample, see https://math.stackexchange.com/questions/4795473/conjugate-gradient-descent-when-a-is-non-singular-but-not-symmetric.Issue 2: Conjugate gradients assumes that there is a fixed matrix A, and we want to solve for x in Ax = b. In our implementation, there is not a fixed A across CG iterations because
fvp(v) != fvp(v)
, wherefvp = make_empirical_fisher_vp(func_log_prob, log_prob_params, data)
. Again this occurs becauselog_prob
is randomized.Potential solution: One option would be to sample the particles in
NMCLogPredictiveLikelihood
in__init__
so that successive calls would be deterministic. That would fix "Issue 2". Assuming we make this change, there are two options for addressing "Issue 1".Option 1: Use 1/N \sum_{n=1}^N grad log_prob(x_n) grad log_prob(x_n)^T, where log_prob has particles sampled in
__init__
instead offorward
method. Then, this matrix is symmetric but biased because E[grad log_prob(x_n) grad log_prob(x_n)^T] != E[grad log_prob(x_n)] E[grad log_prob(x_n)]^T since the same particles are used for both calls to log_prob.Option 2: Instantiate logprob twice. Here, we could use 1/N \sum{n=1}^N grad log_prob1(x_n) grad log_prob2(x_n)^T as the estimate. Then, E[grad log_prob1(x_n) grad log_prob2(x_n)^T] = E[grad log_prob1(x_n)] E[grad log_prob2(x_n)]^T since the particles used for log_prob1 are drawn independently of those used in log_prob2. Here, the empirical fisher matrix is not symmetric in general, so we would could not just use vanilla conjugate gradients.
Current Test: This simple test here shows how unstable conjugate gradient behaves under our current implementation: https://github.com/BasisResearch/chirho/blob/ra-sw-fisher-tests/tests/robust/test_internals_compositions.py.
This issue will be closed once #408 and #428 are complete.