PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
138 stars 35 forks source link

Performance degradation with vmap and large data #1153

Closed dime10 closed 3 weeks ago

dime10 commented 1 month ago

I ran across a problem that I thought was a non-terminating program, but turns out to be severe performance degradation when using vmap on large amounts of data, such as a QNode returning qml.probs() on 20 qubits. This performance issue is present even when the batch dimension is 1. The following circuit starts to noticeably slow down at around the 17 qubit mark:

@qjit
@vmap(in_axes=0)
@qml.qnode(qml.device("lightning.qubit", wires=17))
def empty_circuit(a):
    return qml.probs()

print(empty_circuit(jnp.zeros((1,))))

We can compare the runtime of this circuit with and without vmap, since the amount of work should be the same (batch dimension 1): vmap_scaling_qubits

Worse yet, this is not a constant overhead, but seems to scale with the batch dimension anyways. The following plot compares vmap to a non-vmap implementation that aggregates results in Python (jnp.array([empty_circuit(jnp.zeros((1,))) for _ in range(n)])), holding the qubit number at 15 and varying the batch dimension n: vmap_scaling_batch_size

Not necessarily the cause, but looking at the IR from the example at the top (n=17 case), there are some oddities:


The plots were generated using the following code, with minor modifications for the second plot (wires=15 and python aggregation):

def with_vmap(n):

    @qjit
    @vmap(in_axes=0)
    @qml.qnode(qml.device("lightning.qubit", wires=n))
    def empty_circuit(a):
        return qml.probs()

    return empty_circuit(jnp.zeros((1,)))

def without_vmap(n):

    @qjit
    @qml.qnode(qml.device("lightning.qubit", wires=n))
    def empty_circuit(a):
        return qml.probs()

    return empty_circuit(jnp.zeros((1,)))

xx = np.arange(1, 18, 1)
yy = np.array([timeit.timeit(lambda: without_vmap(n), number=5) for n in xx])
yy_v = np.array([timeit.timeit(lambda: with_vmap(n), number=5) for n in xx])

plt.plot(xx, yy, label="no vmap")
plt.plot(xx, yy_v, label="with vmap")
plt.xlabel("number of qubits")
plt.ylabel("runtime [s]")
plt.title("Vmap scaling with amount of data vs raw execution")
plt.legend()
plt.show()
erick-xanadu commented 1 month ago

Is it just copying data? Never mind. Not sure what's happening yet. Yes, it looks like it is just copying data during the scatter lowering.

paul0403 commented 1 month ago
  • a very large amount of constant data is embedded into the IR during the hlo lowering process

This big tensor is generated during scattering, and is followed by a loop whose size equals its size (2^num_of_wires): image

Now, the major chunk of runtime could be spent on the creation of this giant vector in memory, or the loop after it.

Realizing that the big tensor is just a container for indices and isn't actually necessary, we try trowing it away, aka the following (with 6 wires for illustration) has the same functionality: image

This treatment does not decrease runtime for 17 wires.

Therefore it must be the loop below that is the culprit.

////////////////////// Script (uncleaned) for investigation: vmap.py.txt unknown_17.zip

erick-xanadu commented 1 month ago

I did a little bit of digging here regarding vmap. I discovered that vmap will always call the function at least once, and then save the result of this zero-th iteration.

While this is correct, I do not like it too much. I fixed it with this patch:

diff --git a/frontend/catalyst/api_extensions/function_maps.py b/frontend/catalyst/api_extensions/function_maps.py
index 4ebe07f92..c3650b05b 100644
--- a/frontend/catalyst/api_extensions/function_maps.py
+++ b/frontend/catalyst/api_extensions/function_maps.py
@@ -226,7 +226,11 @@ class VmapCallable(CatalystCallable):
         fn_args = tree_unflatten(args_tree, fn_args_flat)

         # Run 'fn' one time to get output-shape
-        init_result = self.fn(*fn_args, **kwargs)
+        _, shape = jax.make_jaxpr(self.fn, return_shape=True)(*fn_args, **kwargs)
+        shapes, init_result_tree = tree_flatten(shape)
+        init_result_flat = [jnp.zeros(shape=shape.shape, dtype=shape.dtype) for shape in shapes]
+        init_result = tree_unflatten(init_result_tree, init_result_flat)
+        

         # Check the validity of the output w.r.t. out_axes
         out_axes_deep_struct = tree_structure(self.out_axes, is_leaf=lambda x: x is None)
@@ -238,8 +242,6 @@ class VmapCallable(CatalystCallable):
                 f"{out_axes_deep_struct} axis specifiers and {init_result_deep_struct} results."
             )

-        init_result_flat, init_result_tree = tree_flatten(init_result)
-
         num_axes_out = len(init_result_flat)

         if isinstance(self.out_axes, int):
@@ -264,7 +266,7 @@ class VmapCallable(CatalystCallable):
             batched_result_list[j] = batched_result_list[j].at[0].set(init_result_flat[j])

         # Apply mapping batched_args[1:] ---> fn(args)
-        @for_loop(1, batch_size, 1)
+        @for_loop(0, batch_size, 1)
         def loop_fn(i, batched_result_list):
             fn_args_flat = args_flat
             for loc in batch_loc:

The stablehlo.scatter operation is more general than just tensor.insert_slice. But if we look at our uses of JAX which generate the stablehlo.scatter operation in vmap, we see that they are only setting values, which can easily be changed to tensor.insert_slice.

            # Locations that produce scatter:

            batched_result_list[j] = batched_result_list[j].at[0].set(init_result_flat[j])
# ...
            # Update the list of results
            for j in range(num_axes_out):
                batched_result_list[j] = batched_result_list[j].at[i].set(res_flat[j])

So, we can just create a new JAX primitive for tensor.insert_slice. The tensor dialect is not included by default in JAX, but we ship our own python generated files. This is a bit hardcoded:

diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py
index 03a68697b..46d5277f7 100644
--- a/frontend/catalyst/jax_primitives.py
+++ b/frontend/catalyst/jax_primitives.py
@@ -97,6 +97,7 @@ from mlir_quantum.dialects.quantum import (
     VarianceOp,
 )
 from mlir_quantum.dialects.quantum import YieldOp as QYieldOp
+from mlir_quantum.dialects.tensor import InsertSliceOp, ScatterOp as TensorScatterOp, insert_slice

 from catalyst.compiler import get_lib_path
 from catalyst.jax_extras import (
@@ -299,6 +300,7 @@ set_basis_state_p = jax.core.Primitive("set_basis_state")
 set_basis_state_p.multiple_results = True
 quantum_kernel_p = core.CallPrimitive("quantum_kernel")
 quantum_kernel_p.multiple_results = True
+tensor_insert_slice_p = jax.core.Primitive("tensor_insert_slice_p")

 def _assert_jaxpr_without_constants(jaxpr: ClosedJaxpr):
@@ -413,6 +415,22 @@ def _print_lowering(jax_ctx: mlir.LoweringRuleContext, *args, string=None, memre
     return PrintOp(val=val, const_val=None, print_descriptor=memref).results

+@tensor_insert_slice_p.def_abstract_eval
+def _abs_eval(source_tensor, dest_tensor, indices):
+    return dest_tensor
+
+def _tensor_insert_slice_lowering(ctx, source_tensor, dest_tensor, indices):
+    dyn = ir.ShapedType.get_dynamic_size()
+    off = ir.DenseI64ArrayAttr.get([dyn, 0])
+
+    siz = ir.RankedTensorType(source_tensor.type).shape 
+    siz = ir.DenseI64ArrayAttr.get([1, *siz])
+    stri = ir.DenseI64ArrayAttr.get([1, 1])
+    p = TensorExtractOp(ir.RankedTensorType(indices.type).element_type, indices, []).result  # tensor<i64> -> i64
+    p = IndexCastOp(ir.IndexType.get(), p).result  # i64 -> index
+    x = InsertSliceOp(source_tensor, dest_tensor, [p], [], [], off, siz, stri)
+    #x = TensorScatterOp(dest_tensor.type, source_tensor, dest_tensor, indices, [0, 1], unique=True)
+    return x.results
 #
 # transform dialect lowering
 #
@@ -2393,6 +2411,7 @@ CUSTOM_LOWERING_RULES = (
     (sin_p, _sin_lowering2),
     (cos_p, _cos_lowering2),
     (quantum_kernel_p, _quantum_kernel_lowering),
+    (tensor_insert_slice_p, _tensor_insert_slice_lowering),
 )

And you need to modify the python file for the InsertSliceOp like this:

  def __init__(self, source, dest, offsets, sizes, strides, static_offsets, static_sizes, static_strides, *, loc=None, ip=None):
    operands = []
    results = [dest.type] # <----
    attributes = {}
    regions = None
    operands.append(_get_op_result_or_value(source))
    operands.append(_get_op_result_or_value(dest))
    operands.append(_get_op_results_or_values(offsets))
    operands.append(_get_op_results_or_values(sizes))
    operands.append(_get_op_results_or_values(strides))
    _ods_context = _ods_get_default_loc_context(loc)
    attributes["static_offsets"] = (static_offsets if (
    isinstance(static_offsets, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('DenseI64ArrayAttr')) else
      _ods_ir.AttrBuilder.get('DenseI64ArrayAttr')(static_offsets, context=_ods_context))
    attributes["static_sizes"] = (static_sizes if (
    isinstance(static_sizes, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('DenseI64ArrayAttr')) else
      _ods_ir.AttrBuilder.get('DenseI64ArrayAttr')(static_sizes, context=_ods_context))
    attributes["static_strides"] = (static_strides if (
    isinstance(static_strides, _ods_ir.Attribute) or
    not _ods_ir.AttrBuilder.contains('DenseI64ArrayAttr')) else
      _ods_ir.AttrBuilder.get('DenseI64ArrayAttr')(static_strides, context=_ods_context))
    _ods_successors = None
    super().__init__(self.build_generic(attributes=attributes, operands=operands, results=results, successors=_ods_successors, regions=regions, loc=loc, ip=ip)) #<---- add results

Finally you can replace indexing with InsertSliceOp

diff --git a/frontend/catalyst/api_extensions/function_maps.py b/frontend/catalyst/api_extensions/function_maps.py
index c3650b05b..53420ece3 100644
--- a/frontend/catalyst/api_extensions/function_maps.py
+++ b/frontend/catalyst/api_extensions/function_maps.py
@@ -32,6 +32,7 @@ from catalyst.api_extensions.control_flow import for_loop
 from catalyst.tracing.contexts import EvaluationContext
 from catalyst.tracing.type_signatures import get_stripped_signature
 from catalyst.utils.callables import CatalystCallable
+from catalyst.jax_primitives import tensor_insert_slice_p

 ## API ##
@@ -263,7 +264,8 @@ class VmapCallable(CatalystCallable):
                 else (batch_size, *init_result_flat[j].shape)
             )
             batched_result_list.append(jnp.zeros(shape=out_shape, dtype=init_result_flat[j].dtype))
-            batched_result_list[j] = batched_result_list[j].at[0].set(init_result_flat[j])
+            batched_result_list[j] = tensor_insert_slice_p.bind(init_result_flat[j], batched_result_list[j], 0)
+            #batched_result_list[j] = batched_result_list[j].at[0].set(init_result_flat[j])

         # Apply mapping batched_args[1:] ---> fn(args)
         @for_loop(0, batch_size, 1)
@@ -279,8 +281,12 @@ class VmapCallable(CatalystCallable):
             res_flat, _ = tree_flatten(res)

             # Update the list of results
             for j in range(num_axes_out):
-                batched_result_list[j] = batched_result_list[j].at[i].set(res_flat[j])
+                batched_result_list[j] = tensor_insert_slice_p.bind(res_flat[j], batched_result_list[j], i)

             return batched_result_list

For the same example:

import pennylane as qml
import catalyst
import jax.numpy as jnp

@qml.qjit
@catalyst.vmap(in_axes=0)
@qml.qnode(qml.device("lightning.qubit", wires=17))
def empty_circuit(a):
    return qml.probs()

print(empty_circuit(jnp.zeros((1,))))
(env) ubuntu@touched-stingray:~/code/catalyst$ time python test.py
[[1. 0. 0. ... 0. 0. 0.]]

real    0m5.978s
user    0m6.701s
sys     0m2.971s

Instead of lowering stablehlo.scatter to a for loop and updating one index at a time, tensor.insert_slice gets lowered to some copies:

      %21 = func.call @empty_circuit_0(%alloc_10) : (memref<f64>) -> memref<131072xf64>
      %alloc_12 = memref.alloc() {alignment = 64 : i64} : memref<3x131072xf64>
      memref.copy %arg2, %alloc_12 : memref<3x131072xf64> to memref<3x131072xf64>
      memref.dealloc %arg2 : memref<3x131072xf64>
      %subview_13 = memref.subview %alloc_12[%arg1, 0] [1, 131072] [1, 1] : memref<3x131072xf64> to memref<131072xf64, strided<[1], offset: ?>>
      memref.copy %21, %subview_13 : memref<131072xf64> to memref<131072xf64, strided<[1], offset: ?>>
      %alloc_14 = memref.alloc() : memref<3x131072xf64>
      memref.copy %alloc_12, %alloc_14 : memref<3x131072xf64> to memref<3x131072xf64>
      memref.dealloc %alloc_12 : memref<3x131072xf64>
      scf.yield %alloc_14 : memref<3x131072xf64

Still could be optimized further. Not sure why it copied and then deallocates immediately instead of reusing %arg2.

For the second example:

import pennylane as qml
import catalyst
import jax.numpy as jnp

@qml.qjit
@catalyst.vmap(in_axes=0)
@qml.qnode(qml.device("lightning.qubit", wires=15))
def empty_circuit(a):
    return qml.probs()

print(empty_circuit(jnp.zeros((7,))))
(env) ubuntu@touched-stingray:~/code/catalyst$ !tim
time python test.py
[[1. 0. 0. ... 0. 0. 0.]
 [1. 0. 0. ... 0. 0. 0.]
 [1. 0. 0. ... 0. 0. 0.]
 ...
 [1. 0. 0. ... 0. 0. 0.]
 [1. 0. 0. ... 0. 0. 0.]
 [1. 0. 0. ... 0. 0. 0.]]

real    0m6.121s
user    0m6.833s
sys     0m2.984s
dime10 commented 4 days ago

So I ran the above two benchmarks again and they look massively better! 🚀

image

image

erick-xanadu commented 4 days ago

We can keep improving them :)

dime10 commented 3 days ago

I mean, I think they look great! Most notably, we can now see an actual advantage of using vmap over accumulating results in a list and casting the list to an array (plot 2), whereas before vmap massively out-scaled the vmap-less approach :)

The other things of note:

erick-xanadu commented 2 days ago

It's curious that at the tail end of the first benchmark vmap appears to scale faster than no vmap. Increasing the qubit number does scale the amount of data exponentially, so if we are observing that exponential scaling it would make sense, the only question is whether the magnitude of it makes sense and why we don't see it in the no vmap case.

I think what's happening comes from the extra copy I mentioned above:

      %21 = func.call @empty_circuit_0(%alloc_10) : (memref<f64>) -> memref<131072xf64>
      %alloc_12 = memref.alloc() {alignment = 64 : i64} : memref<3x131072xf64>
      memref.copy %arg2, %alloc_12 : memref<3x131072xf64> to memref<3x131072xf64>
      memref.dealloc %arg2 : memref<3x131072xf64>
      %subview_13 = memref.subview %alloc_12[%arg1, 0] [1, 131072] [1, 1] : memref<3x131072xf64> to memref<131072xf64, strided<[1], offset: ?>>
      memref.copy %21, %subview_13 : memref<131072xf64> to memref<131072xf64, strided<[1], offset: ?>>
      %alloc_14 = memref.alloc() : memref<3x131072xf64>
      memref.copy %alloc_12, %alloc_14 : memref<3x131072xf64> to memref<3x131072xf64>
      memref.dealloc %alloc_12 : memref<3x131072xf64>
      scf.yield %alloc_14 : memref<3x131072xf64

%arg2 is the place where we store all information. It gets copied, and then freed. And then we write into that copy. And then we do another copy. Too many copies in my opinion. I think we are doing 3 times the amount of copying necessary.