Open danielalcalde opened 1 year ago
That's an impressive speedup! I was trying to think if we can do this same kind of optimization automatically in our inner
rrule
but I don't really see how besides manually checking if the bra and ket MPS are the same. That could technically work but is a bit tricky.
The main issue I see with your proposal is if H
isn't Hermitian, that derivative wouldn't be correct. We could have a flag ishermitian
which you can pass to inner
, but it's a bit funny having a keyword argument that's only used by the derivative rule. Alternatively we could just say that expect
should be used with Hermitian MPOs, though I'm not sure I like that.
As motivation, I have an example with 2x speedup here. Note that the larger the bond dimension of the Hamiltonian the larger the speedup.
using PastaQ
using ITensors
using Random
using Printf
using OptimKit
using Zygote
using BenchmarkTools
N = 10 # number of qubits
J = 1.0 # Ising exchange interaction
h = 0.5 # transverse magnetic field
# Hilbert space
hilbert = qubits(N)
function ising_hamiltonian(N; J, h)
os = OpSum()
for j in 1:(N - 1)
os += -J, "Z", j, "Z", j + 1
end
for j in 1:N
os += -h, "X", j
end
return os
end
# define the Hamiltonian
os = ising_hamiltonian(N; J, h)
# build MPO "cost function"
H = MPO(os, hilbert)
cutoff = 1e-10
# layer of single-qubit Ry gates
Rylayer(N, θ) = [("Ry", j, (θ=θ[j],)) for j in 1:N]
# brick-layer of CX gates
function CXlayer(N, Π)
start = isodd(Π) ? 1 : 2
return [("CX", (j, j + 1)) for j in start:2:(N - 1)]
end
# variational ansatz
function variationalcircuit(N, depth, θ)
circuit = Tuple[]
for d in 1:depth
circuit = vcat(circuit, CXlayer(N, d))
circuit = vcat(circuit, Rylayer(N, θ[d]))
end
return circuit
end
depth = 20
ψ = productstate(hilbert)
cutoff = 1e-8
maxdim = 200
# cost function
function loss(θ)
circuit = variationalcircuit(N, depth, θ)
Uψ = runcircuit(ψ, circuit; cutoff, maxdim)
return inner(Uψ', H, Uψ; cutoff, maxdim)
end
function expect(ψ, H; kwargs...)
return real(inner(ψ', H, ψ; kwargs...))
end
Zygote.@adjoint function expect(ψ, H; kwargs...)
function f̄(ȳ)
ψbar = contract(H, ψ'; kwargs...)
return ȳ * 2 * ψbar, nothing
end
return expect(ψ, H; kwargs...), f̄
end
function loss2(θ)
circuit = variationalcircuit(N, depth, θ)
Uψ = runcircuit(ψ, circuit; cutoff, maxdim)
return expect(Uψ, H; cutoff, maxdim)
end
Random.seed!(1234)
# initialize parameters
θ₀ = [2π .* rand(N) for _ in 1:depth];
gradient(loss, θ₀) .≈ gradient(loss2, θ₀)
# (true,)
@benchmark gradient(loss, θ₀)
BenchmarkTools.Trial: 3 samples with 1 evaluation.
Range (min … max): 1.441 s … 2.184 s ┊ GC (min … max): 6.06% … 37.79%
Time (median): 1.455 s ┊ GC (median): 6.00%
Time (mean ± σ): 1.694 s ± 425.015 ms ┊ GC (mean ± σ): 19.57% ± 18.46%
██ █
██▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█ ▁
1.44 s Histogram: frequency by time 2.18 s <
Memory estimate: 648.61 MiB, allocs estimate: 1518320.
@benchmark gradient(loss2, θ₀)
BenchmarkTools.Trial: 6 samples with 1 evaluation.
Range (min … max): 685.989 ms … 1.110 s ┊ GC (min … max): 10.04% … 43.52%
Time (median): 840.434 ms ┊ GC (median): 26.17%
Time (mean ± σ): 871.303 ms ± 204.527 ms ┊ GC (mean ± σ): 28.53% ± 17.36%
█
█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇▁▁▁▁▁▁▁▁▁▁▇▁▁▁▁▁▇ ▁
686 ms Histogram: frequency by time 1.11 s <
Memory estimate: 334.74 MiB, allocs estimate: 1468562
What do you think about using this code to check if the MPO is Hermitian. And throw an error if the expect function is used with a non hermitian MPO.
using StatsBase
function repeat_inds(a::ITensor)
for (v, i) in countmap(noprime(a.tensor.inds))
if i==2
return v
end
end
end
function LinearAlgebra.ishermitian(t::ITensor)
s = repeat_inds(t)
ts = swapprime(t, 0=>1; tags=[s, s'])
return t ≈ conj(ts)
end
LinearAlgebra.ishermitian(H::MPO) = all([ishermitian(t) for t in H])
It has come to my attention that the
loss
function defined in the tutorial may not be optimized for performance:Currently, backpropagation through the
inner
function is relatively slow, primarily because the function doesn't take into account two crucial aspects:Uψ'
andUψ
represent the same state, andH
is Hermitian.The function
expect
, defined as:is designed to exploit these properties and can result in a considerable performance boost (in my simulation 3s->200ms) to compute gradients.
The codebase does not seem to have an equivalent function. I suggest incorporating the
expect
or a similarly named function into the ITensors.jl or PastaQ.jl package, which will lead to significant performance improvements.Additionally, if an equivalent function already exists in the codebase, I recommend updating the tutorial to use this function instead of
inner
to make it more performance-oriented and user-friendly.