probabilists / lampe

Likelihood-free AMortized Posterior Estimation with PyTorch
https://lampe.readthedocs.io
MIT License
119 stars 11 forks source link

Differences in batched vs. non-batched FMPE log_prob #18

Closed LGro closed 5 months ago

LGro commented 6 months ago

Description

When computing the log probability with FMPE's log_prob method, the resulting probability values depend on the other input elements in the batch. The change I saw was in the order of the third or fourth decimal place.

In any case, thanks already a lot for your work on LAMPE :relaxed:

Reproduce

Following the example, the two ways to compute log probabilities for a given configuration theta and batch of corresponding simulated results x produce different results:

from itertools import islice

import torch
import torch.nn as nn
import torch.optim as optim
import zuko
from lampe.data import JointLoader
from lampe.inference import FMPE, FMPELoss
from lampe.utils import GDStep
from tqdm import tqdm

LABELS = [r"$\theta_1$", r"$\theta_2$", r"$\theta_3$"]
LOWER = -torch.ones(3)
UPPER = torch.ones(3)

prior = zuko.distributions.BoxUniform(LOWER, UPPER)

def simulator(theta: torch.Tensor) -> torch.Tensor:
    x = torch.stack(
        [
            theta[..., 0] + theta[..., 1] * theta[..., 2],
            theta[..., 0] * theta[..., 1] + theta[..., 2],
        ],
        dim=-1,
    )

    return x + 0.05 * torch.randn_like(x)

theta = prior.sample()
x = simulator(theta)

loader = JointLoader(prior, simulator, batch_size=256, vectorized=True)

estimator = FMPE(3, 2, hidden_features=[64] * 5, activation=nn.ELU)

loss = FMPELoss(estimator)
optimizer = optim.AdamW(estimator.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, 128)
step = GDStep(optimizer, clip=1.0)  # gradient descent step with gradient clipping

estimator.train()

with tqdm(range(128), unit="epoch") as tq:
    for epoch in tq:
        losses = torch.stack(
            [
                step(loss(theta, x))
                for theta, x in islice(loader, 256)  # 256 batches per epoch
            ]
        )

        tq.set_postfix(loss=losses.mean().item())

        scheduler.step()

theta_star = prior.sample()
X = torch.stack([simulator(theta_star) for _ in range(10)])

estimator.eval()

with torch.no_grad():
    # e.g. [3.1956, 1.8184, 2.4533, 1.6461, 3.0488, 2.5868, 2.7055, 2.7679, 3.3405, 1.5554]
    log_p_one_batch = estimator.flow(X).log_prob(theta_star.repeat(len(X), 1))

    # e.g. [3.1978, 1.8175, 2.4526, 1.6468, 3.0495, 2.5894, 2.7065, 2.7712, 3.3385, 1.5558]
    log_p_individual = [estimator.flow(x).log_prob(theta_star) for x in X]

Expected behavior

I would expect that the individual log probability values for one theta and x pair are not affected by the other entries in the X batch. This is corroborated by the official implementation not showing that behaviour when evaluating log_prob_batch with different subsets for the batch.

In the above example, I would expect both to e.g. result in [3.1978, 1.8175, 2.4526, 1.6468, 3.0495, 2.5894, 2.7065, 2.7712, 3.3385, 1.5558].

Causes and solution

I have no clear intuition why that would be the case. I suspected a stochastic influence and that the FreeFormJacobianTransform exact mode might help, but it seems to be a deterministic difference and settings exact=true did not affect that accordingly. I noticed that the LAMPE implementation utilizes a trigonometrical embedding of the time dimension for the vector field computation when the official implementation by the authors does not, but it's also not obvious to me that this would explain the difference.

Environment

francois-rozet commented 6 months ago

Hello @LGro,

Thank you for reporting this bug. I think this comes from the tolerances used in FreeFormJacobianTransform which are way higher in LAMPE/Zuko (1e-5) than in the original implementation (1e-7).

Could you try to modify the atol and rtol in the FreeFormJacobianTransform of log_prob and repeat your experiments?

Also it might be worth running in double precision (float64).

LGro commented 6 months ago

Thanks for digging into this issue with me :relaxed:

Indeed, shrinking the tolerances while running with the estimator and inputs at float64 precision does reduce the initially observed discrepancy. Do I understand it right that the discrepancy is not problematic per-se as long as the magnitude is irrelevant for one's application?

francois-rozet commented 6 months ago

does reduce the initially observed discrepancy

It does not vanish with both absolute and relative tolerances at 1e-7?

Do I understand it right that the discrepancy is not problematic per-se as long as the magnitude is irrelevant for one's application?

Yes the discrepancy is not an implementation or method issue, but a numerical issue. If it is small enough, it should not affect downstream tasks. It could be worth adding the option to modify the tolerances in the FMPE class though, or maybe a warning in the doc-string.

LGro commented 6 months ago

It does not vanish with both absolute and relative tolerances at 1e-7?

For tolerances at 1e-9 the differences go down to the order of 1e-5 or 1e-6, which was enough of an indicator for me. I have not tried to push it to the limit of float64 precision.

francois-rozet commented 6 months ago

How does this compare to the official implementation (at 1e-7)? If at the same tolerance the official implementation shows less discrepancies between batched/unbatched, it could be worth investigating further.