quantumgizmos / ldpc

Software for decoding classical and quantum codes
MIT License
80 stars 28 forks source link

`BpOsdDecoder` in v2 is slower #35

Closed inmzhang closed 8 months ago

inmzhang commented 8 months ago

To test BpOsdDecoder in ldpc_v2 branch, I sampled and decoded some surface code circuits with sinter.

For the baseline, I used stimbposd.sinter_decoders()(which internally used ldpc.bposd_decoder in v1) to decode the samplings. For v2, awared of #32 , I added the SinterCompiledBpOsdDecoder as below to use with sinter:

import stim
import numpy as np
import pathlib
from ldpc.bposd_decoder import BpOsdDecoder
import sinter
from beliefmatching import detector_error_model_to_check_matrices

MAX_BP_ITERS = 30
BP_METHOD = "ps"
OSD_ORDER = 60

class SinterCompiledBpOsdDecoder(sinter.CompiledDecoder):
    def __init__(self, bposd: BpOsdDecoder, num_dets: int, observables_matrix: np.ndarray):
        self.bposd = bposd
        self.num_dets = num_dets
        self.observables_matrix = observables_matrix

    def decode_shots_bit_packed(self, *, bit_packed_detection_event_data: np.ndarray) -> np.ndarray:
        shots = np.unpackbits(bit_packed_detection_event_data, axis=1, count=self.num_dets, bitorder="little")
        corrs = np.apply_along_axis(self.bposd.decode, axis=1, arr=shots)
        predictions = (corrs @ self.observables_matrix.T) % 2
        return np.packbits(predictions, axis=1, bitorder="little")

class SinterBpOsdDecoder(sinter.Decoder):
    def __init__(
        self,
        max_iter=0,
        bp_method="ms",
        ms_scaling_factor=0.625,
        schedule="parallel",
        omp_thread_count=1,
        serial_schedule_order=None,
        osd_method="osd0",
        osd_order=0,
    ):
        self.max_iter = max_iter
        self.bp_method = bp_method
        self.ms_scaling_factor = ms_scaling_factor
        self.schedule = schedule
        self.omp_thread_count = omp_thread_count
        self.serial_schedule_order = serial_schedule_order
        self.osd_method = osd_method
        self.osd_order = osd_order

    def compile_decoder_for_dem(
        self,
        *,
        dem: stim.DetectorErrorModel,
    ) -> sinter.CompiledDecoder:
        check_matrices = detector_error_model_to_check_matrices(dem)
        bposd = BpOsdDecoder(
            check_matrices.check_matrix,
            error_channel=list(check_matrices.priors),
            max_iter=self.max_iter,
            bp_method=self.bp_method,
            ms_scaling_factor=self.ms_scaling_factor,
            schedule=self.schedule,
            omp_thread_count=self.omp_thread_count,
            serial_schedule_order=self.serial_schedule_order,
            osd_method=self.osd_method,
            osd_order=self.osd_order,
        )
        return SinterCompiledBpOsdDecoder(
            bposd, num_dets=dem.num_detectors, observables_matrix=check_matrices.observables_matrix
        )

    def decode_via_files(
        self,
        *,
        num_shots: int,
        num_dets: int,
        num_obs: int,
        dem_path: pathlib.Path,
        dets_b8_in_path: pathlib.Path,
        obs_predictions_b8_out_path: pathlib.Path,
        tmp_dir: pathlib.Path,
    ) -> None:
        self.dem = stim.DetectorErrorModel.from_file(dem_path)
        self.matrices = detector_error_model_to_check_matrices(self.dem)
        self.bposd = BpOsdDecoder(
            self.matrices.check_matrix,
            error_channel=list(self.matrices.priors),
            max_iter=self.max_iter,
            bp_method=self.bp_method,
            ms_scaling_factor=self.ms_scaling_factor,
            schedule=self.schedule,
            omp_thread_count=self.omp_thread_count,
            serial_schedule_order=self.serial_schedule_order,
            osd_method=self.osd_method,
            osd_order=self.osd_order,
        )

        shots = stim.read_shot_data_file(path=dets_b8_in_path, format="b8", num_detectors=num_dets)
        predictions = np.zeros((num_shots, num_obs), dtype=bool)
        for i in range(num_shots):
            predictions[i, :] = self.decode(shots[i, :])

        stim.write_shot_data_file(
            data=predictions,
            path=obs_predictions_b8_out_path,
            format="b8",
            num_observables=num_obs,
        )

    def decode(self, syndrome: np.ndarray) -> np.ndarray:
        corr = self.bposd.decode(syndrome)
        return (self.matrices.observables_matrix @ corr) % 2

def bposd_sinter_decoder():
    return {
        "bposdv2": SinterBpOsdDecoder(
            max_iter=MAX_BP_ITERS,
            bp_method=BP_METHOD,
            osd_method="OSD_CS",
            osd_order=OSD_ORDER,
        )
    }

The MAX_BP_ITERS/BP_METHOD/OSD_ORDER were selected to be consistent with the default ones in stimbposd.

Then, I ran the sampling with the following script:

import stim
import sinter
import stimbposd

import bposdtest # where I defined the sinter decoder for bposd v2

def generate_example_tasks():
    for p in [0.004, 0.006, 0.008, 0.01, 0.012]:
        for d in [3, 5, 7]:
            circuit = stim.Circuit.generated(
                code_task="surface_code:rotated_memory_x",
                rounds=d,
                distance=d,
                after_clifford_depolarization=p,
                after_reset_flip_probability=p,
                before_measure_flip_probability=p,
                before_round_data_depolarization=p,
            )
            yield sinter.Task(
                circuit=circuit,
                json_metadata={"p": p, "d": d},
            )

if __name__ == "__main__":
    sinter.collect(
        num_workers=120, # I ran the simulation on a server
        max_shots=500_000,
        max_errors=500,
        tasks=generate_example_tasks(),
        # decoders=["bposd"],
        # custom_decoders=stimbposd.sinter_decoders(),
        decoders=["bposd2"],
        custom_decoders=bposdtest.bposd_sinter_decoder(),
        print_progress=True,
        save_resume_filepath="bposd.csv",
    )

I compared the decoding accuracy and time per shot with v1/v2 implementations, and here is the result:

图片 图片

We can clearly see for the surface code with distance > 3, the decoding time of v2 implementation is slower than v1. I'm not sure whether all the conditions/arguments were equally set in this benchmark for v1 and v2 and I might missed something important.

The environment I used for the test is Ubuntu22.04LTS and Python3.11.4.

quantumgizmos commented 8 months ago

Hi. Could you try setting the bp_method='ps_log' for the simulations in v1? I think the speed of the two versions should be more comparable once this change has been made.

The bp_method='ps' method uses a version of BP where the messages are computed using probability ratios, rather than log probability ratios. This is quicker, as you can avoid computing the archtanh(x) function need for the LLR version. I haven't implemented the optimised version of product sum for V2, but it is on my todo list.

Cheers, Joschka

inmzhang commented 8 months ago

Thanks for your reply. After I set bp_method='ps_log' for v1, the speed of the two version is close.