Closed mohamed82008 closed 4 years ago
Yeah, it seems it just wasn't updated. This would be a nice one to have.
I already implemented it in DistributionsAD temporarily as it seems that ReverseDiff is not very actively maintained. I can make a PR here if someone is willing to review and merge it.
@yingboma and I can review it. I don't know if @jrevels is looking to merge though, or at least bring in maintainers. I technically can merge, but would like his input first.
Sounds good. I can also port a few improvements in getindex
, vcat
, hcat
, cat
, and some linear algebra from DistributionsAD
. I will make some time next week to make a few PRs.
BTW, do you know of a good way to change Array{TrackedReal} -> TrackedArray in RD?
Not too good, but reduce
with vcat
should work if you define custom adjoints for vcat
.
That sounds like Array{TrackedArray} -> TrackedArray?
We got the 👍 from @jrevels , so I'll make you a maintainer here given the work you've done. @yingboma and I can help review when you need it.
Well yes similar. We just need adjoints for vcat(::Real, ::Real)
and vcat(::AbstractVector, ::Real)
. Then reduce(vcat, ::Array{<:TrackedReal})
will return a TrackedArray
. To make it type stable, one can also pass the initial value as vcat(::TrackedReal)
where the input is the first number in the array. That would require an adjoint for vcat(::Real)
as well which would return a vector, because vcat(1) == [1]
. Then we can use Iterators.drop
to reduce
the rest of the array.
We got the 👍 from @jrevels , so I'll make you a maintainer here given the work you've done. @YingboMa and I can help review when you need it.
Cool, thanks!
Yeah, please feel free to ping on GitHub or Slack when you want me to review it.
Fixed by #134.
Currently broadcasting causes a
TrackedArray
to become an array ofTrackedReal
s. It seems that the broadcasting code inReverseDiff
(RD) wasn't updated to the latest Julia version because I seebroadcast
andbroadcast!
are supported in RD but they don't get called. Perhaps it is worth porting and adapting the broadcasting code from Tracker.jl. I will try working on that.