Qiskit / qiskit

Qiskit is an open-source SDK for working with quantum computers at the level of extended quantum circuits, operators, and primitives.
https://www.ibm.com/quantum/qiskit
Apache License 2.0
4.82k stars 2.29k forks source link

Add post-selection method to `BitArray` #12688

Open ihincks opened 3 days ago

ihincks commented 3 days ago

What should we add?

We should add a method that helps to filter shots by post-selection. The user should be able to make requests such as "I would like all of the shots where bits 15 and 18 are in the state 01".

There is a complication to do with the shaped nature of BitArray: each index of the shape might result in a different number of post selections. For example, if the original BitArray has 100 shots, the example above might result in 57 post-selections on the first index, 17 on the second, and so forth. To get around this, I propose that the post-selection method always flattens the data. Alternate proposals welcome. For example, if a data BitArray has shape (2,3,4) and num_shots=20, then the post-selection specification would first flatten into a BitArray of shape () and num_shots=2023*4, and apply the post selection there.

Proposed signature:


def postselect(self, selection: BitArray, idxs: Iterable[int]) -> BitArray:
    """Post-select this bit array based on sliced equality with a given bitstring.

    .. note::
        If this bit array contains any shape axes, it is first flattened into a long list of shots before 
        applying post-selection. This is done because :class:`~BitArray` cannot handle ragged
        numbers of shots across axes.

    Args:
        selection: A bit array containing only one bitstring, specifying what to post-select on.
        idxs: A list of length matching ``selection``, specifying which bits of ``data`` it corresponds to.

    Returns:
        A new bit array with ``shape=(), num_bits=data.num_bits, num_shots<=data.num_shots``.

    Raises:
        ValueError: If ``selection`` has more bits than :attr:`num_bits``.
        ValueError: If the lengths of ``selection`` and ``idxs`` do not match.
        ValueError: If selection contains more than one bitstring.
    """
ihincks commented 3 days ago

cc @aeddins-ibm

aeddins-ibm commented 3 days ago

I wonder if we can make this easier for the user than the proposed signature above. Having selection be a BitArray means a user will need to understand how the BitArray class works, and also what convention Sampler uses when it returns a result. (Is it right that self is almost always going to be a Sampler result?)

Can we simplify the signature to more closely match this?

"I would like all of the shots where bits 15 and 18 are in the state 01"

The values of idxs can be the classical bit index in the sampled circuit, since users are already familiar with that. This matches the precedent of BitArray.slice_bits(), line 418-419: https://github.com/Qiskit/qiskit/blob/25c054251f50871ff9ad1dd5d3a7f2de2b2436fc/qiskit/primitives/containers/bit_array.py#L408-L419

Then the ordering of values in selection can simply match the ordering of idxs. (idx[k] goes with selection[k]).

Then postselect() will be set up to take these args and correctly process a Sampler result.

Modified signature:

def postselect(self, indices: Iterable[int], selection: Iterable[bool]) -> BitArray:
    """Post-select this bit array based on sliced equality with a given bitstring.

    .. note::
        If this bit array contains any shape axes, it is first flattened into a long list of shots before 
        applying post-selection. This is done because :class:`~BitArray` cannot handle ragged
        numbers of shots across axes.

    .. note:: 

             The convention used by this method is that the index ``0`` corresponds to 
             the least-significant bit in the :attr:`~array`, or equivalently 
             the right-most bitstring entry as returned by 
             :meth:`~get_counts` or :meth:`~get_bitstrings`, etc. 

             If this bit array was produced by a sampler, then an index ``i`` corresponds to the 
             :class:`~.ClassicalRegister` location ``creg[i]``. 

    Args:
        indices: A list of the indices of the cbits on which to postselect. 

        selection: A list of bools of length matching ``indices``, with `indices[i]` corresponding to `selection[i]`. 
            Shots will be discarded unless all cbits specified by `indices` have the values given by `selection`.

    Returns:
        A new bit array with ``shape=(), num_bits=data.num_bits, num_shots<=data.num_shots``.

    Raises:
        ValueError: If ``max(indices)`` is greater than :attr:`num_bits``.
        ValueError: If the lengths of ``selection`` and ``indices`` do not match.
    """
aeddins-ibm commented 2 days ago

For reference in making a PR, I believe this function works, though it could use more testing to confirm.

It is missing the operation to flatten the BitArray.

Since it uses slice_bits, it will suffer from the time-overhead and 8x memory-overhead of unpacking/repacking the bits, until slice_bits is improved.

def postselect(self, indices: Iterable[int], selection: Iterable[bool]) -> BitArray:
    """Post-select this bit array based on sliced equality with a given bitstring.

    .. note::
        If this bit array contains any shape axes, it is first flattened into a long list of shots before 
        applying post-selection. This is done because :class:`~BitArray` cannot handle ragged
        numbers of shots across axes.

    .. note:: 

             The convention used by this method is that the index ``0`` corresponds to 
             the least-significant bit in the :attr:`~array`, or equivalently 
             the right-most bitstring entry as returned by 
             :meth:`~get_counts` or :meth:`~get_bitstrings`, etc. 

             If this bit array was produced by a sampler, then an index ``i`` corresponds to the 
             :class:`~.ClassicalRegister` location ``creg[i]``. 

    Args:
        indices: A list of the indices of the cbits on which to postselect. 

        selection: A list of bools of length matching ``indices``, with `indices[i]` corresponding to `selection[i]`. 
            Shots will be discarded unless all cbits specified by `indices` have the values given by `selection`.

    Returns:
        A new bit array with ``shape=(), num_bits=data.num_bits, num_shots<=data.num_shots``.

    Raises:
        ValueError: If ``max(indices)`` is greater than :attr:`num_bits``.
        ValueError: If the lengths of ``selection`` and ``indices`` do not match.
    """

    selection = BitArray.from_bool_array([selection], order='little')

    flat_self = ...  # TODO: make flattened (2D) copy of self to avoid ragged array errors

    return flat_self[(flat_self.slice_bits(indices).array == selection.array).all(axis=-1)]

Ian, you had asked separately about the change from_bool_array([selection::-1], order='big'). This appears to be fully equivalent to the above, by looking at the definition of from_bool_array: https://github.com/Qiskit/qiskit/blob/25c054251f50871ff9ad1dd5d3a7f2de2b2436fc/qiskit/primitives/containers/bit_array.py#L208-L210