GTorlai / PastaQ.jl

Package for Simulation, Tomography and Analysis of Quantum Computers
Apache License 2.0
142 stars 23 forks source link

Zygote gradient error #263

Closed ghost closed 2 years ago

ghost commented 2 years ago

I am trying to maximise the fidelity of two states, and when I try to use the optimiser as used in your VQE example, I get an error seemingly originating from Zygote via the loss' call. Any suggestions for why this is happening?

Code:

using ITensors
using PastaQ
using Printf
using OptimKit
using Zygote

import PastaQ: gate

N = 2 
hilbert = qubits(N)

gate(::GateName"G"; θ::Number) = [
  1 0 0 0
  0 cos(θ / 2) -sin(θ / 2) 0
  0 sin(θ / 2) cos(θ / 2) 0
  0 0 0 1
]

gate(::GateName"Ryy"; ϕ::Number) = [
  cos(ϕ / 2) 0 0 im*sin(ϕ / 2)
  0 cos(ϕ / 2) -im*sin(ϕ / 2) 0
  0 -im*sin(ϕ / 2) cos(ϕ / 2) 0
  im*sin(ϕ / 2) 0 0 cos(ϕ / 2)
]

Trot_gates = Tuple[("Rxx", (1, 2), (ϕ = 0.1,)),
                   ("Ryy", (1, 2), (ϕ = 0.1,))]

target = runcircuit(hilbert, vcat(Tuple[("X", 1)], Trot_gates))

function variationalcircuit(params)

  circuit = Tuple[]
  circuit = vcat(circuit, ("X", 1))
  circuit = vcat(circuit, ("G", (1, 2), (θ = params[1],)))
  circuit = vcat(circuit, ("Phase", 1, (ϕ = params[2],)))
  circuit = vcat(circuit, ("Phase", 2, (ϕ = params[3],)))
  circuit = vcat(circuit, ("G", (1, 2), (θ = -params[1],)))

  return circuit
end

function loss(params)
  circuit = variationalcircuit(params)
  psi = runcircuit(hilbert, circuit)
  return 1 - fidelity(psi, target)
end

params_0 = 2π .* rand(3)

println(loss(params_0))

optimizer = LBFGS(maxiter = 500, verbosity=2)

loss_n_grad(x) = (loss(x), convert(Vector, loss'(x)))
θ⃗, fs, gs, niter, normgradhistory = optimize(loss_n_grad, params_0,  optimizer)

Full Error:

ERROR: LoadError: MethodError: no method matching setindex!(::ITensor, ::ITensor, ::Int64)
Closest candidates are:
  setindex!(::ITensor, ::AbstractArray, ::Any...) at /Users/joe/.julia/packages/ITensors/ZMKMP/src/itensor.jl:985
  setindex!(::ITensor, ::Number, ::Integer...) where N at /Users/joe/.julia/packages/ITensors/ZMKMP/src/itensor.jl:950
  setindex!(::ITensor, ::Number, ::Any...) where N at /Users/joe/.julia/packages/ITensors/ZMKMP/src/itensor.jl:976
  ...
Stacktrace:
  [1] macro expansion
    @ ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:0 [inlined]
  [2] _pullback(::Zygote.Context, ::typeof(setindex!), ::ITensor, ::ITensor, ::Int64)
    @ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:9
  [3] _pullback
    @ ~/.julia/packages/ITensors/ZMKMP/src/mps/mps.jl:336 [inlined]
  [4] _pullback(::Zygote.Context, ::Type{MPS}, ::Type{Float64}, ::Vector{Index{Int64}}, ::Vector{String})
    @ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:0
  [5] _pullback
    @ ~/.julia/packages/ITensors/ZMKMP/src/mps/mps.jl:351 [inlined]
  [6] _pullback(::Zygote.Context, ::Type{MPS}, ::Type{Float64}, ::Vector{Index{Int64}}, ::String)
    @ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:0
  [7] _pullback
    @ ~/.julia/packages/ITensors/ZMKMP/src/mps/mps.jl:377 [inlined]
  [8] _pullback
    @ ~/.julia/packages/PastaQ/dYp8T/src/productstates.jl:104 [inlined]
  [9] _pullback
    @ ~/.julia/packages/PastaQ/dYp8T/src/circuits/runcircuit.jl:322 [inlined]
 [10] _pullback(::Zygote.Context, ::PastaQ.var"##runcircuit#164", ::Bool, ::Bool, ::Nothing, ::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(runcircuit), ::Vector{Index{Int64}}, ::Vector{Tuple})
    @ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:0
 [11] _pullback
    @ ~/.julia/packages/PastaQ/dYp8T/src/circuits/runcircuit.jl:300 [inlined]
 [12] _pullback(::Zygote.Context, ::typeof(runcircuit), ::Vector{Index{Int64}}, ::Vector{Tuple})
    @ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:0
 [13] _pullback
    @ ~/Documents/Programming/Julia/2q_XY.jl:71 [inlined]
 [14] _pullback(ctx::Zygote.Context, f::typeof(loss), args::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:0
 [15] _pullback
    @ ~/.julia/packages/Zygote/cCyLF/src/compiler/interface.jl:34 [inlined]
 [16] pullback
    @ ~/.julia/packages/Zygote/cCyLF/src/compiler/interface.jl:40 [inlined]
 [17] (::Zygote.var"#57#58"{typeof(loss)})(x::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface.jl:82
 [18] top-level scope
    @ ~/Documents/Programming/Julia/2q_XY.jl:79
in expression starting at /Users/joe/Documents/Programming/Julia/2q_XY.jl:79
GTorlai commented 2 years ago

Hi @JoeGibbs88, when using AD functionalities through Zygote sometimes a specific syntax is required. In your example there are a few things that needs to be changed to make it work. First, when using vcat to build the circuit you should also include the square brakets around each new gate:

function variationalcircuit(params)
  circuit = Tuple[]
  circuit = vcat(circuit, [("X", 1)])
  circuit = vcat(circuit, [("G", (1, 2), (θ = params[1],))])
  circuit = vcat(circuit, [("Phase", 1, (ϕ = params[2],))])
  circuit = vcat(circuit, [("Phase", 2, (ϕ = params[3],))])
  circuit = vcat(circuit, [("G", (1, 2), (θ = -params[1],))])
  return circuit
end

Next, the loss function needs also couple of changes. First, differentiation through an MPS constructor is not yet defined in ITensors (see https://github.com/ITensor/ITensors.jl/issues/776). So for now you need to create the initial state outside the loss function. The other change is in the fidelity call, which is also currently not differentiable. For the time being, you could instead use the inner function in ITensor, which in this case computes the inner product of two MPSs, and build the fidelity from it. The reason fidelity is non-differentiable is because we wrote it to be numerically stable for large MPS, which requires some additional operators on top of inner whose derivatives are not yet taken care of (though we shall make a fix for this soon). So the loss function becomes:

psi0 = productstate(hilbert)
function loss(params)
  circuit = variationalcircuit(params)
  psi = runcircuit(psi0, circuit)
  return 1 - abs2(inner(psi, target))
end

With the above modification the code runs and outputs:

0.05545295586765009
[ Info: LBFGS: initializing with f = 0.055452955868, ‖∇f‖ = 2.8750e-01
[ Info: LBFGS: iter    1: f = 0.001045920807, ‖∇f‖ = 3.7467e-02, α = 4.30e-01, m = 0, nfg = 2
[ Info: LBFGS: iter    2: f = 0.000157296908, ‖∇f‖ = 1.4674e-03, α = 1.00e+00, m = 1, nfg = 1
[ Info: LBFGS: iter    3: f = 0.000155638615, ‖∇f‖ = 7.1383e-04, α = 1.00e+00, m = 2, nfg = 1
[ Info: LBFGS: iter    4: f = 0.000154660387, ‖∇f‖ = 8.0494e-04, α = 1.00e+00, m = 3, nfg = 1
[ Info: LBFGS: iter    5: f = 0.000148931986, ‖∇f‖ = 2.0502e-03, α = 1.00e+00, m = 4, nfg = 1
[ Info: LBFGS: iter    6: f = 0.000137419061, ‖∇f‖ = 3.7131e-03, α = 1.00e+00, m = 5, nfg = 1
[ Info: LBFGS: iter    7: f = 0.000099576772, ‖∇f‖ = 5.6274e-03, α = 1.00e+00, m = 6, nfg = 1
[ Info: LBFGS: iter    8: f = 0.000014151556, ‖∇f‖ = 4.3621e-03, α = 5.72e-01, m = 7, nfg = 3
[ Info: LBFGS: iter    9: f = 0.000010465556, ‖∇f‖ = 1.5371e-03, α = 4.85e-01, m = 8, nfg = 2
[ Info: LBFGS: iter   10: f = 0.000008838465, ‖∇f‖ = 1.6771e-03, α = 1.00e+00, m = 8, nfg = 1
[ Info: LBFGS: iter   11: f = 0.000002607997, ‖∇f‖ = 1.8504e-03, α = 1.00e+00, m = 8, nfg = 1
[ Info: LBFGS: iter   12: f = 0.000000235906, ‖∇f‖ = 2.6223e-05, α = 1.00e+00, m = 8, nfg = 1
[ Info: LBFGS: iter   13: f = 0.000000007655, ‖∇f‖ = 7.5840e-05, α = 1.00e+00, m = 8, nfg = 1
[ Info: LBFGS: iter   14: f = 0.000000000038, ‖∇f‖ = 7.0155e-06, α = 1.00e+00, m = 8, nfg = 1
[ Info: LBFGS: iter   15: f = 0.000000000000, ‖∇f‖ = 7.6904e-08, α = 1.00e+00, m = 8, nfg = 1
[ Info: LBFGS: converged after 16 iterations: f = 0.000000000000, ‖∇f‖ = 7.0827e-09

Hope this helps, and let us know if you still have issues with this!

ghost commented 2 years ago

Thank you for the detailed answer! After adding in your substitutions the code now runs without error. Strangely however, yours converges to 0 nicely while mine plateaus after one step...

First 10 iterations:

0.04319948746660329
[ Info: LBFGS: initializing with f = 0.043199487467, ‖∇f‖ = 1.8831e+00
[ Info: LBFGS: iter    1: f = 0.025405693379, ‖∇f‖ = 2.6732e-01, α = 1.00e+00, m = 0, nfg = 1
[ Info: LBFGS: iter    2: f = 0.024239523406, ‖∇f‖ = 1.4714e-02, α = 5.51e-01, m = 1, nfg = 2
┌ Warning: Linesearch bracket converged to a point without satisfying Wolfe conditions?
└ @ OptimKit ~/.julia/packages/OptimKit/xpmbV/src/linesearches.jl:189
[ Info: LBFGS: iter    3: f = 0.024240523406, ‖∇f‖ = 1.4697e-02, α = 2.70e-03, m = 2, nfg = 54
┌ Warning: Linesearch bracket converged to a point without satisfying Wolfe conditions?
└ @ OptimKit ~/.julia/packages/OptimKit/xpmbV/src/linesearches.jl:189
[ Info: LBFGS: iter    4: f = 0.024241523406, ‖∇f‖ = 1.4681e-02, α = 1.06e-03, m = 3, nfg = 54
┌ Warning: Linesearch bracket converged to a point without satisfying Wolfe conditions?
└ @ OptimKit ~/.julia/packages/OptimKit/xpmbV/src/linesearches.jl:189
[ Info: LBFGS: iter    5: f = 0.024242523406, ‖∇f‖ = 1.4665e-02, α = 9.91e-04, m = 4, nfg = 54
┌ Warning: Linesearch bracket converged to a point without satisfying Wolfe conditions?
└ @ OptimKit ~/.julia/packages/OptimKit/xpmbV/src/linesearches.jl:189
[ Info: LBFGS: iter    6: f = 0.024243523406, ‖∇f‖ = 1.4650e-02, α = 8.86e-04, m = 5, nfg = 54
┌ Warning: Linesearch bracket converged to a point without satisfying Wolfe conditions?
└ @ OptimKit ~/.julia/packages/OptimKit/xpmbV/src/linesearches.jl:189
[ Info: LBFGS: iter    7: f = 0.024244523406, ‖∇f‖ = 1.4635e-02, α = 7.95e-04, m = 6, nfg = 54
┌ Warning: Linesearch bracket converged to a point without satisfying Wolfe conditions?
└ @ OptimKit ~/.julia/packages/OptimKit/xpmbV/src/linesearches.jl:189
[ Info: LBFGS: iter    8: f = 0.024245523406, ‖∇f‖ = 1.4620e-02, α = 7.16e-04, m = 7, nfg = 54
┌ Warning: Linesearch bracket converged to a point without satisfying Wolfe conditions?
└ @ OptimKit ~/.julia/packages/OptimKit/xpmbV/src/linesearches.jl:189
[ Info: LBFGS: iter    9: f = 0.024246523406, ‖∇f‖ = 1.4606e-02, α = 6.48e-04, m = 8, nfg = 54
┌ Warning: Linesearch bracket converged to a point without satisfying Wolfe conditions?
└ @ OptimKit ~/.julia/packages/OptimKit/xpmbV/src/linesearches.jl:189
[ Info: LBFGS: iter   10: f = 0.024247523406, ‖∇f‖ = 1.4593e-02, α = 5.89e-04, m = 8, nfg = 54
┌ Warning: Linesearch bracket converged to a point without satisfying Wolfe conditions?

Package versions:

(@v1.6) pkg> status
      Status `~/.julia/environments/v1.6/Project.toml`
  [587475ba] Flux v0.12.9
  [9136182c] ITensors v0.2.15
  [d41bc354] NLSolversBase v7.8.2
  [76087f3c] NLopt v0.6.3
  [429524aa] Optim v1.6.2
  [77e91f04] OptimKit v0.3.1
  [30b07047] PastaQ v0.0.19
  [e88e6eb3] Zygote v0.6.35
  [de0858da] Printf
  [9a3f8284] Random
GTorlai commented 2 years ago

This may be due to initial conditions. Have you tried running with a different random number generator seed? Or is this systematic?

mtfishman commented 2 years ago

Thanks @GTorlai, I was going to comment that the main issue appears to be differentiating through the MPS constructor, which we will fix but in the meantime a simple workaround is to move the MPS constructor outside of the cost function as you proposed.

I want to clarify that the brackets are not needed in vcat:

julia> function variationalcircuit(params)
         circuit = Tuple[]
         circuit = vcat(circuit, [("X", 1)])
         circuit = vcat(circuit, [("G", (1, 2), (θ = params[1],))])
         circuit = vcat(circuit, [("Phase", 1, (ϕ = params[2],))])
         circuit = vcat(circuit, [("Phase", 2, (ϕ = params[3],))])
         circuit = vcat(circuit, [("G", (1, 2), (θ = -params[1],))])
         return circuit
       end
variationalcircuit (generic function with 1 method)

julia> function variationalcircuit2(params)
         circuit = Tuple[]
         circuit = vcat(circuit, ("X", 1))
         circuit = vcat(circuit, ("G", (1, 2), (θ = params[1],)))
         circuit = vcat(circuit, ("Phase", 1, (ϕ = params[2],)))
         circuit = vcat(circuit, ("Phase", 2, (ϕ = params[3],)))
         circuit = vcat(circuit, ("G", (1, 2), (θ = -params[1],)))
         return circuit
       end
variationalcircuit2 (generic function with 1 method)

julia> params_0 = 2π .* rand(3)
3-element Vector{Float64}:
 4.280782156715938
 3.8414080811580575
 6.178409079753362

julia> variationalcircuit(params_0)
5-element Vector{Tuple}:
 ("X", 1)
 ("G", (1, 2), (θ = 4.280782156715938,))
 ("Phase", 1, (ϕ = 3.8414080811580575,))
 ("Phase", 2, (ϕ = 6.178409079753362,))
 ("G", (1, 2), (θ = -4.280782156715938,))

julia> variationalcircuit2(params_0)
5-element Vector{Tuple}:
 ("X", 1)
 ("G", (1, 2), (θ = 4.280782156715938,))
 ("Phase", 1, (ϕ = 3.8414080811580575,))
 ("Phase", 2, (ϕ = 6.178409079753362,))
 ("G", (1, 2), (θ = -4.280782156715938,))

Additionally, you can use the following syntax as a compact alternative to vcat:

julia> function variationalcircuit3(params)
         circuit = Tuple[]
         circuit = [circuit; ("X", 1)]
         circuit = [circuit; ("G", (1, 2), (θ = params[1],))]
         circuit = [circuit; ("Phase", 1, (ϕ = params[2],))]
         circuit = [circuit; ("Phase", 2, (ϕ = params[3],))]
         circuit = [circuit; ("G", (1, 2), (θ = -params[1],))]
         return circuit
       end
variationalcircuit3 (generic function with 1 method)

julia> variationalcircuit3(params_0)
5-element Vector{Tuple}:
 ("X", 1)
 ("G", (1, 2), (θ = 4.280782156715938,))
 ("Phase", 1, (ϕ = 3.8414080811580575,))
 ("Phase", 2, (ϕ = 6.178409079753362,))
 ("G", (1, 2), (θ = -4.280782156715938,))

Internally it just calls vcat. I've started to prefer that syntax when building circuits, but of course it is a personal preference.

See this section of the Julia documentation: https://docs.julialang.org/en/v1/manual/arrays/#man-array-concatenation for more information on Array concatenation.

ghost commented 2 years ago

@GTorlai No, whenever I run the script with a new initialisation I get the same behaviour, it does one 'good' iteration, then it starts to slowly increase...

[ Info: LBFGS: initializing with f = 0.748913103426, ‖∇f‖ = 8.7004e-01
┌ Warning: Linesearch bracket converged to a point without satisfying Wolfe conditions?
└ @ OptimKit ~/.julia/packages/OptimKit/xpmbV/src/linesearches.jl:236
[ Info: LBFGS: iter    1: f = 0.748914103426, ‖∇f‖ = 8.7004e-01, α = 1.58e-06, m = 0, nfg = 54
[ Info: LBFGS: iter    2: f = 0.000755341443, ‖∇f‖ = 1.0978e+00, α = 1.00e+00, m = 1, nfg = 1
┌ Warning: Linesearch bracket converged to a point without satisfying Wolfe conditions?
└ @ OptimKit ~/.julia/packages/OptimKit/xpmbV/src/linesearches.jl:236
[ Info: LBFGS: iter    3: f = 0.000756341443, ‖∇f‖ = 1.3207e+00, α = 9.43e-02, m = 2, nfg = 54
┌ Warning: Linesearch bracket converged to a point without satisfying Wolfe conditions?
└ @ OptimKit ~/.julia/packages/OptimKit/xpmbV/src/linesearches.jl:236
[ Info: LBFGS: iter    4: f = 0.000757341443, ‖∇f‖ = 1.3208e+00, α = 5.58e-05, m = 2, nfg = 54
┌ Warning: Linesearch bracket converged to a point without satisfying Wolfe conditions?

I'm confused why this would happen when running the same code as you, are you using different package versions?

ghost commented 2 years ago

@mtfishman Thanks for the info. Interestingly, when I do an '@show' to inspect 'variationalcircuit' and 'variationalcircuit2' it appears that they are returning the same object, however when I put variationalcircuit2/3 into the loss function, I get the following error (whereas variationalcircuit runs without error)

variationalcircuit(params_0) = Tuple[("X", 1), ("G", (1, 2), (θ = 1.8968763377256133,)), ("Phase", 1, (ϕ = 5.55160224854473,)), ("Phase", 2, (ϕ = 0.15762143907714976,)), ("G", (1, 2), (θ = -1.8968763377256133,))]
variationalcircuit2(params_0) = Tuple[("X", 1), ("G", (1, 2), (θ = 1.8968763377256133,)), ("Phase", 1, (ϕ = 5.55160224854473,)), ("Phase", 2, (ϕ = 0.15762143907714976,)), ("G", (1, 2), (θ = -1.8968763377256133,))]
ERROR: LoadError: BoundsError: attempt to access 5-element Vector{Tuple{Nothing, Nothing, Vararg{NamedTuple{names, Tuple{Float64}} where names, N} where N}} at index [5:7]
Stacktrace:
  [1] throw_boundserror(A::Vector{Tuple{Nothing, Nothing, Vararg{NamedTuple{names, Tuple{Float64}} where names, N} where N}}, I::Tuple{UnitRange{Int64}})
    @ Base ./abstractarray.jl:651
  [2] checkbounds
    @ ./abstractarray.jl:616 [inlined]
  [3] getindex(A::Vector{Tuple{Nothing, Nothing, Vararg{NamedTuple{names, Tuple{Float64}} where names, N} where N}}, I::UnitRange{Int64})
    @ Base ./array.jl:811
  [4] (::Zygote.var"#513#518"{Vector{Tuple{Nothing, Nothing, Vararg{NamedTuple{names, Tuple{Float64}} where names, N} where N}}, Tuple{Bool}})(x::Tuple{String, Tuple{Int64, Int64}, NamedTuple{(:θ,), Tuple{Float64}}})
    @ Zygote ~/.julia/packages/Zygote/cCyLF/src/lib/array.jl:127
  [5] map
    @ ./tuple.jl:214 [inlined]
  [6] #511
    @ ~/.julia/packages/Zygote/cCyLF/src/lib/array.jl:123 [inlined]
  [7] #2532#back
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:73 [inlined]
  [8] #207
    @ ~/.julia/packages/Zygote/cCyLF/src/lib/lib.jl:203 [inlined]
  [9] #1748#back
    @ ~/.julia/packages/ZygoteRules/AIbCs/src/adjoint.jl:67 [inlined]
 [10] Pullback
    @ ./abstractarray.jl:1710 [inlined]
 [11] Pullback
    @ ~/Documents/Programming/Julia/2q_XY.jl:47 [inlined]
 [12] (::typeof(∂(variationalcircuit2)))(Δ::Vector{Tuple{Nothing, Nothing, Vararg{NamedTuple{names, Tuple{Float64}} where names, N} where N}})
    @ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:0
 [13] Pullback
    @ ~/Documents/Programming/Julia/2q_XY.jl:54 [inlined]
 [14] (::typeof(∂(loss)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface2.jl:0
 [15] (::Zygote.var"#55#56"{typeof(∂(loss))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface.jl:41
 [16] (::Zygote.var"#57#58"{typeof(loss)})(x::Vector{Float64})
    @ Zygote ~/.julia/packages/Zygote/cCyLF/src/compiler/interface.jl:83
 [17] loss_n_grad(x::Vector{Float64})
    @ Main ~/Documents/Programming/Julia/2q_XY.jl:69
 [18] optimize(fg::typeof(loss_n_grad), x::Vector{Float64}, alg::LBFGS{Float64, HagerZhangLineSearch{Rational{Int64}}}; precondition::typeof(OptimKit._precondition), finalize!::typeof(OptimKit._finalize!), retract::Function, inner::typeof(OptimKit._inner), transport!::typeof(OptimKit._transport!), scale!::typeof(OptimKit._scale!), add!::typeof(OptimKit._add!), isometrictransport::Bool)
    @ OptimKit ~/.julia/packages/OptimKit/xpmbV/src/lbfgs.jl:21
 [19] optimize(fg::Function, x::Vector{Float64}, alg::LBFGS{Float64, HagerZhangLineSearch{Rational{Int64}}})
    @ OptimKit ~/.julia/packages/OptimKit/xpmbV/src/lbfgs.jl:20
 [20] top-level scope
    @ ~/Documents/Programming/Julia/2q_XY.jl:70
in expression starting at /Users/joe/Documents/Programming/Julia/2q_XY.jl:70

Code:

using ITensors
using PastaQ
using Printf
using OptimKit
using Zygote

import PastaQ: gate

N = 2 
hilbert = qubits(N)

gate(::GateName"G"; θ::Number) = [
  1 0 0 0
  0 cos(θ / 2) -sin(θ / 2) 0
  0 sin(θ / 2) cos(θ / 2) 0
  0 0 0 1
]

gate(::GateName"Ryy"; ϕ::Number) = [
  cos(ϕ / 2) 0 0 im*sin(ϕ / 2)
  0 cos(ϕ / 2) -im*sin(ϕ / 2) 0
  0 -im*sin(ϕ / 2) cos(ϕ / 2) 0
  im*sin(ϕ / 2) 0 0 cos(ϕ / 2)
]

Trot_gates = Tuple[("Rxx", (1, 2), (ϕ = 0.1,)),
                   ("Ryy", (1, 2), (ϕ = 0.1,))]

target = runcircuit(hilbert, vcat(Tuple[("X", 1)], Trot_gates))

function variationalcircuit(params)
  circuit = Tuple[]
  circuit = vcat(circuit, [("X", 1)])
  circuit = vcat(circuit, [("G", (1, 2), (θ = params[1],))])
  circuit = vcat(circuit, [("Phase", 1, (ϕ = params[2],))])
  circuit = vcat(circuit, [("Phase", 2, (ϕ = params[3],))])
  circuit = vcat(circuit, [("G", (1, 2), (θ = -params[1],))])
  return circuit
end

function variationalcircuit2(params)
  circuit = Tuple[]
  circuit = vcat(circuit, ("X", 1))
  circuit = vcat(circuit, ("G", (1, 2), (θ = params[1],)))
  circuit = vcat(circuit, ("Phase", 1, (ϕ = params[2],)))
  circuit = vcat(circuit, ("Phase", 2, (ϕ = params[3],)))
  circuit = vcat(circuit, ("G", (1, 2), (θ = -params[1],)))
  return circuit
end

psi0 = productstate(hilbert)
function loss(params)
  circuit = variationalcircuit2(params)
  psi = runcircuit(psi0, circuit)
  return 1 - abs2(inner(psi, target))
end

params_0 = 2π .* rand(3)

@show variationalcircuit(params_0)
@show variationalcircuit2(params_0)

optimizer = LBFGS(maxiter = 500, verbosity=2)

loss_n_grad(x) = (loss(x), convert(Vector, loss'(x)))
θ⃗, fs, gs, niter, normgradhistory = optimize(loss_n_grad, params_0,  optimizer)
GTorlai commented 2 years ago

While both are valid circuit definitions, and will both run properly using the runcircuit function, only the one with square brakets is currently differentiable, most likely a bug in Zygote. We do encounter a number of those when we build rrules.

About the numerical instability, we are using the same package versions, but the only difference is that you are using Julia 1.6 (instead of the latest 1.7). This is something you could try updating, though I am not 100% sure as of now that that is the cause.

I did run your example code for many random number seeds and converged all times.

ghost commented 2 years ago

I deleted the packages I was importing, and after re-downloading them again the optimisation magically starting working as expected like yours. Hopefully all sorted now, thanks for the help :)

GTorlai commented 2 years ago

That's great to hear, glad we could help!