Closed dime10 closed 3 weeks ago
Is it just copying data? Never mind. Not sure what's happening yet. Yes, it looks like it is just copying data during the scatter lowering.
- a very large amount of constant data is embedded into the IR during the hlo lowering process
This big tensor is generated during scattering, and is followed by a loop whose size equals its size (2^num_of_wires):
Now, the major chunk of runtime could be spent on the creation of this giant vector in memory, or the loop after it.
Realizing that the big tensor is just a container for indices and isn't actually necessary, we try trowing it away, aka the following (with 6 wires for illustration) has the same functionality:
This treatment does not decrease runtime for 17 wires.
Therefore it must be the loop below that is the culprit.
////////////////////// Script (uncleaned) for investigation: vmap.py.txt unknown_17.zip
I did a little bit of digging here regarding vmap. I discovered that vmap will always call the function at least once, and then save the result of this zero-th iteration.
While this is correct, I do not like it too much. I fixed it with this patch:
diff --git a/frontend/catalyst/api_extensions/function_maps.py b/frontend/catalyst/api_extensions/function_maps.py
index 4ebe07f92..c3650b05b 100644
--- a/frontend/catalyst/api_extensions/function_maps.py
+++ b/frontend/catalyst/api_extensions/function_maps.py
@@ -226,7 +226,11 @@ class VmapCallable(CatalystCallable):
fn_args = tree_unflatten(args_tree, fn_args_flat)
# Run 'fn' one time to get output-shape
- init_result = self.fn(*fn_args, **kwargs)
+ _, shape = jax.make_jaxpr(self.fn, return_shape=True)(*fn_args, **kwargs)
+ shapes, init_result_tree = tree_flatten(shape)
+ init_result_flat = [jnp.zeros(shape=shape.shape, dtype=shape.dtype) for shape in shapes]
+ init_result = tree_unflatten(init_result_tree, init_result_flat)
+
# Check the validity of the output w.r.t. out_axes
out_axes_deep_struct = tree_structure(self.out_axes, is_leaf=lambda x: x is None)
@@ -238,8 +242,6 @@ class VmapCallable(CatalystCallable):
f"{out_axes_deep_struct} axis specifiers and {init_result_deep_struct} results."
)
- init_result_flat, init_result_tree = tree_flatten(init_result)
-
num_axes_out = len(init_result_flat)
if isinstance(self.out_axes, int):
@@ -264,7 +266,7 @@ class VmapCallable(CatalystCallable):
batched_result_list[j] = batched_result_list[j].at[0].set(init_result_flat[j])
# Apply mapping batched_args[1:] ---> fn(args)
- @for_loop(1, batch_size, 1)
+ @for_loop(0, batch_size, 1)
def loop_fn(i, batched_result_list):
fn_args_flat = args_flat
for loc in batch_loc:
The stablehlo.scatter
operation is more general than just tensor.insert_slice
. But if we look at our uses of JAX which generate the stablehlo.scatter
operation in vmap
, we see that they are only setting values, which can easily be changed to tensor.insert_slice
.
# Locations that produce scatter:
batched_result_list[j] = batched_result_list[j].at[0].set(init_result_flat[j])
# ...
# Update the list of results
for j in range(num_axes_out):
batched_result_list[j] = batched_result_list[j].at[i].set(res_flat[j])
So, we can just create a new JAX primitive for tensor.insert_slice
. The tensor dialect is not included by default in JAX, but we ship our own python generated files. This is a bit hardcoded:
diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py
index 03a68697b..46d5277f7 100644
--- a/frontend/catalyst/jax_primitives.py
+++ b/frontend/catalyst/jax_primitives.py
@@ -97,6 +97,7 @@ from mlir_quantum.dialects.quantum import (
VarianceOp,
)
from mlir_quantum.dialects.quantum import YieldOp as QYieldOp
+from mlir_quantum.dialects.tensor import InsertSliceOp, ScatterOp as TensorScatterOp, insert_slice
from catalyst.compiler import get_lib_path
from catalyst.jax_extras import (
@@ -299,6 +300,7 @@ set_basis_state_p = jax.core.Primitive("set_basis_state")
set_basis_state_p.multiple_results = True
quantum_kernel_p = core.CallPrimitive("quantum_kernel")
quantum_kernel_p.multiple_results = True
+tensor_insert_slice_p = jax.core.Primitive("tensor_insert_slice_p")
def _assert_jaxpr_without_constants(jaxpr: ClosedJaxpr):
@@ -413,6 +415,22 @@ def _print_lowering(jax_ctx: mlir.LoweringRuleContext, *args, string=None, memre
return PrintOp(val=val, const_val=None, print_descriptor=memref).results
+@tensor_insert_slice_p.def_abstract_eval
+def _abs_eval(source_tensor, dest_tensor, indices):
+ return dest_tensor
+
+def _tensor_insert_slice_lowering(ctx, source_tensor, dest_tensor, indices):
+ dyn = ir.ShapedType.get_dynamic_size()
+ off = ir.DenseI64ArrayAttr.get([dyn, 0])
+
+ siz = ir.RankedTensorType(source_tensor.type).shape
+ siz = ir.DenseI64ArrayAttr.get([1, *siz])
+ stri = ir.DenseI64ArrayAttr.get([1, 1])
+ p = TensorExtractOp(ir.RankedTensorType(indices.type).element_type, indices, []).result # tensor<i64> -> i64
+ p = IndexCastOp(ir.IndexType.get(), p).result # i64 -> index
+ x = InsertSliceOp(source_tensor, dest_tensor, [p], [], [], off, siz, stri)
+ #x = TensorScatterOp(dest_tensor.type, source_tensor, dest_tensor, indices, [0, 1], unique=True)
+ return x.results
#
# transform dialect lowering
#
@@ -2393,6 +2411,7 @@ CUSTOM_LOWERING_RULES = (
(sin_p, _sin_lowering2),
(cos_p, _cos_lowering2),
(quantum_kernel_p, _quantum_kernel_lowering),
+ (tensor_insert_slice_p, _tensor_insert_slice_lowering),
)
And you need to modify the python file for the InsertSliceOp
like this:
def __init__(self, source, dest, offsets, sizes, strides, static_offsets, static_sizes, static_strides, *, loc=None, ip=None):
operands = []
results = [dest.type] # <----
attributes = {}
regions = None
operands.append(_get_op_result_or_value(source))
operands.append(_get_op_result_or_value(dest))
operands.append(_get_op_results_or_values(offsets))
operands.append(_get_op_results_or_values(sizes))
operands.append(_get_op_results_or_values(strides))
_ods_context = _ods_get_default_loc_context(loc)
attributes["static_offsets"] = (static_offsets if (
isinstance(static_offsets, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('DenseI64ArrayAttr')) else
_ods_ir.AttrBuilder.get('DenseI64ArrayAttr')(static_offsets, context=_ods_context))
attributes["static_sizes"] = (static_sizes if (
isinstance(static_sizes, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('DenseI64ArrayAttr')) else
_ods_ir.AttrBuilder.get('DenseI64ArrayAttr')(static_sizes, context=_ods_context))
attributes["static_strides"] = (static_strides if (
isinstance(static_strides, _ods_ir.Attribute) or
not _ods_ir.AttrBuilder.contains('DenseI64ArrayAttr')) else
_ods_ir.AttrBuilder.get('DenseI64ArrayAttr')(static_strides, context=_ods_context))
_ods_successors = None
super().__init__(self.build_generic(attributes=attributes, operands=operands, results=results, successors=_ods_successors, regions=regions, loc=loc, ip=ip)) #<---- add results
Finally you can replace indexing with InsertSliceOp
diff --git a/frontend/catalyst/api_extensions/function_maps.py b/frontend/catalyst/api_extensions/function_maps.py
index c3650b05b..53420ece3 100644
--- a/frontend/catalyst/api_extensions/function_maps.py
+++ b/frontend/catalyst/api_extensions/function_maps.py
@@ -32,6 +32,7 @@ from catalyst.api_extensions.control_flow import for_loop
from catalyst.tracing.contexts import EvaluationContext
from catalyst.tracing.type_signatures import get_stripped_signature
from catalyst.utils.callables import CatalystCallable
+from catalyst.jax_primitives import tensor_insert_slice_p
## API ##
@@ -263,7 +264,8 @@ class VmapCallable(CatalystCallable):
else (batch_size, *init_result_flat[j].shape)
)
batched_result_list.append(jnp.zeros(shape=out_shape, dtype=init_result_flat[j].dtype))
- batched_result_list[j] = batched_result_list[j].at[0].set(init_result_flat[j])
+ batched_result_list[j] = tensor_insert_slice_p.bind(init_result_flat[j], batched_result_list[j], 0)
+ #batched_result_list[j] = batched_result_list[j].at[0].set(init_result_flat[j])
# Apply mapping batched_args[1:] ---> fn(args)
@for_loop(0, batch_size, 1)
@@ -279,8 +281,12 @@ class VmapCallable(CatalystCallable):
res_flat, _ = tree_flatten(res)
# Update the list of results
for j in range(num_axes_out):
- batched_result_list[j] = batched_result_list[j].at[i].set(res_flat[j])
+ batched_result_list[j] = tensor_insert_slice_p.bind(res_flat[j], batched_result_list[j], i)
return batched_result_list
For the same example:
import pennylane as qml
import catalyst
import jax.numpy as jnp
@qml.qjit
@catalyst.vmap(in_axes=0)
@qml.qnode(qml.device("lightning.qubit", wires=17))
def empty_circuit(a):
return qml.probs()
print(empty_circuit(jnp.zeros((1,))))
(env) ubuntu@touched-stingray:~/code/catalyst$ time python test.py
[[1. 0. 0. ... 0. 0. 0.]]
real 0m5.978s
user 0m6.701s
sys 0m2.971s
Instead of lowering stablehlo.scatter
to a for loop and updating one index at a time, tensor.insert_slice
gets lowered to some copies:
%21 = func.call @empty_circuit_0(%alloc_10) : (memref<f64>) -> memref<131072xf64>
%alloc_12 = memref.alloc() {alignment = 64 : i64} : memref<3x131072xf64>
memref.copy %arg2, %alloc_12 : memref<3x131072xf64> to memref<3x131072xf64>
memref.dealloc %arg2 : memref<3x131072xf64>
%subview_13 = memref.subview %alloc_12[%arg1, 0] [1, 131072] [1, 1] : memref<3x131072xf64> to memref<131072xf64, strided<[1], offset: ?>>
memref.copy %21, %subview_13 : memref<131072xf64> to memref<131072xf64, strided<[1], offset: ?>>
%alloc_14 = memref.alloc() : memref<3x131072xf64>
memref.copy %alloc_12, %alloc_14 : memref<3x131072xf64> to memref<3x131072xf64>
memref.dealloc %alloc_12 : memref<3x131072xf64>
scf.yield %alloc_14 : memref<3x131072xf64
Still could be optimized further. Not sure why it copied and then deallocates immediately instead of reusing %arg2
.
For the second example:
import pennylane as qml
import catalyst
import jax.numpy as jnp
@qml.qjit
@catalyst.vmap(in_axes=0)
@qml.qnode(qml.device("lightning.qubit", wires=15))
def empty_circuit(a):
return qml.probs()
print(empty_circuit(jnp.zeros((7,))))
(env) ubuntu@touched-stingray:~/code/catalyst$ !tim
time python test.py
[[1. 0. 0. ... 0. 0. 0.]
[1. 0. 0. ... 0. 0. 0.]
[1. 0. 0. ... 0. 0. 0.]
...
[1. 0. 0. ... 0. 0. 0.]
[1. 0. 0. ... 0. 0. 0.]
[1. 0. 0. ... 0. 0. 0.]]
real 0m6.121s
user 0m6.833s
sys 0m2.984s
So I ran the above two benchmarks again and they look massively better! 🚀
We can keep improving them :)
I mean, I think they look great! Most notably, we can now see an actual advantage of using vmap over accumulating results in a list and casting the list to an array (plot 2), whereas before vmap massively out-scaled the vmap-less approach :)
The other things of note:
It's curious that at the tail end of the first benchmark vmap appears to scale faster than no vmap. Increasing the qubit number does scale the amount of data exponentially, so if we are observing that exponential scaling it would make sense, the only question is whether the magnitude of it makes sense and why we don't see it in the no vmap case.
I think what's happening comes from the extra copy I mentioned above:
%21 = func.call @empty_circuit_0(%alloc_10) : (memref<f64>) -> memref<131072xf64>
%alloc_12 = memref.alloc() {alignment = 64 : i64} : memref<3x131072xf64>
memref.copy %arg2, %alloc_12 : memref<3x131072xf64> to memref<3x131072xf64>
memref.dealloc %arg2 : memref<3x131072xf64>
%subview_13 = memref.subview %alloc_12[%arg1, 0] [1, 131072] [1, 1] : memref<3x131072xf64> to memref<131072xf64, strided<[1], offset: ?>>
memref.copy %21, %subview_13 : memref<131072xf64> to memref<131072xf64, strided<[1], offset: ?>>
%alloc_14 = memref.alloc() : memref<3x131072xf64>
memref.copy %alloc_12, %alloc_14 : memref<3x131072xf64> to memref<3x131072xf64>
memref.dealloc %alloc_12 : memref<3x131072xf64>
scf.yield %alloc_14 : memref<3x131072xf64
%arg2 is the place where we store all information. It gets copied, and then freed. And then we write into that copy. And then we do another copy. Too many copies in my opinion. I think we are doing 3 times the amount of copying necessary.
I ran across a problem that I thought was a non-terminating program, but turns out to be severe performance degradation when using vmap on large amounts of data, such as a QNode returning
qml.probs()
on 20 qubits. This performance issue is present even when the batch dimension is 1. The following circuit starts to noticeably slow down at around the 17 qubit mark:We can compare the runtime of this circuit with and without vmap, since the amount of work should be the same (batch dimension 1):
Worse yet, this is not a constant overhead, but seems to scale with the batch dimension anyways. The following plot compares vmap to a non-vmap implementation that aggregates results in Python (
jnp.array([empty_circuit(jnp.zeros((1,))) for _ in range(n)])
), holding the qubit number at 15 and varying the batch dimensionn
:Not necessarily the cause, but looking at the IR from the example at the top (n=17 case), there are some oddities:
probs
data is generated during the hlo loweringThe plots were generated using the following code, with minor modifications for the second plot (wires=15 and python aggregation):