As observed in #2542 , the caching mechanism of torch.distributions.Transform works well for purely functional Transforms but can lead to incorrect results with stateful TransformModules.
Possible fixes include:
Call .clear_cache() in a backward hook via .register_backward(). This may be overly aggressive in clearing the cache: is it still compatible with torch.autograd.grad()?
As observed in #2542 , the caching mechanism of
torch.distributions.Transform
works well for purely functionalTransform
s but can lead to incorrect results with statefulTransformModule
s.Possible fixes include:
.clear_cache()
in a backward hook via.register_backward()
. This may be overly aggressive in clearing the cache: is it still compatible withtorch.autograd.grad()
?cc @stefanwebb