dfdx / Yota.jl

Reverse-mode automatic differentiation in Julia
MIT License
158 stars 12 forks source link

Help with implementing diffrule for element-wise division #68

Closed alastair-marshall closed 4 years ago

alastair-marshall commented 4 years ago

Hey there,

I'm looking at using Yota for a project where I need fast reverse mode AD but I'm new to writing adjoints and I'm having a bit of a hard time implementing an adjoint for the matrix exponential (the Zygote implementation looked hard to port over), so instead I decided to approximate the exponential using a Taylor series.

I run into this error:

┌ Error: Failed to find a derivative for %393 = /(%391, %392)::Array{Complex{Float64},2} at position 1, current state of backpropagation saved to Yota.DEBUG_STATE
└ @ Yota ~/.julia/packages/Yota/6ZZpD/src/grad.jl:164

which refers to a line in the Taylor expansion:

u02 + ((A + A^2 / 2) + (A^3 / 6 + A^4 / 24))

I tried to write a diffrule for this, having successfully written one for abs2 but I'm having a hard time with this one...

import Base./ @diffrule (/)(x::Array{ComplexF64,2}, y::Number) x dy / y

Which I think is trying to say that I want to divide the gradient by the value y, which I think is what I want... (some help here would be much appreciated, I thought I followed the example given for logistic(x) but not sure that I really did!)

Thanks a lot!

dfdx commented 4 years ago

Hi! At first glance this looks like correct use of diffrule. Do you have a minimal eeproducible example? Maybe Taylor expansion with some examples of A and u02?

alastair-marshall commented 4 years ago

Hi! Thanks for such a quick response, I'll post two pieces of code, the first being just the series_exp function in question, the second being essentially a whole minimal working example that I'm working on (essentially if I can get this to work then most of the work is done)

Series expansion:

function series_exp(A)
    u02 = Array{ComplexF64,2}(I(2))
    u02 + ((A + A^2 / 2) + (A^3 / 6 + A^4 / 24))
end

Where A is of the form:

sx = [0.0 + 0.0im 1.0 + 0.0im
      1.0 + 0.0im 0.0 + 0.0im]

Using the diffrule I mentioned above I still get an error

julia> Yota.back!(tape)
┌ Error: Failed to find a derivative for %393 = /(%391, %392)::Array{Complex{Float64},2} at position 1, current state of backpropagation saved to Yota.DEBUG_STATE
└ @ Yota ~/.julia/packages/Yota/6ZZpD/src/grad.jl:164

The whole file for the project:

using Yota
using LinearAlgebra

# lets write a super basic idea of what we might want
sx = [0.0 + 0.0im 1.0 + 0.0im
      1.0 + 0.0im 0.0 + 0.0im]

# Taylor series of the matrix exponential of A
function series_exp(A)
    u02 = Array{ComplexF64,2}(I(2))
    u02 + ((A + A^2 / 2) + (A^3 / 6 + A^4 / 24))
end

# compute the evolution operator
function evolve(x)
    N = length(x)
    U0 = Array{ComplexF64,2}(I(2))
    for i = 1:N
        ham = -1.0im * 0.1 * sx * x[i]
        U0 = series_exp(ham) * U0
        # U0 = exp(-1.0im * 0.1 * sx * x[i]) * U0
    end
    U0
end

# define some states
ρ = [1.0 + 0.0im, 0.0 + 0.0im]
ρt = [0.0 + 0.0im, 1.0 + 0.0im]

# the error functional
function functional(x)
    U = evolve(x)
    1 - abs2(ρt' * (U * ρ))
end

# testing input
x_input = rand(10)
functional(x_input)
# currently grad() fails but I can take a tape and debug
val, tape = Yota.itrace(functional, x_input)

Yota.back!(tape)

# first diffrule needed
@diffrule abs2(x::Number) x real(dy) * (x + x)

# next necessary diffrule
import Base./
@diffrule (/)(x::Array{ComplexF64,2}, y::Number) x dy / y

Essentially the problem is to use the gradient to update the "control" vector x to minimise the error functional.

Hope this isn't too much and that some of it is helpful!

Kind regards, Alastair

dfdx commented 4 years ago

Never mind, I spotted the mistake - y is the reserved name in @diffrule. E.g.:

julia> using Yota

# using reserved symbol `y`
julia> @diffrule (/)(x::Array{ComplexF64,2}, y::Number) x dy / y

julia> A = rand(ComplexF64, 2, 2)
2×2 Array{Complex{Float64},2}:
    0.5534+0.774336im   0.450525+0.751458im
 0.0563076+0.792347im  0.0314715+0.408386im

julia> grad(x -> sum(x / 2), A)
┌ Error: Failed to find a derivative for %3 = /(%1, %2)::Array{Complex{Float64},2} at position 1, current state of backpropagation saved to Yota.DEBUG_STATE
└ @ Yota ~/work/Yota/src/grad.jl:163
ERROR: Can't find differentiation rule for (/)(var"%1", var"%2") at 1 with types DataType[Array{Complex{Float64},2}, Int64])
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] deriv_exprs(::Expr, ::Array{DataType,1}, ::Int64) at /home/user/work/Yota/src/diffrules/diffrules.jl:276
 [3] deriv!(::Yota.Tape, ::Yota.Call, ::Int64, ::Yota.Call) at /home/user/work/Yota/src/grad.jl:88
 [4] step_back!(::Yota.Tape, ::Yota.Call, ::Int64) at /home/user/work/Yota/src/grad.jl:157
 [5] back!(::Yota.Tape) at /home/user/work/Yota/src/grad.jl:227
 [6] _grad(::Yota.Tape) at /home/user/work/Yota/src/grad.jl:262
 [7] _grad(::Function, ::Array{Complex{Float64},2}) at /home/user/work/Yota/src/grad.jl:274
 [8] grad(::Function, ::Array{Complex{Float64},2}; dynamic::Bool) at /home/user/work/Yota/src/grad.jl:335
 [9] grad(::Function, ::Array{Complex{Float64},2}) at /home/user/work/Yota/src/grad.jl:327
 [10] top-level scope at REPL[4]:1

# replacing x -> u, y -> v
julia> @diffrule (/)(u::Array{ComplexF64,2}, v::Number) u dy / v

julia> _, g = grad(x -> sum(x / 2), A)
(0.5458522465959658 + 1.363263621537502im, GradResult(1))

julia> g[1]
2×2 Array{Complex{Float64},2}:
 0.5+0.0im  0.5+0.0im
 0.5+0.0im  0.5+0.0im

It's clearly a shortcoming of documentation, so I created #69 to fix this and some other details of recent changes. Before that, please use these rules when defining @diffrules:

I brought these rules from Espresso.jl (and concretely rewrite.jl) with default placeholders set to Yota.DIFF_PHS. Thanks for bringing my attention to not documenting it properly, I'll try to include it into the next improvement round.

alastair-marshall commented 4 years ago

Ah great, I'll keep those rules in mind when I'm writing adjoints in future! Thanks for all of your help! I think I need to do some reading about writing adjoints for the operations that I'm using but this has definitely been useful!

Kind regards,

Alastair

dfdx commented 4 years ago

Regarding your minimal working example, I see it's still failing because Yota doesn't have rules for complex numbers (which I have very little experience with, honestly). I tried to define a bunch of such rules, possibly incorrectly, but hopefully it will be helpful for you:

using LinearAlgebra

# lets write a super basic idea of what we might want
sx = [0.0 + 0.0im 1.0 + 0.0im
      1.0 + 0.0im 0.0 + 0.0im]

# we don't want to trace into array creation coded, so we abstract it out into a function
# and instruct Yota to ignore this path during reverse pass
make_complex_matrix(sz) = Array{ComplexF64,2}(I(sz))
@nodiff make_complex_matrix(_sz) _sz

# Taylor series of the matrix exponential of A
function series_exp(A)
    u02 = make_complex_matrix(2)
   # note that I replaced A^n with sequences of matrix multiplications of corresponding lengths
   # not sure this is equivalent modification, but at least it ran to the end :)
    u02 + ((A + A * A / 2) + (A * A * A / 6 + A * A * A * A / 24))
end

# compute the evolution operator
function evolve(x)
    N = length(x)
    U0 = make_complex_matrix(2)
    for i = 1:N
        ham = -1.0im * 0.1 * sx * x[i]
        U0 = series_exp(ham) * U0
        # U0 = exp(-1.0im * 0.1 * sx * x[i]) * U0
    end
    U0
end

# define some states
ρ = [1.0 + 0.0im, 0.0 + 0.0im]
ρt = [0.0 + 0.0im, 1.0 + 0.0im]

# the error functional
function functional(x)
    U = evolve(x)
    1 - abs2(ρt' * (U * ρ))
end

# first diffrule needed
@diffrule abs2(x::Number) x real(dy) * (x + x)

# next necessary diffrule
import Base./
@diffrule (/)(u::Array{ComplexF64,2}, v::Number) u dy / v

# expand list of rules for (*) to match complex numbers and arrays
@diffrule *(u::Number         , v::Number)            u     v * dy
@diffrule *(u::Number         , v::AbstractArray)    u     sum(v .* dy)
@diffrule *(u::AbstractArray, v::Number)            u     v .* dy
@diffrule *(u::AbstractArray, v::AbstractArray)    u     dy * transpose(v)

@diffrule *(u::Number         , v::Number)            v     u * dy
@diffrule *(u::Number         , v::AbstractArray)    v     u .* dy
@diffrule *(u::AbstractArray, v::Number)            v     sum(u .* dy)
@diffrule *(u::AbstractArray, v::AbstractArray)    v     transpose(u) * dy

x_input = rand(10)
_, g = grad(functional, x_input)

Also note that some changes to the tracer are coming: there's a new one (available under the name irtracer, note prefix "IR" as a reference to IRTools), while itrace will be deprecated soon as it's terribly slow and doesn't work well on Julia 1.4. To make code more robust, I recommend using just name trace() which will always point to the default tracer.

alastair-marshall commented 4 years ago

Ah this is great, thanks a lot! I'll take a look at the rules and see if I can verify that they're correct. I chose itrace simply because the default tracer wasn't working I think, it was giving another error but itrace worked nicely.

Thanks again for all your help!