tencent-quantum-lab / tensorcircuit

Tensor network based quantum software framework for the NISQ era
https://tensorcircuit.readthedocs.io
Apache License 2.0
265 stars 78 forks source link

MPS circuit with JIT compilation #204

Closed Muzhou-Ma closed 5 months ago

Muzhou-Ma commented 8 months ago

Issue Description

When I'm using tensorflow as backend and construct MPSCircuit, I use jit for compilation. But it seems that JIT is not working and it has warning like WARNING:tensorflow:Using a while_loop for converting Qr cause there is no registered converter for this op. & WARNING:tensorflow:Using a while_loop for converting SVD cause there is no registered converter for this op..

Proposed Solution

Make operators in MPScircuit Jitable

refraction-ray commented 8 months ago

please attach a reproduce demo

Muzhou-Ma commented 5 months ago

Hi, here is a reproduce demo:

import tensorcircuit as tc
import tensorflow as tf
import numpy as np

tc.set_backend("tensorflow")
tc.set_dtype("complex64")
def Hamiltonian(c: tc.MPSCircuit, n: int):
    e = 0.0
    for i in range(n):
        e += -1 * tf.cast(c.expectation_ps(z=[i]), tf.float64)
    return -tc.backend.real(e)

def vqe(params, n):
    circuit = tc.MPSCircuit(n)
    circuit.set_split_rules({"max_singular_values": 50})

    for i in range(n):
        circuit.rx(i,theta=params[i][0])
        circuit.ry(i,theta=params[i][1])
        circuit.rz(i,theta=params[i][2])

    energy = Hamiltonian(circuit, n)
    return energy

vqe_vvag = tc.backend.jit(
    tc.backend.vectorized_value_and_grad(vqe, vectorized_argnums = (0,)), static_argnums=(1,)
)

if __name__=="__main__":
    batch = 16
    n = 8
    maxiter = 100
    params = tf.Variable(
            initial_value=tf.concat(
                [tf.random.normal(shape=[int(batch/4), n, 3], mean=0, stddev=0.2, dtype=getattr(tf, tc.rdtypestr)),
                tf.random.normal(shape=[int(batch/4), n, 3], mean=np.pi/4, stddev=0.2, dtype=getattr(tf, tc.rdtypestr)),
                tf.random.normal(shape=[int(batch/4), n, 3], mean=np.pi/2, stddev=0.2, dtype=getattr(tf, tc.rdtypestr)),
                tf.random.normal(shape=[int(batch/4), n, 3], mean=np.pi*3/4, stddev=0.2, dtype=getattr(tf, tc.rdtypestr))
                ],0)
        )
    opt = tf.keras.optimizers.legacy.Adam(1e-2)
    for i in range(maxiter):
        energy, grad = vqe_vvag(params, n)
        opt.apply_gradients([(grad, params)])
        print(energy)

Thanks a lot!

refraction-ray commented 5 months ago

Thanks for providing the demo, but I can successfully run your demo with no error, my environment info attached below

>>> tc.about()
OS info: macOS-10.15.7-x86_64-i386-64bit
Python version: 3.10.0
Numpy version: 1.24.3
Scipy version: 1.10.1
Pandas version: 2.0.3
TensorNetwork version: 0.5.0
Cotengra version: 0.6.0
TensorFlow version: 2.13.0
TensorFlow GPU: []
TensorFlow CUDA infos: {'is_cuda_build': False, 'is_rocm_build': False, 'is_tensorrt_build': False}
Jax version: 0.4.14
Jax installation doesn't support GPU
JaxLib version: 0.4.14
PyTorch version: 2.0.1
PyTorch GPU support: False
PyTorch GPUs: []
Cupy is not installed
Qiskit version: 0.45.1
Cirq version: 1.2.0
TensorCircuit version 0.12.0
refraction-ray commented 5 months ago

Ah, you mean the warning, I indeed see the warning but I believe it doesn't affect the results. I will further investigate whether the warning has negative effect on jit or whether we can get rid of the warning.

Have checked now! The warning is not related to jit but to vmap. If we use value_and_grad instead of vvag, the warning is gone. The reason for the warning is that there is no vectorized implementation for QR in tensorflow.

refraction-ray commented 5 months ago

If you feel tf is not fast enough, you can always try the following snippet for your actual circuit and hyperparameters, to determine which backend is more suitable (tf vs. jax)

import tensorcircuit as tc
import numpy as np
import time

tc.set_dtype("complex64")

def Hamiltonian(c: tc.MPSCircuit, n: int):
    e = 0.0
    for i in range(n):
        e += -1 * c.expectation_ps(z=[i])
    return -tc.backend.real(e)

def vqe(params, n):
    circuit = tc.MPSCircuit(n)
    circuit.set_split_rules({"max_singular_values": 50})

    for i in range(n):
        circuit.rx(i, theta=params[i][0])
        circuit.ry(i, theta=params[i][1])
        circuit.rz(i, theta=params[i][2])
    for i in range(n-1):
        circuit.cx(i, i+1)

    energy = Hamiltonian(circuit, n)
    return energy

if __name__ == "__main__":
    batch = 16
    n = 16
    maxiter = 100
    params0 = np.random.uniform(size=[batch, n, 3])

    for b in ["tensorflow", "jax"]:
        with tc.runtime_backend(b):
            vqe_vvag = tc.backend.jit(
                tc.backend.vectorized_value_and_grad(vqe, vectorized_argnums=(0,)),
                static_argnums=(1,),
            )
            print("benchmarking backend: %s" % b)
            time0 = time.time()
            params = tc.backend.convert_to_tensor(params0)
            energy, grad = vqe_vvag(params, n)
            print(energy, grad)
            print("jit time", time.time() - time0)
            time0 = time.time()
            for _ in range(5):
                energy, grad = vqe_vvag(params, n)
            print("running time", (time.time() - time0) / 5)
Muzhou-Ma commented 5 months ago

Aha, I see. Thanks a lot! So it seems that we can't use vvag for speeding up with tf as backend.

Muzhou-Ma commented 5 months ago

I will close this issue, many thanks!

refraction-ray commented 5 months ago

Aha, I see. Thanks a lot! So it seems that we can't use vvag for speeding up with tf as backend.

For this point, I dont know. Maybe you can have some microbenchmarks on vvag over batch vs. naive for loop with tf backend. It is also possible that other operations are vectorized which may still be more efficient that a for loop.