Open jay-selby opened 1 week ago
The attached pdf file below explains the different test cases(scenarios) and includes some analysis and comments.
###Scenario A (without qjit)###
import pennylane as qml
from catalyst import qjit, grad
import jax
import jax.numpy as jnp
import time
from functools import wraps
from functools import partial
import concurrent.futures as cf
#Create a global pool of max number of workers available in the system
exe = cf.ThreadPoolExecutor()
#Decorative wrapper for invoking a function asynchronously in a separate thread
def async_task(f, executor=exe):
@wraps(f)
def wrap(*args, **kwargs):
return (executor.submit(f, *args))
return wrap
#Serial function to capture the baseline performance without any parallelization
@qml.qnode(qml.device("lightning.qubit", wires=3))
def serial_func(q_arg, c_arg):
qml.RX(q_arg[0], wires=0)
qml.RY(q_arg[1], wires=1)
qml.RZ(q_arg[2], wires=2)
qml.CNOT(wires=[0,2])
matmul_res = jnp.matmul(c_arg,c_arg)
return qml.expval(qml.PauliZ(2))
@async_task
def parallel_func_1(q_arg, c_arg):
matmul_res = jnp.matmul(c_arg,c_arg)
qml.RX(q_arg[0], wires=0)
qml.RY(q_arg[1], wires=1)
qml.RZ(q_arg[2], wires=2)
qml.CNOT(wires=[0,2])
return qml.expval(qml.PauliZ(1))
@async_task
def parallel_func_2(q_arg, c_arg):
matmul_res = jnp.matmul(c_arg,c_arg)
qml.Hadamard(wires=0)
qml.RY(q_arg[1], wires=1)
qml.RZ(q_arg[2], wires=2)
qml.Toffoli(wires=[0,1,2])
return qml.expval(qml.PauliZ(0))
#1000x1000 matrix instatntiation and initialization with random numbers for testing
array_jax1 = jax.random.normal(jax.random.PRNGKey(0), (1000,1000), dtype=jnp.float32)
array_jax2 = jax.random.normal(jax.random.PRNGKey(0), (1000,1000), dtype=jnp.float32)
#Parameter list for quantum gates
q_parameters = jnp.array([0.011, 0.012, 0.13])
def main():
######### Serial region (start) ################
serial_func(q_parameters, array_jax1)
start_serial = time.time()
#Iterate over the serial function 10 times then average the execution time
for i in range(10):
serial_func(q_parameters, array_jax1)
end_serial = time.time()
serial_exe_time = (end_serial - start_serial)/10
print("Serial exe time: ", serial_exe_time)
######### Serial region (end) ################
######### Parallel region (start) ################
start_parallel = time.time()
# Calling two parallel functions
future1 = parallel_func_1(q_parameters, array_jax1)
future2 = parallel_func_2(q_parameters, array_jax2)
# Blocking the execution until all tasks are done
future1.result()
future2.result()
print("Async tasks finished? ", future1.done() and future2.done())
end_parallel = time.time()
parallel_exe_time = (end_parallel - start_parallel)
print("Parallel exe time: ", parallel_exe_time)
speedup = (serial_exe_time*2)/parallel_exe_time
print("Speedup = ", speedup)
######### Parallel region (end) ################
exe.shutdown()
pass
if __name__ == '__main__':
main()
###Scenario A(qjit)###
import pennylane as qml
from catalyst import qjit, grad
import jax
import jax.numpy as jnp
import time
from functools import wraps
from functools import partial
import concurrent.futures as cf
#Create a global pool of max number of workers available in the system
exe = cf.ThreadPoolExecutor()
#Serial function to capture the baseline performance without any parallelization
@qml.qnode(qml.device("lightning.qubit", wires=3))
def serial_func(q_arg, c_arg):
qml.RX(q_arg[0], wires=0)
qml.RY(q_arg[1], wires=1)
qml.RZ(q_arg[2], wires=2)
qml.CNOT(wires=[0,2])
matmul_res = jnp.matmul(c_arg,c_arg)
return qml.expval(qml.PauliZ(2))
def parallel_func_1(q_arg, c_arg):
matmul_res = jnp.matmul(c_arg,c_arg)
qml.RX(q_arg[0], wires=0)
qml.RY(q_arg[1], wires=1)
qml.RZ(q_arg[2], wires=2)
qml.CNOT(wires=[0,2])
return qml.expval(qml.PauliZ(1))
def parallel_func_2(q_arg, c_arg):
matmul_res = jnp.matmul(c_arg,c_arg)
qml.Hadamard(wires=0)
qml.RY(q_arg[1], wires=1)
qml.RZ(q_arg[2], wires=2)
qml.Toffoli(wires=[0,1,2])
return qml.expval(qml.PauliZ(0))
#1000x1000 matrix instatntiation and initialization with random numbers for testing
array_jax1 = jax.random.normal(jax.random.PRNGKey(0), (1000,1000), dtype=jnp.float32)
array_jax2 = jax.random.normal(jax.random.PRNGKey(0), (1000,1000), dtype=jnp.float32)
#Parameter list for quantum gates
q_parameters = jnp.array([0.011, 0.012, 0.13])
def main():
######### Serial region (start) ################
serial_func(q_parameters, array_jax1)
start_serial = time.time()
#Iterate over the serial function 10 times then average the execution time
for i in range(10):
serial_func(q_parameters, array_jax1)
end_serial = time.time()
serial_exe_time = (end_serial - start_serial)/10
print("Serial exe time: ", serial_exe_time)
######### Serial region (end) ################
######### Parallel region (start) ################
parallel_func_1_qjit = qjit(parallel_func_1)
parallel_func_2_qjit = qjit(parallel_func_2)
start_parallel = time.time()
# Calling two parallel functions
future1 = exe.submit(parallel_func_1_qjit, q_parameters, array_jax1)
future2 = exe.submit(parallel_func_2_qjit, q_parameters, array_jax2)
# Blocking the execution until all tasks are done
future1.result()
future2.result()
print("Async tasks finished? ", future1.done() and future2.done())
end_parallel = time.time()
parallel_exe_time = (end_parallel - start_parallel)
print("Parallel exe time: ", parallel_exe_time)
speedup = (serial_exe_time*2)/parallel_exe_time
print("Speedup = ", speedup)
######### Parallel region (end) ################
exe.shutdown()
pass
if __name__ == '__main__':
main()
###Scenario B (without qjit)###
import pennylane as qml
from catalyst import qjit, grad
import jax
import jax.numpy as jnp
import time
from functools import wraps
import concurrent.futures as cf
# Create a global pool of max number of workers available in the system
exe = cf.ThreadPoolExecutor()
# Decorative wrapper for invoking a function asynchronously in a separate thread
def async_task(f, executor=exe):
@wraps(f)
def wrap(*args, **kwargs):
return (executor.submit(f, *args))
return wrap
# Serial function to capture the baseline performance without any parallelization
@qml.qnode(qml.device("lightning.qubit", wires=3))
def serial_func(q_arg, c_arg):
qml.RX(q_arg[0], wires=0)
qml.RY(q_arg[1], wires=1)
qml.RZ(q_arg[2], wires=2)
qml.CNOT(wires=[0,2])
matmul_res = jnp.matmul(c_arg,c_arg)
return qml.expval(qml.PauliZ(2))
# Serial function to run in the parallel region.
# This simulates some computation in the main thread.
@qml.qnode(qml.device("lightning.qubit", wires=3))
def serial_func2(q_arg, c_arg):
qml.RX(q_arg[0], wires=0)
qml.RY(q_arg[1], wires=1)
qml.RZ(q_arg[2], wires=2)
qml.CNOT(wires=[0,2])
matmul_res = jnp.matmul(c_arg,c_arg)
return qml.expval(qml.PauliZ(2))
@async_task
def parallel_func_1(q_arg, c_arg):
matmul_res = jnp.matmul(c_arg,c_arg)
qml.RX(q_arg[0], wires=0)
qml.RY(q_arg[1], wires=1)
qml.RZ(q_arg[2], wires=2)
qml.CNOT(wires=[0,2])
return qml.expval(qml.PauliZ(1))
@async_task
def parallel_func_2(q_arg, c_arg):
matmul_res = jnp.matmul(c_arg,c_arg)
qml.Hadamard(wires=0)
qml.RY(q_arg[1], wires=1)
qml.RZ(q_arg[2], wires=2)
qml.Toffoli(wires=[0,1,2])
return qml.expval(qml.PauliZ(0))
@async_task
def parallel_matmul(c_arg):
return jnp.matmul(c_arg,c_arg)
# 1000x1000 matrix instatntiation and initialization with random numbers for testing
array_jax1 = jax.random.normal(jax.random.PRNGKey(0), (1000,1000), dtype=jnp.float32)
array_jax2 = jax.random.normal(jax.random.PRNGKey(0), (1000,1000), dtype=jnp.float32)
array_jax3 = jax.random.normal(jax.random.PRNGKey(0), (1000,1000), dtype=jnp.float32)
# Parameter list for quantum gates
q_parameters = jnp.array([0.011, 0.012, 0.13])
def main():
######### Serial region (start) ################
serial_func(q_parameters, array_jax1)
start_serial = time.time()
#Iterate over the serial function 10 times then average the execution time
for i in range(10):
serial_func(q_parameters, array_jax1)
end_serial = time.time()
serial_exe_time = (end_serial - start_serial)/10
print("Serial exe time: ", serial_exe_time)
######### Serial region (end) ################
######### Parallel region (start) ################
start_parallel = time.time()
# Running data independent parallel and serial functions
future1 = parallel_func_1(q_parameters, array_jax1)
future2 = parallel_matmul(array_jax3)
serial_func2(q_parameters, array_jax2)
# Blocking execution to wait for a future result
future2.result()
future3 = parallel_func_2(q_parameters, future2)
futures = [future1, future2, future3]
# Wait for all tasks to finish
cf.wait(futures)
print("Async tasks finished? ", future1.done() and future2.done() and future3.done())
end_parallel = time.time()
parallel_exe_time = (end_parallel - start_parallel)
print("Parallel exe time: ", parallel_exe_time)
speedup = (serial_exe_time*4)/parallel_exe_time
print("Speedup = ", speedup)
######### Parallel region (end) ################
exe.shutdown()
pass
if __name__ == '__main__':
main()
###Scenario B (qjit)###
import pennylane as qml
from catalyst import qjit, grad
import jax
import jax.numpy as jnp
import time
from functools import wraps
import concurrent.futures as cf
# Create a global pool of max number of workers available in the system
exe = cf.ThreadPoolExecutor()
# Serial function to capture the baseline performance without any parallelization
@qml.qnode(qml.device("lightning.qubit", wires=3))
def serial_func(q_arg, c_arg):
qml.RX(q_arg[0], wires=0)
qml.RY(q_arg[1], wires=1)
qml.RZ(q_arg[2], wires=2)
qml.CNOT(wires=[0,2])
matmul_res = jnp.matmul(c_arg,c_arg)
return qml.expval(qml.PauliZ(2))
# Serial function to run in the parallel region.
# This simulates some computation in the main thread.
@qml.qnode(qml.device("lightning.qubit", wires=3))
def serial_func2(q_arg, c_arg):
qml.RX(q_arg[0], wires=0)
qml.RY(q_arg[1], wires=1)
qml.RZ(q_arg[2], wires=2)
qml.CNOT(wires=[0,2])
matmul_res = jnp.matmul(c_arg,c_arg)
return qml.expval(qml.PauliZ(2))
@qml.qnode(qml.device("lightning.kokkos", wires=3))
def parallel_func_1(q_arg, c_arg):
matmul_res = jnp.matmul(c_arg,c_arg)
qml.RX(q_arg[0], wires=0)
qml.RY(q_arg[1], wires=1)
qml.RZ(q_arg[2], wires=2)
qml.CNOT(wires=[0,2])
return qml.expval(qml.PauliZ(1))
def parallel_func_2(q_arg, c_arg):
matmul_res = jnp.matmul(c_arg,c_arg)
qml.Hadamard(wires=0)
qml.RY(q_arg[1], wires=1)
qml.RZ(q_arg[2], wires=2)
qml.Toffoli(wires=[0,1,2])
return qml.expval(qml.PauliZ(0))
def parallel_matmul(c_arg):
return jnp.matmul(c_arg,c_arg)
# 1000x1000 matrix instatntiation and initialization with random numbers for testing
array_jax1 = jax.random.normal(jax.random.PRNGKey(0), (1000,1000), dtype=jnp.float32)
array_jax2 = jax.random.normal(jax.random.PRNGKey(0), (1000,1000), dtype=jnp.float32)
array_jax3 = jax.random.normal(jax.random.PRNGKey(0), (1000,1000), dtype=jnp.float32)
# Parameter list for quantum gates
q_parameters = jnp.array([0.011, 0.012, 0.13])
def main():
######### Serial region (start) ################
serial_func(q_parameters, array_jax1)
start_serial = time.time()
#Iterate over the serial function 10 times then average the execution time
for i in range(10):
serial_func(q_parameters, array_jax1)
end_serial = time.time()
serial_exe_time = (end_serial - start_serial)/10
print("Serial exe time: ", serial_exe_time)
######### Serial region (end) ################
######### Parallel region (start) ################
parallel_func_1_qjit = qjit(parallel_func_1)
parallel_func_2_qjit = qjit(parallel_func_2)
parallel_matmul_qjit = qjit(parallel_matmul)
start_parallel = time.time()
# Running data independent parallel and serial functions
future1 = exe.submit(parallel_func_1_qjit, q_parameters, array_jax1)
future2 = exe.submit(parallel_matmul_qjit, array_jax3)
serial_func2(q_parameters, array_jax2)
# Blocking execution to wait for a future result
future2.result()
future3 = exe.submit(parallel_func_2_qjit, q_parameters, future2)
futures = [future1, future2, future3]
# Wait for all tasks to finish
cf.wait(futures)
print("Async tasks finished? ", future1.done() and future2.done() and future3.done())
end_parallel = time.time()
parallel_exe_time = (end_parallel - start_parallel)
print("Parallel exe time: ", parallel_exe_time)
speedup = (serial_exe_time*4)/parallel_exe_time
print("Speedup = ", speedup)
######### Parallel region (end) ################
exe.shutdown()
pass
if __name__ == '__main__':
main()
###Scenario C (without qjit)###
import pennylane as qml
from catalyst import qjit, grad
import jax
import jax.numpy as jnp
import time
import functools
from functools import wraps
import concurrent.futures as cf
# Create a global pool of max number of workers available in the system
exe = cf.ThreadPoolExecutor()
# Decorative wrapper for invoking a function asynchronously in a separate thread
def async_task(f, executor=exe):
@wraps(f)
def wrap(*args, **kwargs):
return (executor.submit(f, *args))
return wrap
# Serial function to capture the baseline performance without any parallelization
@qml.qnode(qml.device("lightning.qubit", wires=3))
def serial_func(q_arg, c_arg):
qml.RX(q_arg[0], wires=0)
qml.RY(q_arg[1], wires=1)
qml.RZ(q_arg[2], wires=2)
qml.CNOT(wires=[0,2])
matmul_res = jnp.matmul(c_arg,c_arg)
return qml.expval(qml.PauliZ(2))
# Serial function to run in the parallel region.
# This simulates some computation in the main thread.
@qml.qnode(qml.device("lightning.qubit", wires=3))
def serial_func2(q_arg, c_arg):
qml.RX(q_arg[0], wires=0)
qml.RY(q_arg[1], wires=1)
qml.RZ(q_arg[2], wires=2)
qml.CNOT(wires=[0,2])
matmul_res = jnp.matmul(c_arg,c_arg)
return qml.expval(qml.PauliZ(2))
@async_task
def parallel_func_1(q_arg, c_arg):
matmul_res = jnp.matmul(c_arg,c_arg)
qml.RX(q_arg[0], wires=0)
qml.RY(q_arg[1], wires=1)
qml.RZ(q_arg[2], wires=2)
qml.CNOT(wires=[0,2])
return qml.expval(qml.PauliZ(1))
@async_task
def parallel_func_2(q_arg, c_arg):
matmul_res = jnp.matmul(c_arg,c_arg)
qml.Hadamard(wires=0)
qml.RY(q_arg[1], wires=1)
qml.RZ(q_arg[2], wires=2)
qml.Toffoli(wires=[0,1,2])
return qml.expval(qml.PauliZ(0))
@async_task
def parallel_func_3(q_arg, c_arg):
matmul_res = jnp.matmul(c_arg,c_arg)
qml.Hadamard(wires=0)
qml.RY(q_arg[1], wires=1)
qml.RZ(q_arg[2], wires=2)
qml.Toffoli(wires=[0,1,2])
return qml.expval(qml.PauliZ(0))
@async_task
def parallel_matmul(c_arg):
return jnp.matmul(c_arg,c_arg)
# 1000x1000 matrix instatntiation and initialization with random numbers for testing
array_jax1 = jax.random.normal(jax.random.PRNGKey(0), (1000,1000), dtype=jnp.float32)
array_jax2 = jax.random.normal(jax.random.PRNGKey(0), (1000,1000), dtype=jnp.float32)
array_jax3 = jax.random.normal(jax.random.PRNGKey(0), (1000,1000), dtype=jnp.float32)
# Parameter list for quantum gates
q_parameters = jnp.array([0.011, 0.012, 0.13])
def main():
######### Serial region (start) ################
serial_func(q_parameters, array_jax1)
start_serial = time.time()
#Iterate over the serial function 10 times then average the execution time
for i in range(10):
serial_func(q_parameters, array_jax1)
end_serial = time.time()
serial_exe_time = (end_serial - start_serial)/10
print("Serial exe time: ", serial_exe_time)
######### Serial region (end) ################
######### Parallel region (start) ################
start_parallel = time.time()
future1 = parallel_func_1(q_parameters, array_jax1)
future2 = parallel_matmul(array_jax3)
serial_func2(q_parameters, array_jax2)
# Using cf.add_done_callback to create a call back for parallel_func_2 when
# the data it needs is ready
# functools.partial is used to pass the quantum parameter to the call back function
future2.add_done_callback(functools.partial(parallel_func_2, q_parameters))
# Other data independent parallel functions
parallel_matmul(array_jax1)
parallel_matmul(array_jax1)
parallel_matmul(array_jax1)
parallel_func_3(q_parameters, array_jax1)
# Shutting down to make sure all parallel tasks are finished
exe.shutdown()
end_parallel = time.time()
parallel_exe_time = (end_parallel - start_parallel)
print("Parallel exe time: ", parallel_exe_time)
speedup = (serial_exe_time*8)/parallel_exe_time
print("Speedup = ", speedup)
######### Parallel region (end) ################
pass
if __name__ == '__main__':
main()
Hi @mwasfy, thank you for the submission!
I would be curious to ask you a few follow-up questions about this solution.
Looking at Scenario A:
The observed speedup is ∼1.4X, which could be attributed to the fact that I am running on one device and the overhead of context switching is not insignificant.
How do you explain the speedup when using multi-threading despite Python's global interpreter lock?
Looking a Scenario A (qjit):
I noticed that the base case (serial) does not use qjit, whereas the parallel case does. Aren't we comparing different things then since one version is compiled and the other is not?
To answer one your comments, one reason we are seeing a slowdown in the qjit version is that it has to compile the function first before running it, and the parallel functions do not receive a "warmup run" like the serial version does, which would alleviate the issue.
Regarding your first comment:
For the non-jitted version, I created a decorative wrapper to conveniently call asynchronous parallel tasks. This wrapper didn’t work with qjitted functions returning an error that the future is not a valid Jax type.
I think this error might pop up if the qjit decorator is placed on top of the async wrapper. Have you tried putting the qjit decorator "inside" the async one?
This is rather minor, but I'm also curious why the test function was copy pasted a few times, rather than re-using one definition. Was there a specific reason for this?
Hi @dime10, thank you for your insightful comments. It cleared many of my concerns.
@qjit
decorator rather than f_qjit = qjit(f)
. What I noticed then is that calling the qjitted function multiple times when using the @qjit
decorator, it always returned the error “All measurements must be returned in the order they are measured.”
But I don’t get that error when I use f_qjit = qjit(f)
@qjit
on top of @async
, however, this setup worked:
@async_task
@qjit
def parallel_func_1(q_arg, c_arg):
Copying and pasting functions: Mainly I didn’t want to run in the “All measurements must be returned in the order they are measured.”
error. Hence, I used different functions with slightly different syntax.
Hi @mwasfy, thanks for your reply!
Copying and pasting functions: Mainly I didn’t want to run in the “All measurements must be returned in the order they are measured.” error. Hence, I used different functions with slightly different syntax.
I believe this is an issue with the PennyLane library not being thread-safe, since it uses a global context to capture quantum instructions in a QNode. I think we can get around this issue by only making the execution of a qjit-ed function is asynchronous, but not its capture/compilation.
Regarding point 1:
I think because under the hood numpy releases the GIL and uses its own machine code.
This is actually a good point, numpy does appear to do that for many of its functions. The functions we are interested in are typically quantum functions, which will execute a quantum circuit on a device using the PennyLane library. Do you think this reasoning applies there as well? For the QJIT case, Catalyst doesn't use numpy library code during execution, do you think multi-threading can help there?
Btw, I noticed that the parallel functions don't have the @qnode
decorator while the serial function does, doesn't that mean we are comparing different things again?
Hi @dime10, thanks for getting back to me.
I think we can get around this issue by only making the execution of a qjit-ed function is asynchronous, but not its capture/compilation.
Actually I think that is how it was implemented. Please take a look at the following code snippet. Wouldn’t that be the case you are describing as a work around. (Unless I am not understating it well). BTW, there was no use of decorators here for qjit or async.
parallel_func_1_qjit = qjit(parallel_func_1)
future1 = exe.submit(parallel_func_1_qjit, q_parameters, array_jax1)
The functions we are interested in are typically quantum functions, which will execute a quantum circuit on a device using the PennyLane library. Do you think this reasoning applies there as well?
We’ll be running on a separate device its own machine code, so yes I think the same reasoning would apply. Actually executing on a separate “quantum” device underscore the importance of such an approach for asynchronous tasks even more.
For the QJIT case, Catalyst doesn't use numpy library code during execution, do you think multi-threading can help there?
I think this is more a question of multi-threading vs multi-processing. Intuitively, I would say this is supposed to be a compute intensive function so the answer would be multi-processing. However, if the quantum device acts like an attached co-processor where we send inputs and wait for outputs, it could be considered a case of IO bound computation from the main thread’s perspective where multi-threading would be better. Having a specific answer for that question requires some more internal knowledge of how both Catalyst and PennyLane work and how the quantum device is connected to the host and how they interact with each other.
Btw, I noticed that the parallel functions don't have the @qnode decorator while the serial function does, doesn't that mean we are comparing different things again?
As a matter of fact, I did test all the parallel function with @qnode, there was no difference in terms of performance. So I neglected them when writing up the final test scenarios I submitted. lightning.qubit
and lightning.kokkos
didn’t seem to show any difference in execution time for these circuits. But then again this may be because these are toy circuits, not complex circuits with enough depth where there would be a meaningful difference in performance. In fact that is why I opted to insert matmul operations in there to simulate long processing time (I didn’t want to use sleep functions). Another reason for inserting matmul was I wanted to create data dependency. To be honest, I am not sure how to create data dependency between two quantum functions (may be use the evaluation of one circuit to initialize qubits for another circuit, I am not sure).
Actually I think that is how it was implemented. Please take a look at the following code snippet. Wouldn’t that be the case you are describing as a work around. (Unless I am not understating it well). BTW, there was no use of decorators here for qjit or async.
Both the @qnode
and @qjit
decorator work in a similar way to "just-in-time compilation". That is there is a big difference between the first run of the function and subsequent runs. For the QNode, on first run it constructs a data structure of the quantum circuit by executing the Python code in the function, only then does it send this data structure to a device for execution. All subsequent runs then directly jump to second part (the execution).
So if the first call is executed asynchronously, then the "capture phase" (the construction of the data structure) might happen in parallel leading to the described issue with the global context.
For QJIT, the procedure is very similar. The first call will compile the function to binary code (again by running the Python function, capturing operations within, and compiling them), and then execute the compiled code. Subsequent calls directly execute the binary code.
This is why its so important that all test cases (serial or parallel) use the same setup, otherwise we are measuring very different things.
For instance, the matrix multiplication inside a QNode (without QJIT) will only happen once, during the circuit construction. If the function is not decorated with @qnode
then the multiplication happens every time the function is called.
To be honest, I am not sure how to create data dependency between two quantum functions (may be use the evaluation of one circuit to initialize qubits for another circuit, I am not sure).
Since quantum functions always have classical inputs (e.g. rotation angles) and outputs (e.g. expectation values), a dependency can also be created by using the output of one quantum function as the input of another:
@qml.qnode(qml.device("lightning.qubit", wires=3))
def circuit(phi):
qml.RX(phi, wires=0)
qml.CNOT(wires=[0,2])
return qml.expval(qml.PauliZ(2))
input = 0.7
output = circuit(input)
_ = circuit(output)
Thanks for the clarification, I guess from what you are describing Catalyst uses lazy compilation. Does it support eager as well? However, I tested the same set up like you suggested in the code below. I ran both serial and parallel once before starting the profiling process. I see that parallel async is actually slower, which means that GIL does in fact get in the way.
import pennylane as qml
from catalyst import qjit, grad
import jax
import jax.numpy as jnp
import time
import concurrent.futures as cf
#Create a global pool of max number of workers available in the system
exe = cf.ThreadPoolExecutor()
@qjit
@qml.qnode(qml.device("lightning.qubit", wires=3))
def serial_func(q_arg, c_arg):
qml.RX(q_arg[0], wires=0)
qml.RY(q_arg[1], wires=1)
qml.RZ(q_arg[2], wires=2)
qml.CNOT(wires=[0,2])
matmul_res = jnp.matmul(c_arg,c_arg)
return qml.expval(qml.PauliZ(0))
@qjit
@qml.qnode(qml.device("lightning.qubit", wires=3))
def parallel_func(q_arg, c_arg):
qml.RX(q_arg[0], wires=0)
qml.RY(q_arg[1], wires=1)
qml.RZ(q_arg[2], wires=2)
qml.CNOT(wires=[0,2])
matmul_res = jnp.matmul(c_arg,c_arg)
return qml.expval(qml.PauliZ(1))
#1000x1000 matrix instatntiation and initialization with random numbers for testing
array_jax1 = jax.random.normal(jax.random.PRNGKey(0), (1000,1000), dtype=jnp.float32)
#Parameter list for quantum gates
q_parameters = jnp.array([0.011, 0.012, 0.13])
def main():
######### Serial region (start) ################
# First call to invoke compilation & exeution
start_compile_serial = time.time()
serial_func(q_parameters, array_jax1)
end_compile_serial = time.time()
serial_compile = end_compile_serial-start_compile_serial
print("First serial call (compile): ", serial_compile)
# Second call to invoke execution only
start_exe_serial = time.time()
serial_func(q_parameters, array_jax1)
end_exe_serial = time.time()
serial_exe = end_exe_serial-start_exe_serial
print("Second serial call (exe): ", serial_exe)
######### Serial region (end) ################
######### Parallel region (start) ################
# First call to invoke compilation
start_compile_parallel = time.time()
f1 = exe.submit(parallel_func, q_parameters, array_jax1)
f1.result()
end_compile_parallel = time.time()
parallel_compile = end_compile_parallel-start_compile_parallel
print("First parallel call (compile): ", parallel_compile)
# Second call to invoke execution only
start_exe_parallel = time.time()
f2 = exe.submit(parallel_func, q_parameters, array_jax1)
f3 = exe.submit(parallel_func, q_parameters, array_jax1)
f2.result()
f3.result()
end_exe_parallel = time.time()
parallel_exe = end_exe_parallel-start_exe_parallel
print("Second parallel call (exe): ", parallel_exe)
######### Parallel region (end) ################
exe.shutdown()
speedup = (serial_exe*2)/parallel_exe
print("Speedup = ", speedup)
pass
if __name__ == '__main__':
main()
One more comment about the code I just shared now, some times it runs perfectly and some times it gives the error: Error in Catalyst Runtime: Invalid use of the global driver before initialization
Context
Thread-Level Speculation is a technique that has been used in various research to speed up general purpose programs by speculatively executing code downstream of a function call. The idea here is to do this in a similar manner to JAX, see Asynchronous Dispatch in the JAX docs.
JAX does not wait for the operation to complete before returning control to the Python program. Instead, JAX returns a
DeviceArray
value, which is a future, i.e., a value that will be produced in the future on an accelerator device but isn’t necessarily available immediately. Only when the value of theDeviceArray
is queried is a blocking call generated.Consider the following code snippet. Here,
x
, a device array returned as the result of evaluatingf
is a futureDeviceArray
, and blocking only occurs when a user requests the value ofx
in Python.Questions:
The assumption here is that this will lead to speedups in the following situation (this assumption needs to be validated, but should be apparent in an interpreted language):
That is, since
x
is evaluated asynchronously, Python is not blocked awaiting the result off
and can simply invokeg
directly.Requirements:
qjit
'ted function is executed in parallel with the compiled function.qjit
'ted function.Installation Help
Refer to the Catalyst installation guide for how to install a source build of the project.