Open cagatayyildiz opened 3 years ago
Yeah, this is because of the way the variational distribution is stored (e.g. caching) --- haven't quite figured out if it is a bug yet.
A couple of workarounds that should be fine, not entirely sure which would be preferred:
opt = torch.optim.Adam(model.parameters())
for i in range(10):
opt.zero_grad()
variational_dist = model.variational_strategy.variational_distribution
loss = -variational_dist.rsample(torch.Size([10])).sum()
loss.backward()
print('iter={:d}, \tloss={:.3f}'.format(i,loss.item()))
del model.variational_strategy._memoize_cache # deletes the memoize cache which is what's building up
opt.step()
opt = torch.optim.Adam(model.parameters())
for i in range(10):
opt.zero_grad()
# use the private method to regenerate the variational distribution
variational_dist = model.variational_strategy._variational_distribution()
loss = -variational_dist.rsample(torch.Size([10])).sum()
loss.backward()
print('iter={:d}, \tloss={:.3f}'.format(i,loss.item()))
opt.step()
I think this counts as a bug :) In variational_strategy
, we clear the cache (which stores the current variational distribution) when __call__
is called, but we do not clear otherwise. Since you are not calling your model, the cache doesn't get cleared.
I'm not sure exactly how to best deal with this issue. We could do some form of less aggressive caching, or smarter clearing of the cache. @wjmaddox's workaround should be good in the meantime.
Hi,
I have been implementing decoupled sampling of GPs and bumped into a rather strange issue. I tried to simplify it by the following piece of code, which is a sparse GP example in which we try to minimize a stupid loss related to the variational posterior:
Now, if we do not call
model.train()
before each gradient computation, I get the following error after the first iteration:Trying to backward through the graph a second time (or directly access saved variables after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved variables after calling backward.
Why is this happening? I guess either this is a trivial bug or I miss something quite trivial.
Thanks, Cagatay.