@DilumAluthge suggested in https://github.com/cscherrer/Soss.jl/issues/161 that we add a Bayesian neural net example. There's a lot we'll need to do to get good performance here. Let's have some discussion here, then created a meta-issue for these tasks. We can close the current issue once we feel discussion is done and we've added the meta-issue.
So far, we've mostly focused on DynamicHMC with the default ForwardDiff. This is fast for small models, but for higher dimensions we really need reverse-mode AD. The obvious choice here is Zygote, but I'm also really impressed with the performance benchamarks of [Yota.jl][(https://github.com/dfdx/Yota.jl). Currently, Zygote uses ChainRules.jl, while Yota doesn't (yet), but uses its own rule-writing system. Yota's system looks very nice, but would require us writing rules for all of the distributions, which is probably too much. Zygote's big win here is from crowdsourcing.
Still, we'll need to
[ ] Check that Zygote is set up to work with Soss models
[ ] Make sure it's easy to use
[ ] Ideally, have some sensible default that switches between ForwardDiff and Zygote based on parameter dimensionality
Soss models are typically small at the top-level, though a given node could be large (e.g., a neural net). In the long term, we should be able to leverage this, something like
[ ] Allow Zygote gradient information on a node to propagate to top-level model
Since Zygote uses ChainRules, most gradient work is done. But we've added lots of distribution combinators, which will need ChainRules rrules.
[ ] Add ChainRules.rrules for For, iid, Mix, MarkovChain
There's probably more, please add other concerns here to roll into the meta-issue
@DilumAluthge suggested in https://github.com/cscherrer/Soss.jl/issues/161 that we add a Bayesian neural net example. There's a lot we'll need to do to get good performance here. Let's have some discussion here, then created a meta-issue for these tasks. We can close the current issue once we feel discussion is done and we've added the meta-issue.
So far, we've mostly focused on DynamicHMC with the default
ForwardDiff
. This is fast for small models, but for higher dimensions we really need reverse-mode AD. The obvious choice here is Zygote, but I'm also really impressed with the performance benchamarks of [Yota.jl][(https://github.com/dfdx/Yota.jl). Currently, Zygote uses ChainRules.jl, while Yota doesn't (yet), but uses its own rule-writing system. Yota's system looks very nice, but would require us writing rules for all of the distributions, which is probably too much. Zygote's big win here is from crowdsourcing.Still, we'll need to
Soss models are typically small at the top-level, though a given node could be large (e.g., a neural net). In the long term, we should be able to leverage this, something like
Since Zygote uses ChainRules, most gradient work is done. But we've added lots of distribution combinators, which will need ChainRules
rrule
s.ChainRules.rrule
s forFor
,iid
,Mix
,MarkovChain
There's probably more, please add other concerns here to roll into the meta-issue
cc: @millerjoey