Closed ghost closed 2 years ago
Hi @JoeGibbs88, when using AD functionalities through Zygote sometimes a specific syntax is required. In your example there are a few things that needs to be changed to make it work. First, when using vcat
to build the circuit you should also include the square brakets around each new gate:
function variationalcircuit(params)
circuit = Tuple[]
circuit = vcat(circuit, [("X", 1)])
circuit = vcat(circuit, [("G", (1, 2), (θ = params[1],))])
circuit = vcat(circuit, [("Phase", 1, (ϕ = params[2],))])
circuit = vcat(circuit, [("Phase", 2, (ϕ = params[3],))])
circuit = vcat(circuit, [("G", (1, 2), (θ = -params[1],))])
return circuit
end
Next, the loss function needs also couple of changes. First, differentiation through an MPS constructor is not yet defined in ITensors
(see https://github.com/ITensor/ITensors.jl/issues/776). So for now you need to create the initial state outside the loss function. The other change is in the fidelity
call, which is also currently not differentiable. For the time being, you could instead use the inner
function in ITensor, which in this case computes the inner product of two MPSs, and build the fidelity from it. The reason fidelity
is non-differentiable is because we wrote it to be numerically stable for large MPS, which requires some additional operators on top of inner
whose derivatives are not yet taken care of (though we shall make a fix for this soon). So the loss function becomes:
psi0 = productstate(hilbert)
function loss(params)
circuit = variationalcircuit(params)
psi = runcircuit(psi0, circuit)
return 1 - abs2(inner(psi, target))
end
With the above modification the code runs and outputs:
0.05545295586765009
[ Info: LBFGS: initializing with f = 0.055452955868, ‖∇f‖ = 2.8750e-01
[ Info: LBFGS: iter 1: f = 0.001045920807, ‖∇f‖ = 3.7467e-02, α = 4.30e-01, m = 0, nfg = 2
[ Info: LBFGS: iter 2: f = 0.000157296908, ‖∇f‖ = 1.4674e-03, α = 1.00e+00, m = 1, nfg = 1
[ Info: LBFGS: iter 3: f = 0.000155638615, ‖∇f‖ = 7.1383e-04, α = 1.00e+00, m = 2, nfg = 1
[ Info: LBFGS: iter 4: f = 0.000154660387, ‖∇f‖ = 8.0494e-04, α = 1.00e+00, m = 3, nfg = 1
[ Info: LBFGS: iter 5: f = 0.000148931986, ‖∇f‖ = 2.0502e-03, α = 1.00e+00, m = 4, nfg = 1
[ Info: LBFGS: iter 6: f = 0.000137419061, ‖∇f‖ = 3.7131e-03, α = 1.00e+00, m = 5, nfg = 1
[ Info: LBFGS: iter 7: f = 0.000099576772, ‖∇f‖ = 5.6274e-03, α = 1.00e+00, m = 6, nfg = 1
[ Info: LBFGS: iter 8: f = 0.000014151556, ‖∇f‖ = 4.3621e-03, α = 5.72e-01, m = 7, nfg = 3
[ Info: LBFGS: iter 9: f = 0.000010465556, ‖∇f‖ = 1.5371e-03, α = 4.85e-01, m = 8, nfg = 2
[ Info: LBFGS: iter 10: f = 0.000008838465, ‖∇f‖ = 1.6771e-03, α = 1.00e+00, m = 8, nfg = 1
[ Info: LBFGS: iter 11: f = 0.000002607997, ‖∇f‖ = 1.8504e-03, α = 1.00e+00, m = 8, nfg = 1
[ Info: LBFGS: iter 12: f = 0.000000235906, ‖∇f‖ = 2.6223e-05, α = 1.00e+00, m = 8, nfg = 1
[ Info: LBFGS: iter 13: f = 0.000000007655, ‖∇f‖ = 7.5840e-05, α = 1.00e+00, m = 8, nfg = 1
[ Info: LBFGS: iter 14: f = 0.000000000038, ‖∇f‖ = 7.0155e-06, α = 1.00e+00, m = 8, nfg = 1
[ Info: LBFGS: iter 15: f = 0.000000000000, ‖∇f‖ = 7.6904e-08, α = 1.00e+00, m = 8, nfg = 1
[ Info: LBFGS: converged after 16 iterations: f = 0.000000000000, ‖∇f‖ = 7.0827e-09
Hope this helps, and let us know if you still have issues with this!
Thank you for the detailed answer! After adding in your substitutions the code now runs without error. Strangely however, yours converges to 0 nicely while mine plateaus after one step...
First 10 iterations:
0.04319948746660329
[ Info: LBFGS: initializing with f = 0.043199487467, ‖∇f‖ = 1.8831e+00
[ Info: LBFGS: iter 1: f = 0.025405693379, ‖∇f‖ = 2.6732e-01, α = 1.00e+00, m = 0, nfg = 1
[ Info: LBFGS: iter 2: f = 0.024239523406, ‖∇f‖ = 1.4714e-02, α = 5.51e-01, m = 1, nfg = 2
┌ Warning: Linesearch bracket converged to a point without satisfying Wolfe conditions?
└ @ OptimKit ~/.julia/packages/OptimKit/xpmbV/src/linesearches.jl:189
[ Info: LBFGS: iter 3: f = 0.024240523406, ‖∇f‖ = 1.4697e-02, α = 2.70e-03, m = 2, nfg = 54
┌ Warning: Linesearch bracket converged to a point without satisfying Wolfe conditions?
└ @ OptimKit ~/.julia/packages/OptimKit/xpmbV/src/linesearches.jl:189
[ Info: LBFGS: iter 4: f = 0.024241523406, ‖∇f‖ = 1.4681e-02, α = 1.06e-03, m = 3, nfg = 54
┌ Warning: Linesearch bracket converged to a point without satisfying Wolfe conditions?
└ @ OptimKit ~/.julia/packages/OptimKit/xpmbV/src/linesearches.jl:189
[ Info: LBFGS: iter 5: f = 0.024242523406, ‖∇f‖ = 1.4665e-02, α = 9.91e-04, m = 4, nfg = 54
┌ Warning: Linesearch bracket converged to a point without satisfying Wolfe conditions?
└ @ OptimKit ~/.julia/packages/OptimKit/xpmbV/src/linesearches.jl:189
[ Info: LBFGS: iter 6: f = 0.024243523406, ‖∇f‖ = 1.4650e-02, α = 8.86e-04, m = 5, nfg = 54
┌ Warning: Linesearch bracket converged to a point without satisfying Wolfe conditions?
└ @ OptimKit ~/.julia/packages/OptimKit/xpmbV/src/linesearches.jl:189
[ Info: LBFGS: iter 7: f = 0.024244523406, ‖∇f‖ = 1.4635e-02, α = 7.95e-04, m = 6, nfg = 54
┌ Warning: Linesearch bracket converged to a point without satisfying Wolfe conditions?
└ @ OptimKit ~/.julia/packages/OptimKit/xpmbV/src/linesearches.jl:189
[ Info: LBFGS: iter 8: f = 0.024245523406, ‖∇f‖ = 1.4620e-02, α = 7.16e-04, m = 7, nfg = 54
┌ Warning: Linesearch bracket converged to a point without satisfying Wolfe conditions?
└ @ OptimKit ~/.julia/packages/OptimKit/xpmbV/src/linesearches.jl:189
[ Info: LBFGS: iter 9: f = 0.024246523406, ‖∇f‖ = 1.4606e-02, α = 6.48e-04, m = 8, nfg = 54
┌ Warning: Linesearch bracket converged to a point without satisfying Wolfe conditions?
└ @ OptimKit ~/.julia/packages/OptimKit/xpmbV/src/linesearches.jl:189
[ Info: LBFGS: iter 10: f = 0.024247523406, ‖∇f‖ = 1.4593e-02, α = 5.89e-04, m = 8, nfg = 54
┌ Warning: Linesearch bracket converged to a point without satisfying Wolfe conditions?
Package versions:
(@v1.6) pkg> status
Status `~/.julia/environments/v1.6/Project.toml`
[587475ba] Flux v0.12.9
[9136182c] ITensors v0.2.15
[d41bc354] NLSolversBase v7.8.2
[76087f3c] NLopt v0.6.3
[429524aa] Optim v1.6.2
[77e91f04] OptimKit v0.3.1
[30b07047] PastaQ v0.0.19
[e88e6eb3] Zygote v0.6.35
[de0858da] Printf
[9a3f8284] Random
This may be due to initial conditions. Have you tried running with a different random number generator seed? Or is this systematic?
Thanks @GTorlai, I was going to comment that the main issue appears to be differentiating through the MPS constructor, which we will fix but in the meantime a simple workaround is to move the MPS constructor outside of the cost function as you proposed.
I want to clarify that the brackets are not needed in vcat
:
julia> function variationalcircuit(params)
circuit = Tuple[]
circuit = vcat(circuit, [("X", 1)])
circuit = vcat(circuit, [("G", (1, 2), (θ = params[1],))])
circuit = vcat(circuit, [("Phase", 1, (ϕ = params[2],))])
circuit = vcat(circuit, [("Phase", 2, (ϕ = params[3],))])
circuit = vcat(circuit, [("G", (1, 2), (θ = -params[1],))])
return circuit
end
variationalcircuit (generic function with 1 method)
julia> function variationalcircuit2(params)
circuit = Tuple[]
circuit = vcat(circuit, ("X", 1))
circuit = vcat(circuit, ("G", (1, 2), (θ = params[1],)))
circuit = vcat(circuit, ("Phase", 1, (ϕ = params[2],)))
circuit = vcat(circuit, ("Phase", 2, (ϕ = params[3],)))
circuit = vcat(circuit, ("G", (1, 2), (θ = -params[1],)))
return circuit
end
variationalcircuit2 (generic function with 1 method)
julia> params_0 = 2π .* rand(3)
3-element Vector{Float64}:
4.280782156715938
3.8414080811580575
6.178409079753362
julia> variationalcircuit(params_0)
5-element Vector{Tuple}:
("X", 1)
("G", (1, 2), (θ = 4.280782156715938,))
("Phase", 1, (ϕ = 3.8414080811580575,))
("Phase", 2, (ϕ = 6.178409079753362,))
("G", (1, 2), (θ = -4.280782156715938,))
julia> variationalcircuit2(params_0)
5-element Vector{Tuple}:
("X", 1)
("G", (1, 2), (θ = 4.280782156715938,))
("Phase", 1, (ϕ = 3.8414080811580575,))
("Phase", 2, (ϕ = 6.178409079753362,))
("G", (1, 2), (θ = -4.280782156715938,))
Additionally, you can use the following syntax as a compact alternative to vcat
:
julia> function variationalcircuit3(params)
circuit = Tuple[]
circuit = [circuit; ("X", 1)]
circuit = [circuit; ("G", (1, 2), (θ = params[1],))]
circuit = [circuit; ("Phase", 1, (ϕ = params[2],))]
circuit = [circuit; ("Phase", 2, (ϕ = params[3],))]
circuit = [circuit; ("G", (1, 2), (θ = -params[1],))]
return circuit
end
variationalcircuit3 (generic function with 1 method)
julia> variationalcircuit3(params_0)
5-element Vector{Tuple}:
("X", 1)
("G", (1, 2), (θ = 4.280782156715938,))
("Phase", 1, (ϕ = 3.8414080811580575,))
("Phase", 2, (ϕ = 6.178409079753362,))
("G", (1, 2), (θ = -4.280782156715938,))
Internally it just calls vcat
. I've started to prefer that syntax when building circuits, but of course it is a personal preference.
See this section of the Julia documentation: https://docs.julialang.org/en/v1/manual/arrays/#man-array-concatenation for more information on Array concatenation.
@GTorlai No, whenever I run the script with a new initialisation I get the same behaviour, it does one 'good' iteration, then it starts to slowly increase...
[ Info: LBFGS: initializing with f = 0.748913103426, ‖∇f‖ = 8.7004e-01
┌ Warning: Linesearch bracket converged to a point without satisfying Wolfe conditions?
└ @ OptimKit ~/.julia/packages/OptimKit/xpmbV/src/linesearches.jl:236
[ Info: LBFGS: iter 1: f = 0.748914103426, ‖∇f‖ = 8.7004e-01, α = 1.58e-06, m = 0, nfg = 54
[ Info: LBFGS: iter 2: f = 0.000755341443, ‖∇f‖ = 1.0978e+00, α = 1.00e+00, m = 1, nfg = 1
┌ Warning: Linesearch bracket converged to a point without satisfying Wolfe conditions?
└ @ OptimKit ~/.julia/packages/OptimKit/xpmbV/src/linesearches.jl:236
[ Info: LBFGS: iter 3: f = 0.000756341443, ‖∇f‖ = 1.3207e+00, α = 9.43e-02, m = 2, nfg = 54
┌ Warning: Linesearch bracket converged to a point without satisfying Wolfe conditions?
└ @ OptimKit ~/.julia/packages/OptimKit/xpmbV/src/linesearches.jl:236
[ Info: LBFGS: iter 4: f = 0.000757341443, ‖∇f‖ = 1.3208e+00, α = 5.58e-05, m = 2, nfg = 54
┌ Warning: Linesearch bracket converged to a point without satisfying Wolfe conditions?
I'm confused why this would happen when running the same code as you, are you using different package versions?
@mtfishman Thanks for the info. Interestingly, when I do an '@show' to inspect 'variationalcircuit' and 'variationalcircuit2' it appears that they are returning the same object, however when I put variationalcircuit2/3 into the loss function, I get the following error (whereas variationalcircuit runs without error)
variationalcircuit(params_0) = Tuple[("X", 1), ("G", (1, 2), (θ = 1.8968763377256133,)), ("Phase", 1, (ϕ = 5.55160224854473,)), ("Phase", 2, (ϕ = 0.15762143907714976,)), ("G", (1, 2), (θ = -1.8968763377256133,))]
variationalcircuit2(params_0) = Tuple[("X", 1), ("G", (1, 2), (θ = 1.8968763377256133,)), ("Phase", 1, (ϕ = 5.55160224854473,)), ("Phase", 2, (ϕ = 0.15762143907714976,)), ("G", (1, 2), (θ = -1.8968763377256133,))]
ERROR: LoadError: BoundsError: attempt to access 5-element Vector{Tuple{Nothing, Nothing, Vararg{NamedTuple{names, Tuple{Float64}} where names, N} where N}} at index [5:7]
Stacktrace:
[1] throw_boundserror(A::Vector{Tuple{Nothing, Nothing, Vararg{NamedTuple{names, Tuple{Float64}} where names, N} where N}}, I::Tuple{UnitRange{Int64}})
@ Base ./abstractarray.jl:651
[2] checkbounds
@ ./abstractarray.jl:616 [inlined]
[3] getindex(A::Vector{Tuple{Nothing, Nothing, Vararg{NamedTuple{names, Tuple{Float64}} where names, N} where N}}, I::UnitRange{Int64})
@ Base ./array.jl:811
[4] (::Zygote.var"#513#518"{Vector{Tuple{Nothing, Nothing, Vararg{NamedTuple{names, Tuple{Float64}} where names, N} where N}}, Tuple{Bool}})(x::Tuple{String, Tuple{Int64, Int64}, NamedTuple{(:θ,), Tuple{Float64}}})
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/lib/array.jl:127
[5] map
@ ./tuple.jl:214 [inlined]
[6] #511
@ ~/.julia/packages/Zygote/cCyLF/src/lib/array.jl:123 [inlined]
[7] #2532#back
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:73 [inlined]
[8] #207
@ ~/.julia/packages/Zygote/cCyLF/src/lib/lib.jl:203 [inlined]
[9] #1748#back
@ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
[10] Pullback
@ ./abstractarray.jl:1710 [inlined]
[11] Pullback
@ ~/Documents/Programming/Julia/2q_XY.jl:47 [inlined]
[12] (::typeof(∂(variationalcircuit2)))(Δ::Vector{Tuple{Nothing, Nothing, Vararg{NamedTuple{names, Tuple{Float64}} where names, N} where N}})
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:0
[13] Pullback
@ ~/Documents/Programming/Julia/2q_XY.jl:54 [inlined]
[14] (::typeof(∂(loss)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:0
[15] (::Zygote.var"#55#56"{typeof(∂(loss))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface.jl:41
[16] (::Zygote.var"#57#58"{typeof(loss)})(x::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface.jl:83
[17] loss_n_grad(x::Vector{Float64})
@ Main ~/Documents/Programming/Julia/2q_XY.jl:69
[18] optimize(fg::typeof(loss_n_grad), x::Vector{Float64}, alg::LBFGS{Float64, HagerZhangLineSearch{Rational{Int64}}}; precondition::typeof(OptimKit._precondition), finalize!::typeof(OptimKit._finalize!), retract::Function, inner::typeof(OptimKit._inner), transport!::typeof(OptimKit._transport!), scale!::typeof(OptimKit._scale!), add!::typeof(OptimKit._add!), isometrictransport::Bool)
@ OptimKit ~/.julia/packages/OptimKit/xpmbV/src/lbfgs.jl:21
[19] optimize(fg::Function, x::Vector{Float64}, alg::LBFGS{Float64, HagerZhangLineSearch{Rational{Int64}}})
@ OptimKit ~/.julia/packages/OptimKit/xpmbV/src/lbfgs.jl:20
[20] top-level scope
@ ~/Documents/Programming/Julia/2q_XY.jl:70
in expression starting at /Users/joe/Documents/Programming/Julia/2q_XY.jl:70
Code:
using ITensors
using PastaQ
using Printf
using OptimKit
using Zygote
import PastaQ: gate
N = 2
hilbert = qubits(N)
gate(::GateName"G"; θ::Number) = [
1 0 0 0
0 cos(θ / 2) -sin(θ / 2) 0
0 sin(θ / 2) cos(θ / 2) 0
0 0 0 1
]
gate(::GateName"Ryy"; ϕ::Number) = [
cos(ϕ / 2) 0 0 im*sin(ϕ / 2)
0 cos(ϕ / 2) -im*sin(ϕ / 2) 0
0 -im*sin(ϕ / 2) cos(ϕ / 2) 0
im*sin(ϕ / 2) 0 0 cos(ϕ / 2)
]
Trot_gates = Tuple[("Rxx", (1, 2), (ϕ = 0.1,)),
("Ryy", (1, 2), (ϕ = 0.1,))]
target = runcircuit(hilbert, vcat(Tuple[("X", 1)], Trot_gates))
function variationalcircuit(params)
circuit = Tuple[]
circuit = vcat(circuit, [("X", 1)])
circuit = vcat(circuit, [("G", (1, 2), (θ = params[1],))])
circuit = vcat(circuit, [("Phase", 1, (ϕ = params[2],))])
circuit = vcat(circuit, [("Phase", 2, (ϕ = params[3],))])
circuit = vcat(circuit, [("G", (1, 2), (θ = -params[1],))])
return circuit
end
function variationalcircuit2(params)
circuit = Tuple[]
circuit = vcat(circuit, ("X", 1))
circuit = vcat(circuit, ("G", (1, 2), (θ = params[1],)))
circuit = vcat(circuit, ("Phase", 1, (ϕ = params[2],)))
circuit = vcat(circuit, ("Phase", 2, (ϕ = params[3],)))
circuit = vcat(circuit, ("G", (1, 2), (θ = -params[1],)))
return circuit
end
psi0 = productstate(hilbert)
function loss(params)
circuit = variationalcircuit2(params)
psi = runcircuit(psi0, circuit)
return 1 - abs2(inner(psi, target))
end
params_0 = 2π .* rand(3)
@show variationalcircuit(params_0)
@show variationalcircuit2(params_0)
optimizer = LBFGS(maxiter = 500, verbosity=2)
loss_n_grad(x) = (loss(x), convert(Vector, loss'(x)))
θ⃗, fs, gs, niter, normgradhistory = optimize(loss_n_grad, params_0, optimizer)
While both are valid circuit definitions, and will both run properly using the runcircuit
function, only the one with square brakets is currently differentiable, most likely a bug in Zygote. We do encounter a number of those when we build rrules
.
About the numerical instability, we are using the same package versions, but the only difference is that you are using Julia 1.6 (instead of the latest 1.7). This is something you could try updating, though I am not 100% sure as of now that that is the cause.
I did run your example code for many random number seeds and converged all times.
I deleted the packages I was importing, and after re-downloading them again the optimisation magically starting working as expected like yours. Hopefully all sorted now, thanks for the help :)
That's great to hear, glad we could help!
I am trying to maximise the fidelity of two states, and when I try to use the optimiser as used in your VQE example, I get an error seemingly originating from Zygote via the loss' call. Any suggestions for why this is happening?
Code:
Full Error: