Closed anhi closed 1 year ago
Try filldist
instead of .~
. That should spead things up significantly.
Essentially, filldist
means that the entire vector x
is treated as one multivariate random variable while .~
makes it so that you get length(x)
univariate random variables. The former is going to be signficantly faster.
@torfjelde @mohamed82008 A side note, any reason that we shouldn't translate dot observe
and dot assume
into filldist
automatically in DynamicPPL
?
@torfjelde thank you for the hint. I'm still a little confused: did you mean
x = filldist(LogNormal(μ, σ), 100000)
which is blazingly fast but yields wrong values for \mu and \sigma, or
x ~ filldist(LogNormal(μ, σ), 100000)
which seems to have a very similar runtime to the original (still waiting for the ETA)?
Or is there another way to use filldist? I was a little confused by the documentation here.
@yebai dot broadcasting should be fast on observations, even GPU compatible most of the time. That is unless something changed recently.
@anhi try Zygote and ReverseDiff. You might have better luck with Zygote here because ReverseDiff's performance is a bit brittle depending on whether we use an array of tracked reals (slow) or a tracked array (fast).
I meant the latter, i.e. the one with ~
. ~I'm surprised it's not faster though~ :confused:
Also, as @mohamed82008 pointed out this is for observations so my argument above doesn't actually hold.
try Zygote and ReverseDiff
It seems like he's using Zygote already?
It seems like he's using Zygote already?
Yes my bad, didn't see this. So try ReverseDiff then :)
One thing I've noticed in the past: if the function being map
ed or broadcasted contains if-statements (which LogNormal
does), you get a pretty significant slowdown when using Zygote (broadcasting also often leads to type-instability in this case).
It seems like he's using Zygote already?
Yes my bad, didn't see this. So try ReverseDiff then :)
:) ok, I'll try... ETA for zygote was ~22 hours, btw, while forwarddiff with filldist took 61 seconds (a little longer than without filldist, which was ~42 seconds)
Btw, one thing you can do if you want to go really fast, is to use @addlogprob!
and just compute the logpdf of LogNormal
"by hand". Then you can also remove stuff like the if
statements to check if it's inside the domain or nor. That is, you can replace it with
logx = log.(x)
zval = @. (logx - μ) / σ # `StatsFuns.normlogpdf(μ, σ, x)` has an if-statement in it, so we circumvent this by computing the `zval` ourselves.
@addlogprob! sum(StatsFuns.normlogpdf.(zval)) - sum(logx)
which should be the same (maybe check this though), assuming you've done import StatsFuns
somewhere.
This should be muuuch faster using Zygote.
@torfjelde
ok, this is getting close...
@model benchmark_model_2(x) = begin
μ ~ Normal(0.1, 10)
σ ~ Normal(0.1, 10)
logx = log.(x)
zval = @. (logx - μ) / σ
@Turing.addlogprob! sum(StatsFuns.normlogpdf.(zval)) - sum(logx)
end
@time chains = Turing.sample(benchmark_model_2(samples), NUTS(0.65), 2000)
I've also changed the TruncatedNormals to Normals because I was not sure if they contain ifs as well...
308.334925 seconds (8.34 G allocations: 381.370 GiB, 19.94% gc time, 11.96% compilation time)
so this is indeed much faster than the standard LogNormal implementation.
I'm running the same experiment with the TruncatedNormals again, and it seems similarly fast.
ReverseDiff seems to have similar problems as Zygote, and similar timings. But I'll try that again later.
Using TruncatedNormals, the sampling took 268 seconds, which indeed looks much better than the several days I started out with. However, I just noticed that the sampling returns wrong values for \sigma. The data was generated with \mu = 1.5 and \sigma = 0.5, but sampling with non-truncated normals returned a \sigma of 41, the TruncatedNormal version a \sigma of 10. \mu was close to 1.5 in both cases...
I've also changed the TruncatedNormals to Normals because I was not sure if they contain ifs as well...
Probably unnecessary. This is specifically a problem when you're doing map
or broadcasting, i.e. f.(x)
, over something, not so much an if-statement somewhere in the code:)
However, I just noticed that the sampling returns wrong values for \sigma. The data was generated with \mu = 1.5 and \sigma = 0.5, but sampling with non-truncated normals returned a \sigma of 41, the TruncatedNormal version a \sigma of 10. \mu was close to 1.5 in both cases...
Yeah sorry, this is because I made a mistake in the above (told yah it needed some checking :sweat_smile:). It should be
sum(StatsFuns.normlogpdf.(zval)) - sum(logx) - sum(log.(σ))
I forgot the log-abs-det-jacobian term from computing the zval
:+1:
Closed in favour of #1934
When trying to optimize our Turing code, we experimented with the different AD engines. It seems as if the reverse-mode AD engines are extremely slow for large numbers of observations. Our original model has several hundred dimensions, but the effect can be demonstrated on this simple example:
On my machine, this takes about 42 seconds:
41.868502 seconds (21.57 M allocations: 1.340 GiB, 1.49% gc time, 18.19% compilation time)
I understand that for such a simple model, forward-mode should be more efficient. But when switching to reverse-mode
it takes several hours on my machine to even arrive at an ETA, which starts out at several days. This seems a little excessive.
Are we doing anything wrong, or is reverse-mode just not useable for large numbers of observations?