dfdx / Yota.jl

Reverse-mode automatic differentiation in Julia
MIT License
158 stars 12 forks source link

ERROR: UndefVarError: to_same_device not defined #115

Closed alastair-marshall closed 2 years ago

alastair-marshall commented 2 years ago

Ran into a strange bug in Yota when I tried to broadcast a multiplication of a ComplexF64 to a Matrix of ComplexF64:

ERROR: UndefVarError: to_same_device not defined
Stacktrace:
  [1] unbroadcast_prod_x(x::ComplexF64, y::Matrix{ComplexF64}, Δ::Matrix{ComplexF64})
    @ Yota ~/.julia/packages/Yota/VCIzN/src/helpers.jl:78

I had a look in the source code and I can't see the function to_same_device anywhere inside the package.

P.S. I hit the error in the first place because I was lazy and didn't want to write a rrule for ChainRulesCore.rrule(::typeof(*), ::ComplexF64, ::Float64, ::Matrix{ComplexF64}) = ... and so I tried broadcasting the multiplication instead. I'm not sure if its a code path that anyone else would usually reach.

dfdx commented 2 years ago

Oh, I thought I got rid of them! to_same_device() used to move one array argument to the same device (e.g. CPU or GPU) as the other one. As a quick and dirty fix, you can try to redefine these methods without the conversion:

import Yota

Yota.unbroadcast_prod_x(x::Number, y::ArrayOrBroadcasted, Δ) = Yota.unbroadcast_prod_x([x], y, Δ)[1]
Yota.unbroadcast_prod_x(x::ArrayOrBroadcasted, y::Number, Δ) = Yota.unbroadcast_prod_x(x, [y], Δ)
Yota.unbroadcast_prod_y(x::ArrayOrBroadcasted, y::Number, Δ) = Yota.unbroadcast_prod_y(x, [y], Δ)[1]
Yota.unbroadcast_prod_y(x::Number, y::ArrayOrBroadcasted, Δ) = Yota.unbroadcast_prod_y([x], y, Δ)

I will try to come up with a better fix and tests in the next couple of days.

dfdx commented 2 years ago

Do you have a reproducible example that I can test the fix on?

dfdx commented 2 years ago

This should be fixed by #119 , please, let me know if it didn't help.