pyro-ppl / pyro

Deep universal probabilistic programming with Python and PyTorch
http://pyro.ai
Apache License 2.0
8.57k stars 986 forks source link

FR: subsampling-aware PyTorch optimizers #2386

Open fritzo opened 4 years ago

fritzo commented 4 years ago

@martinjankowiak suggested that, now that #1796 makes pyro.param aware of subsampling, we could in principle make PyTorch optimizers whose gradient statistics are updated only for those elements that appear in a subsample. This would give lower-variance gradient estimates and would be cheaper.

@martinjankowiak also points out that an alternative is to amortize the guide.

you can imagine two implementations:

  • one that explicitly makes use of subsample indices
  • one that uses grad != 0 as a proxy are there circumstances under which the proxy isn’t good enough?
fehiepsi commented 4 years ago

I believe this is important for my experiments so I want to fix this soon.

martinjankowiak commented 4 years ago

@fehiepsi i think the easiest way to get something that works would be to add an option to ClippedAdam like ignore_zero_gradient_stats=True/False

you'd then need to keep track of state['step'] on a per-coordinate basis (i.e. the optimizer would need more memory) and then you'd just need to change the updates to do vectorized/masked updates of the statistics

fritzo commented 4 years ago

@martinjankowiak can you explain why it is necessary to keep track of coordinate-wise state['step']? I would think that step could be approximated as global, since the poisson approximation concentrates (in contrast to the other statistics).

martinjankowiak commented 4 years ago

well you could make that approximation. but i was assuming a pretty generic optimizer that makes few assumptions (just don't update when grad is zero)

fritzo commented 4 years ago

I agree the coordinate-wise step is parsimonious. Another parsimonious assumption could be that the gradient distribution is a zero-inflated Normal, so that with slight modifications, the usual Adam statistics can learn that distribution's parameters. Both versions seem reasonable.