TuringLang / Turing.jl

Bayesian inference with probabilistic programming.
https://turinglang.org
MIT License
2.01k stars 217 forks source link

Reverse-mode AD extremely slow for large number of observations #1642

Closed anhi closed 1 year ago

anhi commented 3 years ago

When trying to optimize our Turing code, we experimented with the different AD engines. It seems as if the reverse-mode AD engines are extremely slow for large numbers of observations. Our original model has several hundred dimensions, but the effect can be demonstrated on this simple example:

setadbackend(:forwarddiff)

@model benchmark_model(x) = begin
    μ ~ TruncatedNormal(1, 2, 0.1, 10)
    σ ~ TruncatedNormal(1, 2, 0.1, 10)

    x .~ LogNormal(μ, σ)   
end

samples = rand(LogNormal(1.5, 0.5), 100000);

@time chains = Turing.sample(benchmark_model(samples), NUTS(0.65), 2000)

On my machine, this takes about 42 seconds: 41.868502 seconds (21.57 M allocations: 1.340 GiB, 1.49% gc time, 18.19% compilation time)

I understand that for such a simple model, forward-mode should be more efficient. But when switching to reverse-mode

setadbackend(:zygote)

@time chains = Turing.sample(benchmark_model(samples), NUTS(0.65), 2000)

it takes several hours on my machine to even arrive at an ETA, which starts out at several days. This seems a little excessive.

Are we doing anything wrong, or is reverse-mode just not useable for large numbers of observations?

torfjelde commented 3 years ago

Try filldist instead of .~. That should spead things up significantly.

Essentially, filldist means that the entire vector x is treated as one multivariate random variable while .~ makes it so that you get length(x) univariate random variables. The former is going to be signficantly faster.

yebai commented 3 years ago

@torfjelde @mohamed82008 A side note, any reason that we shouldn't translate dot observe and dot assume into filldist automatically in DynamicPPL?

anhi commented 3 years ago

@torfjelde thank you for the hint. I'm still a little confused: did you mean

x = filldist(LogNormal(μ, σ), 100000)

which is blazingly fast but yields wrong values for \mu and \sigma, or

x ~ filldist(LogNormal(μ, σ), 100000)

which seems to have a very similar runtime to the original (still waiting for the ETA)?

Or is there another way to use filldist? I was a little confused by the documentation here.

mohamed82008 commented 3 years ago

@yebai dot broadcasting should be fast on observations, even GPU compatible most of the time. That is unless something changed recently.

mohamed82008 commented 3 years ago

@anhi try Zygote and ReverseDiff. You might have better luck with Zygote here because ReverseDiff's performance is a bit brittle depending on whether we use an array of tracked reals (slow) or a tracked array (fast).

torfjelde commented 3 years ago

I meant the latter, i.e. the one with ~. ~I'm surprised it's not faster though~ :confused:

Also, as @mohamed82008 pointed out this is for observations so my argument above doesn't actually hold.

try Zygote and ReverseDiff

It seems like he's using Zygote already?

mohamed82008 commented 3 years ago

It seems like he's using Zygote already?

Yes my bad, didn't see this. So try ReverseDiff then :)

torfjelde commented 3 years ago

One thing I've noticed in the past: if the function being maped or broadcasted contains if-statements (which LogNormal does), you get a pretty significant slowdown when using Zygote (broadcasting also often leads to type-instability in this case).

anhi commented 3 years ago

It seems like he's using Zygote already?

Yes my bad, didn't see this. So try ReverseDiff then :)

:) ok, I'll try... ETA for zygote was ~22 hours, btw, while forwarddiff with filldist took 61 seconds (a little longer than without filldist, which was ~42 seconds)

torfjelde commented 3 years ago

Btw, one thing you can do if you want to go really fast, is to use @addlogprob! and just compute the logpdf of LogNormal "by hand". Then you can also remove stuff like the if statements to check if it's inside the domain or nor. That is, you can replace it with

logx = log.(x)
zval = @. (logx - μ) / σ # `StatsFuns.normlogpdf(μ, σ, x)` has an if-statement in it, so we circumvent this by computing the `zval` ourselves.
@addlogprob! sum(StatsFuns.normlogpdf.(zval)) - sum(logx)

which should be the same (maybe check this though), assuming you've done import StatsFuns somewhere.

This should be muuuch faster using Zygote.

anhi commented 3 years ago

@torfjelde

ok, this is getting close...


@model benchmark_model_2(x) = begin
    μ ~ Normal(0.1, 10)
    σ ~ Normal(0.1, 10)   

    logx = log.(x)
    zval = @. (logx - μ) / σ
    @Turing.addlogprob! sum(StatsFuns.normlogpdf.(zval)) - sum(logx)
end

@time chains = Turing.sample(benchmark_model_2(samples), NUTS(0.65), 2000)

I've also changed the TruncatedNormals to Normals because I was not sure if they contain ifs as well...

308.334925 seconds (8.34 G allocations: 381.370 GiB, 19.94% gc time, 11.96% compilation time)

so this is indeed much faster than the standard LogNormal implementation.

I'm running the same experiment with the TruncatedNormals again, and it seems similarly fast.

ReverseDiff seems to have similar problems as Zygote, and similar timings. But I'll try that again later.

anhi commented 3 years ago

Using TruncatedNormals, the sampling took 268 seconds, which indeed looks much better than the several days I started out with. However, I just noticed that the sampling returns wrong values for \sigma. The data was generated with \mu = 1.5 and \sigma = 0.5, but sampling with non-truncated normals returned a \sigma of 41, the TruncatedNormal version a \sigma of 10. \mu was close to 1.5 in both cases...

torfjelde commented 3 years ago

I've also changed the TruncatedNormals to Normals because I was not sure if they contain ifs as well...

Probably unnecessary. This is specifically a problem when you're doing map or broadcasting, i.e. f.(x), over something, not so much an if-statement somewhere in the code:)

However, I just noticed that the sampling returns wrong values for \sigma. The data was generated with \mu = 1.5 and \sigma = 0.5, but sampling with non-truncated normals returned a \sigma of 41, the TruncatedNormal version a \sigma of 10. \mu was close to 1.5 in both cases...

Yeah sorry, this is because I made a mistake in the above (told yah it needed some checking :sweat_smile:). It should be

sum(StatsFuns.normlogpdf.(zval)) - sum(logx) - sum(log.(σ))

I forgot the log-abs-det-jacobian term from computing the zval :+1:

yebai commented 1 year ago

Closed in favour of #1934