Closed alastair-marshall closed 2 years ago
Oh, I thought I got rid of them! to_same_device()
used to move one array argument to the same device (e.g. CPU or GPU) as the other one. As a quick and dirty fix, you can try to redefine these methods without the conversion:
import Yota
Yota.unbroadcast_prod_x(x::Number, y::ArrayOrBroadcasted, Δ) = Yota.unbroadcast_prod_x([x], y, Δ)[1]
Yota.unbroadcast_prod_x(x::ArrayOrBroadcasted, y::Number, Δ) = Yota.unbroadcast_prod_x(x, [y], Δ)
Yota.unbroadcast_prod_y(x::ArrayOrBroadcasted, y::Number, Δ) = Yota.unbroadcast_prod_y(x, [y], Δ)[1]
Yota.unbroadcast_prod_y(x::Number, y::ArrayOrBroadcasted, Δ) = Yota.unbroadcast_prod_y([x], y, Δ)
I will try to come up with a better fix and tests in the next couple of days.
Do you have a reproducible example that I can test the fix on?
This should be fixed by #119 , please, let me know if it didn't help.
Ran into a strange bug in Yota when I tried to broadcast a multiplication of a ComplexF64 to a Matrix of ComplexF64:
I had a look in the source code and I can't see the function to_same_device anywhere inside the package.
P.S. I hit the error in the first place because I was lazy and didn't want to write a rrule for
ChainRulesCore.rrule(::typeof(*), ::ComplexF64, ::Float64, ::Matrix{ComplexF64}) = ...
and so I tried broadcasting the multiplication instead. I'm not sure if its a code path that anyone else would usually reach.