cscherrer / Soss.jl

Probabilistic programming via source rewriting
https://cscherrer.github.io/Soss.jl/stable/
MIT License
413 stars 30 forks source link

Performance for Large Models #200

Open cscherrer opened 4 years ago

cscherrer commented 4 years ago

@DilumAluthge suggested in https://github.com/cscherrer/Soss.jl/issues/161 that we add a Bayesian neural net example. There's a lot we'll need to do to get good performance here. Let's have some discussion here, then created a meta-issue for these tasks. We can close the current issue once we feel discussion is done and we've added the meta-issue.

So far, we've mostly focused on DynamicHMC with the default ForwardDiff. This is fast for small models, but for higher dimensions we really need reverse-mode AD. The obvious choice here is Zygote, but I'm also really impressed with the performance benchamarks of [Yota.jl][(https://github.com/dfdx/Yota.jl). Currently, Zygote uses ChainRules.jl, while Yota doesn't (yet), but uses its own rule-writing system. Yota's system looks very nice, but would require us writing rules for all of the distributions, which is probably too much. Zygote's big win here is from crowdsourcing.

Still, we'll need to

Soss models are typically small at the top-level, though a given node could be large (e.g., a neural net). In the long term, we should be able to leverage this, something like

Since Zygote uses ChainRules, most gradient work is done. But we've added lots of distribution combinators, which will need ChainRules rrules.

There's probably more, please add other concerns here to roll into the meta-issue

cc: @millerjoey

cscherrer commented 4 years ago

Oh, and we'll need

cscherrer commented 4 years ago

This could be an easier way to get us there: https://github.com/mcabbott/Tullio.jl