JuliaDiff / ReverseDiff.jl

Reverse Mode Automatic Differentiation for Julia
Other
348 stars 57 forks source link

Broadcasting support is broken-ish #122

Closed mohamed82008 closed 4 years ago

mohamed82008 commented 4 years ago

Currently broadcasting causes a TrackedArray to become an array of TrackedReals. It seems that the broadcasting code in ReverseDiff (RD) wasn't updated to the latest Julia version because I see broadcast and broadcast! are supported in RD but they don't get called. Perhaps it is worth porting and adapting the broadcasting code from Tracker.jl. I will try working on that.

ChrisRackauckas commented 4 years ago

Yeah, it seems it just wasn't updated. This would be a nice one to have.

mohamed82008 commented 4 years ago

I already implemented it in DistributionsAD temporarily as it seems that ReverseDiff is not very actively maintained. I can make a PR here if someone is willing to review and merge it.

ChrisRackauckas commented 4 years ago

@yingboma and I can review it. I don't know if @jrevels is looking to merge though, or at least bring in maintainers. I technically can merge, but would like his input first.

mohamed82008 commented 4 years ago

Sounds good. I can also port a few improvements in getindex, vcat, hcat, cat, and some linear algebra from DistributionsAD. I will make some time next week to make a few PRs.

ChrisRackauckas commented 4 years ago

BTW, do you know of a good way to change Array{TrackedReal} -> TrackedArray in RD?

mohamed82008 commented 4 years ago

Not too good, but reduce with vcat should work if you define custom adjoints for vcat.

ChrisRackauckas commented 4 years ago

That sounds like Array{TrackedArray} -> TrackedArray?

ChrisRackauckas commented 4 years ago

We got the 👍 from @jrevels , so I'll make you a maintainer here given the work you've done. @yingboma and I can help review when you need it.

mohamed82008 commented 4 years ago

Well yes similar. We just need adjoints for vcat(::Real, ::Real) and vcat(::AbstractVector, ::Real). Then reduce(vcat, ::Array{<:TrackedReal}) will return a TrackedArray. To make it type stable, one can also pass the initial value as vcat(::TrackedReal) where the input is the first number in the array. That would require an adjoint for vcat(::Real) as well which would return a vector, because vcat(1) == [1]. Then we can use Iterators.drop to reduce the rest of the array.

mohamed82008 commented 4 years ago

We got the 👍 from @jrevels , so I'll make you a maintainer here given the work you've done. @YingboMa and I can help review when you need it.

Cool, thanks!

YingboMa commented 4 years ago

Yeah, please feel free to ping on GitHub or Slack when you want me to review it.

mohamed82008 commented 4 years ago

Fixed by #134.