TuringLang / Bijectors.jl

Implementation of normalising flows and constrained random variable transformations
https://turinglang.org/Bijectors.jl/
MIT License
196 stars 32 forks source link

fix `Stacked` to work with Zygote #315

Closed Red-Portal closed 3 weeks ago

Red-Portal commented 1 month ago

This fixes #252 and also the corresponding issue downstream.

I checked that the performance (both AD and forward) is pretty much the same (slightly better in my setup) as the previous implementation. Interestingly, doing mapreduce(f, vcat) on ys performed consistently worse than doing map+reduce(vcat) and also the previous implementation.

yebai commented 3 weeks ago

Thanks @Red-Portal!