pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.55k stars 987 forks source link

TransformModule cache is invalid after optimizer step #2564

Open fritzo opened 4 years ago

fritzo commented 4 years ago

As observed in #2542 , the caching mechanism of torch.distributions.Transform works well for purely functional Transforms but can lead to incorrect results with stateful TransformModules.

Possible fixes include:

cc @stefanwebb

stefanwebb commented 4 years ago

I think calling .clear_cache() somehow in a backward hook would be a good solution. I haven't done this before in PyTorch but can look into it...