n_ensemble = 10
sghmc_transform = posteriors.sgmcmc.sghmc.build(
log_posterior, lr=5e-2, alpha=1.0, temperature=0.0
)
states = torch.vmap(sghmc_transform.init, randomness='different')(torch.randn(n_ensemble, 2))
# ValueError: vmap(functools.partial(<function init at 0x3027cf420>, momenta=None), ...):
#`functools.partial(<function init at 0x3027cf420>, momenta=None)` must only return Tensors,
# got type <class 'posteriors.sgmcmc.sghmc.SGHMCState'>. Did you mean to set out_dim= to None for output?
Currently you cannot do something like
I think we need to register
TransformState
as a pytree node withtorch.utils._pytree
following https://github.com/pytorch/functorch/issues/475