JaimeRZP / MicroCanonicalHMC.jl

Implementation of Arxiv:2212.08549 in Julia
MIT License
11 stars 2 forks source link

reverse mode AD #12

Open thorek1 opened 4 months ago

thorek1 commented 4 months ago

I would like to try your algorithm but my application requires reverse mode AD (Zygote to be more specific). Do you support it? My reading of the code is that it supports ForwardDiff.jl only for now

JaimeRZP commented 4 months ago

Hi Thorek! I am currently using MCHMC with Zygote so it should work. The problem might be at the Turing level.

thorek1 commented 4 months ago

Gotcha. I will try that in the meantime then. Turing with Zygote is an issue on Turings end? Thanks

JaimeRZP commented 4 months ago

There's been a lot of work to get Zygote to get working properly in Turing in the last releases. I would make sure I am using the latest version (0.32?).

thorek1 commented 4 months ago

I use >=0.32 as well. Just to make sure, here is the alternative (from a user perspective): samps = Turing.sample(loglikelihood_fn, NUTS(adtype = AutoZygote()), n_samples). Ideally the micro canonical hmc sampler supports a similar way of switching ad backend.

JaimeRZP commented 4 months ago

yes, what you are looking for is:

# Define sampler
mchmc = MCHMC(n_adapts, tev; adaptive=true)
sampler = externalsampler(mchm; adtype=AutoZygote())
# Sample
chain = Turing.sample(model,  sampler, 10_000)

hope this helps!

thorek1 commented 3 months ago

I understand this isn't an issue with the package. hence, I would recommend closing the issue.

now that I got it to work I wanted to share that for my use case I need much more adaptation draws and samples to get comparable ESS with NUTS [MCHMC (50000,20000) vs NUTS (3000)], while MCHMC is faster.

Do you have any hint what to do when during the tuning phase it shows NaN. I can restart the procedure and it might find a valid point but even then I saw it converging to epsilon = 0 and not recovering from there. Starting from the/a mode does not help either. NUTS does work in this case.

check here for the example application I used with both NUTS and MCHMC