normal-computing / posteriors

Uncertainty quantification with PyTorch
https://normal-computing.github.io/posteriors/
Apache License 2.0
314 stars 12 forks source link

`TransformState` doesn't work with `torch.vmap` #83

Open SamDuffield opened 5 months ago

SamDuffield commented 5 months ago

Currently you cannot do something like

n_ensemble = 10

sghmc_transform = posteriors.sgmcmc.sghmc.build(
    log_posterior, lr=5e-2, alpha=1.0, temperature=0.0
)

states = torch.vmap(sghmc_transform.init, randomness='different')(torch.randn(n_ensemble, 2))
# ValueError: vmap(functools.partial(<function init at 0x3027cf420>, momenta=None), ...): 
#`functools.partial(<function init at 0x3027cf420>, momenta=None)` must only return Tensors,
# got type <class 'posteriors.sgmcmc.sghmc.SGHMCState'>. Did you mean to set out_dim= to None for output?

I think we need to register TransformState as a pytree node with torch.utils._pytree following https://github.com/pytorch/functorch/issues/475

SamDuffield commented 2 months ago

To fully support this I think we might need to enforce aux to be a TensorTree