QuantumApplicationLab / vqls-prototype

A Variational Quantum Linear Solver Prototype for Qiskit
Apache License 2.0
12 stars 3 forks source link

CircuitError: "invalid param type <class 'numpy.ndarray'> for instruction state_preparation" #5

Closed dmark04 closed 9 months ago

dmark04 commented 1 year ago

What is the expected enhancement?

I tried executing the following very basic example for VQLS:

import numpy as np
from qiskit.primitives import Estimator, Sampler
from qiskit_algorithms.optimizers import COBYLA
from qiskit.circuit.library.n_local.real_amplitudes import RealAmplitudes
from vqls_prototype import VQLS, VQLSLog

if __name__ == "__main__":
    N = 1
    matrix_a = np.random.rand(2**N, 2**N)
    matrix_a += matrix_a.T
    vector_b = np.random.rand(2**N, 1)

    ansatz = RealAmplitudes(num_qubits=int(np.log2(matrix_a.shape[0])), entanglement="full", reps=3,
                            insert_barriers=False)

    # if vector_b.ndim == 2:
    #     vector_b = vector_b.flatten()

    log = VQLSLog([], [])
    estimator = Estimator()
    sampler = Sampler()
    vqls = VQLS(
        estimator,
        ansatz,
        COBYLA(maxiter=250, disp=True),
        sampler=sampler,
        callback=log.update,
    )
    opt = {"use_overlap_test": False, "use_local_cost_function": False}
    res = vqls.solve(matrix_a, vector_b, opt)

This results in the following traceback:

Traceback (most recent call last):
  File "/home/davidm/Projects/qc-quantum-linear-systems/quantum_linear_systems/minimal_test.py", line 35, in <module>
    res = vqls.solve(matrix_a, vector_b, opt)
  File "/home/davidm/Projects/venvs/qls3.10/lib/python3.10/site-packages/vqls_prototype/vqls.py", line 807, in solve
    hdmr_tests_norm, hdmr_tests_overlap = self.construct_circuit(
  File "/home/davidm/Projects/venvs/qls3.10/lib/python3.10/site-packages/vqls_prototype/vqls.py", line 323, in construct_circuit
    self.vector_circuit.prepare_state(vector / vec_norm)
  File "/home/davidm/Projects/venvs/qls3.10/lib/python3.10/site-packages/qiskit/circuit/library/data_preparation/state_preparation.py", line 521, in prepare_state
    StatePreparation(state, num_qubits, label=label, normalize=normalize), qubits
  File "/home/davidm/Projects/venvs/qls3.10/lib/python3.10/site-packages/qiskit/circuit/library/data_preparation/state_preparation.py", line 112, in __init__
    super().__init__(self._name, num_qubits, params, label=self._label)
  File "/home/davidm/Projects/venvs/qls3.10/lib/python3.10/site-packages/qiskit/circuit/gate.py", line 37, in __init__
    super().__init__(name, num_qubits, 0, params, label=label)
  File "/home/davidm/Projects/venvs/qls3.10/lib/python3.10/site-packages/qiskit/circuit/instruction.py", line 105, in __init__
    self.params = params  # must be at last (other properties may be required for validation)
  File "/home/davidm/Projects/venvs/qls3.10/lib/python3.10/site-packages/qiskit/circuit/instruction.py", line 223, in params
    self._params.append(self.validate_parameter(single_param))
  File "/home/davidm/Projects/venvs/qls3.10/lib/python3.10/site-packages/qiskit/circuit/library/data_preparation/state_preparation.py", line 250, in validate_parameter
    raise CircuitError(f"invalid param type {type(parameter)} for instruction  {self.name}")
qiskit.circuit.exceptions.CircuitError: "invalid param type <class 'numpy.ndarray'> for instruction  state_preparation"

Which looks very confusing.

However, this is only due to the rank of the numpy array vector_b. (aka vector_b being passed as an ndarray of shape (m,1) instead of shape (m,)). Consequently just by flattening the input vector (see the commented out lines in the example above), the code runs just fine.

Therefore, I'd like to suggest just integrating a test like that, for example around line 314 in vqls.py in the function construct_circuit

        elif isinstance(vector, np.ndarray):
            # ensure the vector is double
            vector = vector.astype("float64")

adding a check that the vector is of the correct rank by checking if vector.ndim == 2: vector = vector.flatten()

NicoRenaud commented 9 months ago

Thanks ! It took me a while but this is now fixed. Thanks for the input It will be merged as part of #9 so i will close this issue