MarcSerraPeralta / dem-decoders

Wrapped decoders to work with stim.DetectorErrorModel
MIT License
0 stars 0 forks source link

Performance in the presence of gauge detectors? #7

Open MarcSerraPeralta opened 2 months ago

MarcSerraPeralta commented 2 months ago

Check if the performance of bp_osd is affected when there are gauge detectors in the circuit (especially for BP_OSD). If the performance decreases, one should remove them or warn the user. Preferably the second option because if one removes the gauge detectors, then the number of error mechanisms changes, so decode_to_faults_array will not work correctly.

MarcSerraPeralta commented 2 months ago

The current function util.remove_gauge_detectors only works for DEMs. I believe a useful functionality would be to remove the gauge detectors from a stim.Circuit. This can be achieved by (1) getting the stim.DetectorErrorModel, (2) search for the detectors that are gauge detectors, (3) remove these detectors from the circuit.

MarcSerraPeralta commented 2 months ago
DISTANCE = 3

layout = rot_surf_code(DISTANCE)
qubit_inds = {q: layout.get_inds([q])[0] for q in layout.get_qubits()}

setup = Setup.from_yaml("config/circ_level_noise.yaml")
setup.set_var_param("prob", 1e-4)
setup.set_var_param("add_assign_error", True)

model = CircuitNoiseModel(setup, qubit_inds=qubit_inds)

circuit = memory_experiment(
    model=model, 
    layout=layout, 
    num_rounds=DISTANCE, 
    data_init=[0,]*DISTANCE,
    meas_reset=True, 
    rot_basis=False,
)

# 20 million samples

image

For DISTANCE = 7 and prob=1e-6, I get logical error probability of 0 (no errors happened in the 1 million samples I took):

image

MarcSerraPeralta commented 2 months ago

To understand the drop in performance for the cases where there are no gauge detectors (case of Surface-17 with p=1e-4 and num_shots=1e6):

image

image

One can see that the performance increases heavily by just increasing the o/lsd_order by 1. Increasing the number of BP iterations does not seem to improve a lot the performance. However, this is only for distance 3 and I haven't done a grid scan on both the order and max_number of iterations.

MarcSerraPeralta commented 2 months ago

For the distance 5 with p=1e-4 and num_shots=1e6, we get:

image

Note: for max_iter=0, the runtime was too long to run in this type of benchmark.

image

We see that a little bit of order improve the performance a lot (e.g. factor of 10x). The order method is CS (because it is know that it is better).

MarcSerraPeralta commented 2 months ago

To study these more automatically, I use:

# INPUTS
max_iters = [0, 1, 10, 100, 1_000]
osd_orders = [0, 1, 2, 4, 8, 16, 32]
max_time_scan = 3 * 3600 #s
data = [detectors_no_gauge, log_flips_no_gauge]
DECODER = lambda max_iter, osd_order: BP_OSD(dem_no_gauge, max_iter=max_iter, osd_order=osd_order, osd_method="OSD_CS")
max_n_fails = np.inf

OUTPUT_NAME = f"output/BPOSD_d{DISTANCE}_p{PROB}_nogaugedets_maxtime{max_time_scan}_maxfails{max_n_fails}".replace(".", "-") + ".npy"

# RUNNING SCAN
t_init_estimation = time.time()
num_points = len(max_iters) * len(osd_orders)
output = {}

for max_iter in max_iters:
    for osd_order in osd_orders:
        print(f"{max_iter=} {osd_order=}")
        t_init_point = time.time()
        decoder = DECODER(max_iter, osd_order)

        # compute available time
        elapsed_time = t_init_point - t_init_estimation
        remaining_points = num_points - len(output)
        max_duration = (max_time_scan - elapsed_time) / remaining_points

        # sample data
        runtime_decoder, n_total, n_fail = [], 0, 0
        while (time.time() - t_init_point) < max_duration and n_fail < max_n_fails and n_total < len(data[0]):
            detector_vec, log_flip = data[0][n_total], data[1][n_total]

            t0 = time.time()
            prediction = decoder.decode(detector_vec)
            runtime = time.time() - t0

            n_total += 1
            n_fail += (prediction != log_flip).any()
            runtime_decoder.append(runtime)

            print(f"\r    p={n_fail/n_total:0.9f} t={np.average(runtime_decoder):0.6f}s n_fail={n_fail} n_total={n_total}", end="")
        print("")

        # store data
        output[(max_iter, osd_order)] = [n_fail, n_total, runtime_decoder]

np.save(OUTPUT_NAME, output)
print(f"\nTOTAL SCAN TIME: {time.time() - t_init_estimation:0f}s")