Open rafaol opened 2 years ago
Hi @rafaol, tl;dr don't use the jit, it isn't useful.
I have found that JitTrace_ELBO
only works on very simple models, and provides little or no speed improvement. I believe this is because the PyTorch team's goals are more to execute models outside of a Python runtime, in contrast to the JAX team's goals which are more performance oriented. Historically, torch.jit.trace used to slightly speed up some models, but the jit has gradually grown slower so it is no longer useful in Pyro. I've also found that the torch.jit changes so often across PyTorch releases that I can't rely on my jitted models remaining jittable across minor PyTorch releases.
Thanks @fritzo ! I was actually planning to use JIT on a larger and more complex model (a random effects/hierarchical), which had a significant (10x) speed up in SVI when I tried JIT with a simpler guide (AutoMultivariateNormal
and AutoLaplaceApproximation
). I wish I could get the same kind of speed ups with normalizing flows, since they're way more flexible and expressive, but I'll have a look at other options within Pyro.
Issue Description
Hi all, I'm having issues trying to use normalizing flows with JIT and Pyro's SVI. Code runs fine if I use the standard
Trace_ELBO
(or evenTraceEnum_ELBO
in models with discrete variables), but it fails if I replace the ELBO by its JIT-compiled version. The issue seems to be caused by the use of weak references (weakref
) when checking unconstrained parameter values. I was wondering if anyone else is having the same issue, and if there's a fix coming up on the horizon.Environment
Code Snippet
Here is a minimal example. The following code runs normally if I use
Trace_ELBO
as the loss function for SVI, but fails withJitTrace_ELBO
.Console output
When I check the runtime environment with the debugger, I see that
constrained_param.unconstrained()
andunconstrained_param
match in value, butconstrained_param.unconstrained()
holds aweakref
to aParameter
object, whileunconstrained_param
points to aTensor
. The gradient'sgrad_fn
also seems to be mismatched.Values
Pointer addresses:
weakref.ref(constrained_param.unconstrained())
=<weakref at 0x7fa425962180; to 'Parameter' at 0x7fa425959040>
weakref.ref(unconstrained_param)
=<weakref at 0x7fa42408e090; to 'Tensor' at 0x7fa426555090>