Package for Simulation, Tomography and Analysis of Quantum Computers
Gradient Performance Enhancement: Implementing `expect` Function to Exploit Hermitian Nature of H #301

danielalcalde opened 1 year ago

danielalcalde commented 1 year ago

It has come to my attention that the loss function defined in the tutorial may not be optimized for performance:

function loss(θ)
  circuit = variationalcircuit(N, depth, θ)
  Uψ = runcircuit(ψ, circuit; cutoff, maxdim)
  return inner(Uψ', H, Uψ; cutoff, maxdim)

Currently, backpropagation through the inner function is relatively slow, primarily because the function doesn't take into account two crucial aspects: Uψ' and represent the same state, and H is Hermitian.

The function expect, defined as:

function expect(ψ, H; kwargs...)
    return real(inner(ψ', H, ψ; kwargs...))

Zygote.@adjoint function expect(ψ, H; kwargs...)
    function f̄(ȳ)
        ψbar = contract(H, ψ'; kwargs...)
        return ȳ * 2 * ψbar, nothing
   return expect(ψ, H; kwargs...), f̄

is designed to exploit these properties and can result in a considerable performance boost (in my simulation 3s->200ms) to compute gradients.

The codebase does not seem to have an equivalent function. I suggest incorporating the expect or a similarly named function into the ITensors.jl or PastaQ.jl package, which will lead to significant performance improvements.

Additionally, if an equivalent function already exists in the codebase, I recommend updating the tutorial to use this function instead of inner to make it more performance-oriented and user-friendly.

mtfishman commented 1 year ago

That's an impressive speedup! I was trying to think if we can do this same kind of optimization automatically in our inner rrule but I don't really see how besides manually checking if the bra and ket MPS are the same. That could technically work but is a bit tricky.

The main issue I see with your proposal is if H isn't Hermitian, that derivative wouldn't be correct. We could have a flag ishermitian which you can pass to inner, but it's a bit funny having a keyword argument that's only used by the derivative rule. Alternatively we could just say that expect should be used with Hermitian MPOs, though I'm not sure I like that.

danielalcalde commented 1 year ago

As motivation, I have an example with 2x speedup here. Note that the larger the bond dimension of the Hamiltonian the larger the speedup.

using PastaQ
using ITensors
using Random
using Printf
using OptimKit
using Zygote
using BenchmarkTools

N = 10   # number of qubits
J = 1.0  # Ising exchange interaction
h = 0.5  # transverse magnetic field

# Hilbert space
hilbert = qubits(N)

function ising_hamiltonian(N; J, h)
  os = OpSum()
  for j in 1:(N - 1)
    os += -J, "Z", j, "Z", j + 1
  for j in 1:N
    os += -h, "X", j
  return os

# define the Hamiltonian
os = ising_hamiltonian(N; J, h)

# build MPO "cost function"
H = MPO(os, hilbert)

cutoff = 1e-10

# layer of single-qubit Ry gates
Rylayer(N, θ) = [("Ry", j, (θ=θ[j],)) for j in 1:N]

# brick-layer of CX gates
function CXlayer(N, Π)
  start = isodd(Π) ? 1 : 2
  return [("CX", (j, j + 1)) for j in start:2:(N - 1)]

# variational ansatz
function variationalcircuit(N, depth, θ)
  circuit = Tuple[]
  for d in 1:depth
    circuit = vcat(circuit, CXlayer(N, d))
    circuit = vcat(circuit, Rylayer(N, θ[d]))
  return circuit

depth = 20
ψ = productstate(hilbert)

cutoff = 1e-8
maxdim = 200

# cost function
function loss(θ)
  circuit = variationalcircuit(N, depth, θ)
  Uψ = runcircuit(ψ, circuit; cutoff, maxdim)
  return inner(Uψ', H, Uψ; cutoff, maxdim)

function expect(ψ, H; kwargs...)
    return real(inner(ψ', H, ψ; kwargs...))

Zygote.@adjoint function expect(ψ, H; kwargs...)
    function f̄(ȳ)
        ψbar = contract(H, ψ'; kwargs...)
        return ȳ * 2 * ψbar, nothing
   return expect(ψ, H; kwargs...), f̄

function loss2(θ)
  circuit = variationalcircuit(N, depth, θ)
  Uψ = runcircuit(ψ, circuit; cutoff, maxdim)
  return expect(Uψ, H; cutoff, maxdim)


# initialize parameters
θ₀ = [2π .* rand(N) for _ in 1:depth];

gradient(loss, θ₀) .≈ gradient(loss2, θ₀)
# (true,)

@benchmark gradient(loss, θ₀)
BenchmarkTools.Trial: 3 samples with 1 evaluation.
 Range (min … max):  1.441 s …    2.184 s  ┊ GC (min … max):  6.06% … 37.79%
 Time  (median):     1.455 s               ┊ GC (median):     6.00%
 Time  (mean ± σ):   1.694 s ± 425.015 ms  ┊ GC (mean ± σ):  19.57% ± 18.46%

  ██                                                       █  
  ██▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█ ▁
  1.44 s         Histogram: frequency by time         2.18 s <

 Memory estimate: 648.61 MiB, allocs estimate: 1518320.

@benchmark gradient(loss2, θ₀)
BenchmarkTools.Trial: 6 samples with 1 evaluation.
 Range (min … max):  685.989 ms …    1.110 s  ┊ GC (min … max): 10.04% … 43.52%
 Time  (median):     840.434 ms               ┊ GC (median):    26.17%
 Time  (mean ± σ):   871.303 ms ± 204.527 ms  ┊ GC (mean ± σ):  28.53% ± 17.36%

  █▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇▁▁▁▁▁▁▁▁▁▁▇▁▁▁▁▁▇ ▁
  686 ms           Histogram: frequency by time          1.11 s <

 Memory estimate: 334.74 MiB, allocs estimate: 1468562
danielalcalde commented 1 year ago

What do you think about using this code to check if the MPO is Hermitian. And throw an error if the expect function is used with a non hermitian MPO.

using StatsBase

function repeat_inds(a::ITensor)
    for (v, i) in countmap(noprime(a.tensor.inds))
        if i==2
            return v

function LinearAlgebra.ishermitian(t::ITensor)
    s = repeat_inds(t)
    ts = swapprime(t, 0=>1; tags=[s, s'])
    return t ≈ conj(ts)

LinearAlgebra.ishermitian(H::MPO) = all([ishermitian(t) for t in H])