Julia-Tempering / Pigeons.jl

Distributed and parallel sampling from intractable distributions
https://pigeons.run/dev/
GNU Affero General Public License v3.0
73 stars 9 forks source link

AutoMALA without autodiff / with custom autodiff #247

Open sefffal opened 1 week ago

sefffal commented 1 week ago

This might be a documentation request, or a feature request.

Currently the AutoMALA constructor takes the argument default_autodiff_backend, which it uses via LogDensityProblemsAD to differentiate the target. However, the LogDensityProblems interface allows one to provide a gradient function directly. This could be an analytic gradient, or it could calculate the gradient via autodiff itself.

Can AutoMALA use the user-provided gradient function and only fall back to doing auto-diff itself if not provided?

I ran into this problem since I was supplying a LogDensityProblem with an enzyme-backed gradient, but AutoMALA was still using ForwardDiff.

miguelbiron commented 6 days ago

Hey William -- we should probably document this but the way to do it is to write a method for LogDensityProblemsAD.ADGradient for your own log potential. The best example for this is to look at our implementation for handling Stan targets with autoMALA

https://github.com/Julia-Tempering/Pigeons.jl/blob/42c260ecf822a884d20ae717a8b2e69738187933/ext/PigeonsBridgeStanExt/interface.jl#L57-L76

Here you can see that, regardless of the AD symbol passed, we always compute the logdensity and gradient using BridgeStan.

sefffal commented 6 days ago

Thanks for the pointer @miguelbiron! I'll give that a try.

Is this something that could be done automatically when LogDensityProblems.capabilities(logdensityproblem) >= LogDensityProblems.LogDensityOrder{1}()? If I understand correctly, that value which a compile time constant should let one know if the target already supplies a gradient or requires autodiff.

miguelbiron commented 6 days ago

ooh that would be really nice to have. I'm not too familiar with the whole LogDensityProblems but it does seem doable.

miguelbiron commented 5 days ago

So I was thinking about how to do this and the easiest approach would be to replace

https://github.com/Julia-Tempering/Pigeons.jl/blob/66026660b4dd9de70a6cd7a8c0c2e37693fd444f/src/explorers/BufferedAD.jl#L61-L62

with

function LogDensityProblemsAD.ADgradient(kind::Symbol, log_potential, buffers::Augmentation)
    cap = LogDensityProblems.capabilities(log_potential)
    if !isnothing(cap) && cap >= LogDensityProblems.LogDensityOrder{1}()
        log_potential
    else
        LogDensityProblemsAD.ADgradient(kind, log_potential)
    end
end

This is fine unless you want it to be non-allocating. Then two complications arise

  1. The user might have used a buffer to construct their log_potential, which would result in race conditions if exploration is multithreaded, since the buffer would be written at the same time by all replicas.
  2. There is no way of taking advantage of Pigeons buffer system since log_potential is already constructed here.

So it seems like the only way would be for the user to define a method for

function LogDensityProblemsAD.ADgradient(kind::Symbol, log_potential::MyLogPotential, buffers::Augmentation)
    ...
end

that constructs an object that encloses a vector from the buffer and uses it to implement logdensity_and_gradient in a nonallocating way.

sefffal commented 5 days ago

Thanks @miguelbiron great points about buffers and multithreading, I hadn't considered this. In fact, my own autodiff calls wouldn't have been thread safe!

So I agree that even if the log density model provides a gradient, it shouldn't be used by default by Pigeons with multi threading. Maybe it could still be used with MPI though?

As a minimal change, what about adding a log message (with maxlog=1 set) when LogDensityProblems.capabilities(log_potential) >= LogDensityProblems.LogDensityOrder{1}() to at least alert the user that their gradient isn't being used.

Thanks!

miguelbiron commented 5 days ago

Yeah that's another option. Or we could even allow using the logpotential but also emit a warn that says that you better not have a shared buffer in it if you're planning on using multi threading. 'Cause maybe that's fine for some people? Like you say, maybe single threaded MPI is better suited for that application.

I'm interested in what others think about this too. Also considering the possibility of us implementing the buffer for some backends as suggested in #249 .

alexandrebouchard commented 5 days ago
  • The user might have used a buffer to construct their log_potential, which would result in race conditions if exploration is multithreaded, since the buffer would be written at the same time by all replicas.

Actually the buffers are replica specific if I recall correctly (double check that though)

alexandrebouchard commented 5 days ago

Relevant source location I think is: https://github.com/Julia-Tempering/Pigeons.jl/blob/822dc8251fd47fe12b321f3c497b8b99ffe1b504/src/explorers/Augmentation.jl#L5

alexandrebouchard commented 5 days ago

And note that multithreaded exploration has the property that no two threads touch the same replica at the same time.

miguelbiron commented 5 days ago

Hey @alexandrebouchard -- that would be the case if we designed the buffered AD, like for stan targets. But if the user themselves designed their own in-place buffered log potential (with the buffer in the same struct, as @sefffal apparently was trying to do) then for sure you will get race conditions. We don't deepcopy logpotentials across replicas to ensure that this wouldn't happen.