normal-computing / posteriors

Uncertainty quantification with PyTorch
https://normal-computing.github.io/posteriors/
Apache License 2.0
282 stars 12 forks source link

Problem composing `per_samplify` with `functools.partial` #78

Closed SamDuffield closed 2 months ago

SamDuffield commented 2 months ago
model = TestModel() # From tests.scenarios
params = dict(model.named_parameters())
batch_inputs = torch.randn(3, 10)
batch_labels = torch.randint(2, (3,)).unsqueeze(-1)
batch_spec = {"inputs": batch_inputs, "labels": batch_labels}

def log_likelihood(params, batch):
    output = torch.func.functional_call(model, params, batch["inputs"])
    return -torch.nn.BCEWithLogitsLoss()(output, batch["labels"].float())

log_likelihood_per_sample = per_samplify(log_likelihood)

Works fine:

log_likelihood_per_sample(params, batch_spec)

Throws ValueError:

partial(log_likelihood_per_sample, batch= batch_spec)(params)
# ValueError: vmap(f_per_sample, in_dims=(None, 0), ...)(<inputs>): in_dims is not compatible with the structure of `inputs`. in_dims has structure TreeSpec(tuple, None, [*,
#  *]) but inputs has structure TreeSpec(tuple, None, [TreeSpec(dict, ['linear.weight', 'linear.bias'], [*,
#    *])]).
SamDuffield commented 2 months ago

It appears this is a more general problem of composing functools.partial and torch.vmap

def f(a, x):
    return a + (x**2).sum()

a = 3.0
x_all = torch.randn(3, 10)
f(a, x_all)
torch.vmap(f, in_dims=(None, 0))(a, x_all) # Works fine
partial(torch.vmap(f, in_dims=(None, 0)), x=x_all)(a) # Throws error
# ValueError: vmap(f, in_dims=(None, 0), ...)(<inputs>): in_dims is not compatible with the structure of `inputs`. in_dims has structure TreeSpec(tuple, None, [*,
#  *]) but inputs has structure TreeSpec(tuple, None, [*]).
SamDuffield commented 2 months ago

This issue that in_dims does not support kwargs https://github.com/google/jax/issues/20914

This can be simply fixed with a wrapper ensuring no kwargs in per_samplify

SamDuffield commented 2 months ago

Fixed in #79