Open Red-Portal opened 1 year ago
This appears to be more complicated. It seems that gradient(y -> sum(x -> x^2, y)/10, CUDA.randn(10))
does not hit the sum(f, x)
rrule, while mean(f, x)
does. This is super weird. I have no idea which rrule
is being hit for sum(f, x)
.
Zygote has this rule for sum(f, xs::CuArray)
, which takes precedence over the one here:
Note also that sum(x -> x^2, xs)
is equivalent to sum(abs2, xs)
which has a special rule. I think that mean(abs2, xs)
goes here and should call that.
(One example above has x -> x.^2
with an extra broadcast, some chance that changes what path is taken in the sum(f, xs)
rule.)
Hi, it seems that the
rrule
formean(f, x)
is not vectorized and thus does not place nicely with CUDA:The problem seems to be that this line does not use
map
or broadcasting. But the comment seems to suggest that we can't do that here. Is there anything we can do?By the way,
sum(f, x)
for the samef
works perfectly. So I'm quite curious why the result is different. Both hit the samerrule
right?