TuringLang / AdvancedHMC.jl

Robust, modular and efficient implementation of advanced Hamiltonian Monte Carlo algorithms
https://turinglang.org/AdvancedHMC.jl/
MIT License
228 stars 39 forks source link

Save gradients #351

Open JaimeRZP opened 10 months ago

JaimeRZP commented 10 months ago

Hi @sethaxen!

This should do what you were suggesting if you use the internal sampling method of AHMC and set drop_warmup=false and keep_gradient=true.

let me know if it works!

torfjelde commented 10 months ago

Generally speaking, it's not great to have the return-value change based on a keyword argument , as it leads to type-instabilities :confused: In this scenario, it's probably the not the worst (sample is generally going to be the outermost caller anyways), but it's still not great style.

IMO what we should do here is to

  1. Move AdvancedHMC.jl completely over to the AbstractMCMC.jl interface, i.e. we have step function that takes in a state containing all the necessary information.
  2. Then the gradients can easily be extracted through a callback which just extracts this information from the state.

(1) will also just be a huge gain in general, as it will make things much more modular:)