Closed alastair-marshall closed 4 years ago
Hi! At first glance this looks like correct use of diffrule. Do you have a minimal eeproducible example? Maybe Taylor expansion with some examples of A and u02?
Hi! Thanks for such a quick response, I'll post two pieces of code, the first being just the series_exp function in question, the second being essentially a whole minimal working example that I'm working on (essentially if I can get this to work then most of the work is done)
Series expansion:
function series_exp(A)
u02 = Array{ComplexF64,2}(I(2))
u02 + ((A + A^2 / 2) + (A^3 / 6 + A^4 / 24))
end
Where A is of the form:
sx = [0.0 + 0.0im 1.0 + 0.0im
1.0 + 0.0im 0.0 + 0.0im]
Using the diffrule I mentioned above I still get an error
julia> Yota.back!(tape)
┌ Error: Failed to find a derivative for %393 = /(%391, %392)::Array{Complex{Float64},2} at position 1, current state of backpropagation saved to Yota.DEBUG_STATE
└ @ Yota ~/.julia/packages/Yota/6ZZpD/src/grad.jl:164
The whole file for the project:
using Yota
using LinearAlgebra
# lets write a super basic idea of what we might want
sx = [0.0 + 0.0im 1.0 + 0.0im
1.0 + 0.0im 0.0 + 0.0im]
# Taylor series of the matrix exponential of A
function series_exp(A)
u02 = Array{ComplexF64,2}(I(2))
u02 + ((A + A^2 / 2) + (A^3 / 6 + A^4 / 24))
end
# compute the evolution operator
function evolve(x)
N = length(x)
U0 = Array{ComplexF64,2}(I(2))
for i = 1:N
ham = -1.0im * 0.1 * sx * x[i]
U0 = series_exp(ham) * U0
# U0 = exp(-1.0im * 0.1 * sx * x[i]) * U0
end
U0
end
# define some states
ρ = [1.0 + 0.0im, 0.0 + 0.0im]
ρt = [0.0 + 0.0im, 1.0 + 0.0im]
# the error functional
function functional(x)
U = evolve(x)
1 - abs2(ρt' * (U * ρ))
end
# testing input
x_input = rand(10)
functional(x_input)
# currently grad() fails but I can take a tape and debug
val, tape = Yota.itrace(functional, x_input)
Yota.back!(tape)
# first diffrule needed
@diffrule abs2(x::Number) x real(dy) * (x + x)
# next necessary diffrule
import Base./
@diffrule (/)(x::Array{ComplexF64,2}, y::Number) x dy / y
Essentially the problem is to use the gradient to update the "control" vector x to minimise the error functional.
Hope this isn't too much and that some of it is helpful!
Kind regards, Alastair
Never mind, I spotted the mistake - y
is the reserved name in @diffrule
. E.g.:
julia> using Yota
# using reserved symbol `y`
julia> @diffrule (/)(x::Array{ComplexF64,2}, y::Number) x dy / y
julia> A = rand(ComplexF64, 2, 2)
2×2 Array{Complex{Float64},2}:
0.5534+0.774336im 0.450525+0.751458im
0.0563076+0.792347im 0.0314715+0.408386im
julia> grad(x -> sum(x / 2), A)
┌ Error: Failed to find a derivative for %3 = /(%1, %2)::Array{Complex{Float64},2} at position 1, current state of backpropagation saved to Yota.DEBUG_STATE
└ @ Yota ~/work/Yota/src/grad.jl:163
ERROR: Can't find differentiation rule for (/)(var"%1", var"%2") at 1 with types DataType[Array{Complex{Float64},2}, Int64])
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] deriv_exprs(::Expr, ::Array{DataType,1}, ::Int64) at /home/user/work/Yota/src/diffrules/diffrules.jl:276
[3] deriv!(::Yota.Tape, ::Yota.Call, ::Int64, ::Yota.Call) at /home/user/work/Yota/src/grad.jl:88
[4] step_back!(::Yota.Tape, ::Yota.Call, ::Int64) at /home/user/work/Yota/src/grad.jl:157
[5] back!(::Yota.Tape) at /home/user/work/Yota/src/grad.jl:227
[6] _grad(::Yota.Tape) at /home/user/work/Yota/src/grad.jl:262
[7] _grad(::Function, ::Array{Complex{Float64},2}) at /home/user/work/Yota/src/grad.jl:274
[8] grad(::Function, ::Array{Complex{Float64},2}; dynamic::Bool) at /home/user/work/Yota/src/grad.jl:335
[9] grad(::Function, ::Array{Complex{Float64},2}) at /home/user/work/Yota/src/grad.jl:327
[10] top-level scope at REPL[4]:1
# replacing x -> u, y -> v
julia> @diffrule (/)(u::Array{ComplexF64,2}, v::Number) u dy / v
julia> _, g = grad(x -> sum(x / 2), A)
(0.5458522465959658 + 1.363263621537502im, GradResult(1))
julia> g[1]
2×2 Array{Complex{Float64},2}:
0.5+0.0im 0.5+0.0im
0.5+0.0im 0.5+0.0im
It's clearly a shortcoming of documentation, so I created #69 to fix this and some other details of recent changes. Before that, please use these rules when defining @diffrule
s:
y
means return value of a primal expression, e.g. y = f(x)
dy
means derivative of a loss function w.r.t. to y
x
, u
, v
, w
, t
, i
, j
and k
(all listed in Yota.DIFF_PHS
) can be used as names of variables, e.g. @diffrule foo(u, v) u ∇foo(dy, u, v)
_
can also be used, e.g. @diffrule bar(u, _state) _state ∇bar(dy, u, _state)
I brought these rules from Espresso.jl (and concretely rewrite.jl) with default placeholders set to Yota.DIFF_PHS
. Thanks for bringing my attention to not documenting it properly, I'll try to include it into the next improvement round.
Ah great, I'll keep those rules in mind when I'm writing adjoints in future! Thanks for all of your help! I think I need to do some reading about writing adjoints for the operations that I'm using but this has definitely been useful!
Kind regards,
Alastair
Regarding your minimal working example, I see it's still failing because Yota doesn't have rules for complex numbers (which I have very little experience with, honestly). I tried to define a bunch of such rules, possibly incorrectly, but hopefully it will be helpful for you:
using LinearAlgebra
# lets write a super basic idea of what we might want
sx = [0.0 + 0.0im 1.0 + 0.0im
1.0 + 0.0im 0.0 + 0.0im]
# we don't want to trace into array creation coded, so we abstract it out into a function
# and instruct Yota to ignore this path during reverse pass
make_complex_matrix(sz) = Array{ComplexF64,2}(I(sz))
@nodiff make_complex_matrix(_sz) _sz
# Taylor series of the matrix exponential of A
function series_exp(A)
u02 = make_complex_matrix(2)
# note that I replaced A^n with sequences of matrix multiplications of corresponding lengths
# not sure this is equivalent modification, but at least it ran to the end :)
u02 + ((A + A * A / 2) + (A * A * A / 6 + A * A * A * A / 24))
end
# compute the evolution operator
function evolve(x)
N = length(x)
U0 = make_complex_matrix(2)
for i = 1:N
ham = -1.0im * 0.1 * sx * x[i]
U0 = series_exp(ham) * U0
# U0 = exp(-1.0im * 0.1 * sx * x[i]) * U0
end
U0
end
# define some states
ρ = [1.0 + 0.0im, 0.0 + 0.0im]
ρt = [0.0 + 0.0im, 1.0 + 0.0im]
# the error functional
function functional(x)
U = evolve(x)
1 - abs2(ρt' * (U * ρ))
end
# first diffrule needed
@diffrule abs2(x::Number) x real(dy) * (x + x)
# next necessary diffrule
import Base./
@diffrule (/)(u::Array{ComplexF64,2}, v::Number) u dy / v
# expand list of rules for (*) to match complex numbers and arrays
@diffrule *(u::Number , v::Number) u v * dy
@diffrule *(u::Number , v::AbstractArray) u sum(v .* dy)
@diffrule *(u::AbstractArray, v::Number) u v .* dy
@diffrule *(u::AbstractArray, v::AbstractArray) u dy * transpose(v)
@diffrule *(u::Number , v::Number) v u * dy
@diffrule *(u::Number , v::AbstractArray) v u .* dy
@diffrule *(u::AbstractArray, v::Number) v sum(u .* dy)
@diffrule *(u::AbstractArray, v::AbstractArray) v transpose(u) * dy
x_input = rand(10)
_, g = grad(functional, x_input)
Also note that some changes to the tracer are coming: there's a new one (available under the name irtracer
, note prefix "IR" as a reference to IRTools), while itrace
will be deprecated soon as it's terribly slow and doesn't work well on Julia 1.4. To make code more robust, I recommend using just name trace()
which will always point to the default tracer.
Ah this is great, thanks a lot! I'll take a look at the rules and see if I can verify that they're correct. I chose itrace
simply because the default tracer wasn't working I think, it was giving another error but itrace
worked nicely.
Thanks again for all your help!
Hey there,
I'm looking at using Yota for a project where I need fast reverse mode AD but I'm new to writing adjoints and I'm having a bit of a hard time implementing an adjoint for the matrix exponential (the Zygote implementation looked hard to port over), so instead I decided to approximate the exponential using a Taylor series.
I run into this error:
which refers to a line in the Taylor expansion:
u02 + ((A + A^2 / 2) + (A^3 / 6 + A^4 / 24))
I tried to write a diffrule for this, having successfully written one for
abs2
but I'm having a hard time with this one...import Base./ @diffrule (/)(x::Array{ComplexF64,2}, y::Number) x dy / y
Which I think is trying to say that I want to divide the gradient by the value y, which I think is what I want... (some help here would be much appreciated, I thought I followed the example given for
logistic(x)
but not sure that I really did!)Thanks a lot!