PennyLaneAI / catalyst

A JIT compiler for hybrid quantum programs in PennyLane
https://docs.pennylane.ai/projects/catalyst
Apache License 2.0
109 stars 28 forks source link

[BUG] `grad` fails when the differentiated circuit has no trainable parameter #813

Open paul0403 opened 1 month ago

paul0403 commented 1 month ago

jax.grad (and catalyst.grad) sometimes randomly fails on a qjit-ted function containing a qnode:

dev = qml.device("lightning.qubit", wires=1)

@qjit
def cost(x):
   @qml.qnode(dev)
   def circuit(x):
       qml.PauliX(wires=0)
       #qml.RZ(0, wires=0)
       return qml.probs()
   return circuit(x)[0]

print("grad: ", jax.grad(cost)(1.1))

>>>
error: <unknown>:0:0: in function preprocess_cost.cloned void (ptr, ptr, i64, ptr, ptr, i64): Enzyme: Cannot deduce type of memset   call void @llvm.memset.p0.i64(ptr noundef nonnull align 8 dereferenceable(16) %.fca.2.gep6, i8 0, i64 16, i1 false) #6
<analysis>
  %11 = alloca { ptr, ptr, i64, [1 x i64], [1 x i64] }, align 8: {[-1]:Pointer, [-1,0]:Pointer, [-1,8]:Pointer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}, intvals: {}
ptr %0: {[-1]:Pointer}, intvals: {}
ptr %1: {[-1]:Pointer}, intvals: {}
i64 %2: {[-1]:Integer}, intvals: {}
ptr %3: {[-1]:Pointer}, intvals: {}
ptr %4: {[-1]:Pointer}, intvals: {}
i64 %5: {[-1]:Integer}, intvals: {}
  %10 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 16) #6: {[-1]:Pointer}, intvals: {}
  %8 = alloca { ptr, ptr, i64 }, align 8: {[-1]:Pointer, [-1,0]:Pointer, [-1,8]:Pointer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer}, intvals: {}
i64 2: {[-1]:Integer}, intvals: {2,}
  %.fca.1.gep4 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %9, i64 0, i32 1: {[-1]:Pointer, [-1,0]:Pointer}, intvals: {}
  %.fca.2.gep6 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %9, i64 0, i32 2: {[-1]:Pointer}, intvals: {}
  %7 = tail call ptr @_mlir_memref_to_llvm_alloc(i64 0) #6: {[-1]:Pointer}, intvals: {}
i64 1: {[-1]:Integer}, intvals: {1,}
  %9 = alloca { ptr, ptr, i64, [1 x i64], [1 x i64] }, align 8: {[-1]:Pointer, [-1,0]:Pointer, [-1,8]:Pointer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}, intvals: {}
  %.fca.1.gep = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %11, i64 0, i32 1: {[-1]:Pointer, [-1,0]:Pointer}, intvals: {}
  %.fca.3.0.gep = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %11, i64 0, i32 3, i64 0: {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer}, intvals: {}
  %.fca.4.0.gep10 = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %9, i64 0, i32 4, i64 0: {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer}, intvals: {}
  %.fca.2.gep = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %11, i64 0, i32 2: {[-1]:Pointer}, intvals: {}
i64 0: {[-1]:Anything}, intvals: {0,}
  %.fca.1.gep14 = getelementptr inbounds { ptr, ptr, i64 }, ptr %8, i64 0, i32 1: {[-1]:Pointer, [-1,0]:Pointer}, intvals: {}
  %.fca.2.gep16 = getelementptr inbounds { ptr, ptr, i64 }, ptr %8, i64 0, i32 2: {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer}, intvals: {}
  call void @circuit.quantum(ptr nonnull %8, ptr nonnull %9, ptr nonnull %11) #6: {}, intvals: {}
  %.fca.4.0.gep = getelementptr inbounds { ptr, ptr, i64, [1 x i64], [1 x i64] }, ptr %11, i64 0, i32 4, i64 0: {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer}, intvals: {}
  %12 = load double, ptr %10, align 8: {}, intvals: {}
</analysis>

however, if the circuit includes the RZ gate then the error disappears:

dev = qml.device("lightning.qubit", wires=1)

@qjit
def cost(x):
   @qml.qnode(dev)
   def circuit(x):
       qml.PauliX(wires=0)
       qml.RZ(0, wires=0)
       return qml.probs()
   return circuit(x)[0]

print("grad: ", jax.grad(cost)(1.1))

>>>
grad:  0.0

I have tried many different circuits. Currently it seems to me that whether grad succeeds or not is only a function of whether the circuit contains at least one of RX, RY, RZ . As soon as the circuit has none of them, grad will fail.

It is further suspected that this is an edge case bug for asking for the gradients where circuits have no trainable parameters.

Note: without qjit everything works as expected:

dev = qml.device("lightning.qubit", wires=1)

#@qjit
def cost(x):
   @qml.qnode(dev)
   def circuit(x):
       qml.PauliX(wires=0)
       #qml.RZ(0, wires=0)
       return qml.probs()
   return circuit(x)[0]

print("grad: ", jax.grad(cost)(1.1))

>>>
grad:  0.0