Closed WardBrian closed 7 months ago
Here is a demo of what I'm describing, using this model
import bridgestan
import timeit
model = bridgestan.StanModel('sir.stan', 'sir.data.json')
with open('sir.init.json') as f:
params = model.param_unconstrain_json(f.read())
def one():
model.log_density(params, propto=True)
def two():
model.log_density(params, propto=False)
timeit.timeit(one, number=1000) # warms-up caches etc
time_one = timeit.timeit(one, number=1000)
time_two = timeit.timeit(two, number=1000)
print(f"propto=T: {time_one*1000:.1f}ms")
print(f"propto=F: {time_two*1000:.1f}ms")
This prints
propto=T: 329.7ms
propto=F: 81.4ms
So, over 4 times slower for this model!
I think this is essentially absent from the Stan documentation, since calculating gradients is essentially taken for granted everywhere in Stan
I think this is essentially absent from the Stan documentation, since calculating gradients is essentially taken for granted everywhere in Stan
It's discussed in the efficiency section of the User's Guide and at length in the Reference Manual, which covers which things get autodiffed and which ones are just double
-based.
Is there any place that would obviously lead the reader to the conclusion this discusses? e.g., that for the log_density
function, propto
can have dramatically different performance implications than it does for log_density_gradient
?
Good point---we talk about everything you would need to draw this conclusion yourself, but I don't think we ever connect the dots. We probably should. I added an issue for the User's Guide efficiency chapter:
This is a bit of a weird thing I've been thinking about recently that it is good to alert users for. It's related to #165 and #180.
Basically
propto=True
requires we passvars
to Stan. Before #165, we were even callinggrad
, wasting a lot of computation. Since #165, we no longer call grad, but it still may lead to more work than you'd expect, since the Stan math library assumes (justifiably) that if you're calling a function withvar
s you will want gradients, and so it can do some pre-computation for you. Because we never call/usegrad()
inlog_density
, this is wasted effort.The big offenders are the higher order functions like
reduce_sum
, which basically calculate their entire gradients in the "forward pass".I wrote up a docs page on implementation details and added a section about this. I'm not sure if it should be linked other places or not.