CardiacModelling / pcpostprocess

BSD 3-Clause "New" or "Revised" License
1 stars 0 forks source link

Refactor this code into a single loop. There's no need to store each individual ... #5

Open github-actions[bot] opened 6 months ago

github-actions[bot] commented 6 months ago

https://github.com/CardiacModelling/pcpostprocess/blob/cd60bd2dd7d7411b4373478e59869d7e8ab14688/scripts/run_herg_qc.py#L986


                     'qc6.2.subtracted']

    df = pd.DataFrame(np.array(df_rows), columns=column_labels)

    missing_wells_dfs = []
    # Add onboard qc to dataframe
    for well in args.wells:
        if well not in df['well'].values:
            onboard_qc_df = pd.DataFrame([[well] + [False for col in
                                                    list(df)[1:]]],
                                         columns=list(df))
            missing_wells_dfs.append(onboard_qc_df)
    df = pd.concat([df] + missing_wells_dfs, ignore_index=True)

    df['protocol'] = savename

    return selected_wells, df

def qc3_bookend(readname, savename, time_strs, args):
    plot_dir = os.path.join(args.output_dir, args.savedir,
                            f"{args.saveID}-{savename}-qc3-bookend")

    filepath_first_before = os.path.join(args.data_directory,
                                         f"{readname}_{time_strs[0]}")
    filepath_last_before = os.path.join(args.data_directory,
                                        f"{readname}_{time_strs[1]}")
    json_file_first_before = f"{readname}_{time_strs[0]}"
    json_file_last_before = f"{readname}_{time_strs[1]}"

    # Each Trace object contains two sweeps
    first_before_trace = Trace(filepath_first_before,
                               json_file_first_before)
    last_before_trace = Trace(filepath_last_before,
                              json_file_last_before)

    times = first_before_trace.get_times()
    voltage = first_before_trace.get_voltage()

    voltage_protocol = first_before_trace.get_voltage_protocol()
    ramp_bounds = detect_ramp_bounds(times,
                                     voltage_protocol.get_all_sections())
    filepath_first_after = os.path.join(args.data_directory,
                                        f"{readname}_{time_strs[2]}")
    filepath_last_after = os.path.join(args.data_directory,
                                        f"{readname}_{time_strs[3]}")
    json_file_first_after = f"{readname}_{time_strs[2]}"
    json_file_last_after = f"{readname}_{time_strs[3]}"

    first_after_trace = Trace(filepath_first_after,
                                json_file_first_after)
    last_after_trace = Trace(filepath_last_after,
                                json_file_last_after)

    # Ensure that all traces use the same voltage protocol
    assert np.all(first_before_trace.get_voltage() == last_before_trace.get_voltage())
    assert np.all(first_after_trace.get_voltage() == last_after_trace.get_voltage())
    assert np.all(first_before_trace.get_voltage() == first_after_trace.get_voltage())
    assert np.all(first_before_trace.get_voltage() == last_before_trace.get_voltage())

    # Ensure that the same number of sweeps were used
    assert first_before_trace.NofSweeps == last_before_trace.NofSweeps

    first_before_current_dict = first_before_trace.get_trace_sweeps()
    first_after_current_dict = first_after_trace.get_trace_sweeps()
    last_before_current_dict = last_before_trace.get_trace_sweeps()
    last_after_current_dict = last_after_trace.get_trace_sweeps()

    # Do leak subtraction and store traces for each well
    # TODO Refactor this code into a single loop. There's no need to store each individual trace.
    before_traces_first = {}
    before_traces_last = {}
    after_traces_first = {}
    after_traces_last = {}
    first_processed = {}
    last_processed = {}

    # Iterate over all wells
    for well in np.array(all_wells).flatten():
        first_before_current = first_before_current_dict[well][0, :]
        first_after_current = first_after_current_dict[well][0, :]
        last_before_current = last_before_current_dict[well][-1, :]
        last_after_current = last_after_current_dict[well][-1, :]

        before_traces_first[well] = get_leak_corrected(first_before_current,
                                                       voltage, times,
                                                       *ramp_bounds)
        before_traces_last[well] = get_leak_corrected(last_before_current,
                                                      voltage, times,
                                                      *ramp_bounds)

        after_traces_first[well] = get_leak_corrected(first_after_current,
                                                      voltage, times,
                                                      *ramp_bounds)
        after_traces_last[well] = get_leak_corrected(last_after_current,
                                                     voltage, times,
                                                     *ramp_bounds)

        # Store subtracted traces
        first_processed[well] = before_traces_first[well] - after_traces_first[well]
        last_processed[well] = before_traces_last[well] - after_traces_last[well]

    voltage_protocol = VoltageProtocol.from_voltage_trace(voltage, times)

    hergqc = hERGQC(sampling_rate=first_before_trace.sampling_rate,
                    plot_dir=plot_dir,
                    voltage=voltage)

    assert first_before_trace.NofSweeps == last_before_trace.NofSweeps

    voltage_steps = [tstart \
                     for tstart, tend, vstart, vend in
                     voltage_protocol.get_all_sections() if vend == vstart]
    res_dict = {}

    fig = plt.figure(figsize=args.figsize)
    ax = fig.subplots()
    for well in args.wells:
        trace1 = hergqc.filter_capacitive_spikes(
            first_processed[well], times, voltage_steps
        ).flatten()

        trace2 = hergqc.filter_capacitive_spikes(
            last_processed[well], times, voltage_steps
        ).flatten()

        passed = hergqc.qc3(trace1, trace2)

        res_dict[well] = passed 

        save_fname = os.path.join(args.output_dir,
                                  'debug',
                                  f"debug_{well}_{savename}",
                                  'qc3_bookend')

        ax.plot(times, trace1)
        ax.plot(times, trace2)

        fig.savefig(save_fname)
        ax.cla()

    plt.close(fig)
    return res_dict

def get_time_constant_of_first_decay(trace, times, protocol_desc, args, output_path):

    if output_path:
        if not os.path.exists(os.path.dirname(output_path)):
            os.makedirs(os.path.dirname(output_path))

    first_40mV_step_index = [i for i, line in enumerate(protocol_desc) if line[2]==40][0]

    tstart, tend, vstart, vend = protocol_desc[first_40mV_step_index + 1, :]
    assert(vstart == vend)
    assert(vstart==-120.0)

    indices = np.argwhere((times >= tstart) & (times <= tend))

    # find peak current
    peak_current = np.min(trace[indices])
    peak_index = np.argmax(np.abs(trace[indices]))
    peak_time = times[indices[peak_index]]

    indices = np.argwhere((times >= peak_time) & (times <= tend - 50))
    print(indices)

    def fit_func(x):
        a, b, c = x
        prediction = c + a * np.exp((-1.0/b) * (times[indices] - peak_time))

        return np.sum((prediction - trace[indices])**2)

    bounds =  [
        (-np.abs(trace).max()*2, np.abs(trace).max()*2),
        (1e-12, 1e4),
        (-np.abs(trace).max()*2, np.abs(trace).max()*2),
    ]

    # Repeat optimisation with different starting guesses
    x0s = [[np.random.uniform(lower_b, upper_b) for lower_b, upper_b in bounds] for i in range(100)]

    best_res = None
    for x0 in x0s:
        res = scipy.optimize.minimize(fit_func, x0=x0,
                                      bounds=bounds)
        if best_res is None:
            best_res = res
        elif res.fun < best_res.fun and res.success and res.fun != 0:
            best_res = res

    res = best_res

    if not res:
        logging.warning('finding 40mv decay timeconstant failed:' + str(res))

    if output_path and res:
        fig = plt.figure(figsize=args.figsize, constrained_layout=True)
        axs = fig.subplots(2)

        for ax in axs:
            ax.spines[['top', 'right']].set_visible(False)
            ax.set_ylabel(r'$I_\text{obs}$ (pA)')
            ax.set_xlabel(r'$t$ (ms)')

        protocol_ax, fit_ax = axs
        protocol_ax.set_title('a', fontweight='bold')
        fit_ax.set_title('b', fontweight='bold')
        fit_ax.plot(peak_time, tend-50, alpha=.5)

        a, b, c = res.x
        fit_ax.plot(times[indices], c + a * np.exp(-(1.0/b) * (times[indices] - peak_time)),
                    color='red', linestyle='--')

        res_string = r'$\tau_{40\text{mV}} = ' f"{b:.1f}" r'\text{ms}$'
        fit_ax.annotate(res_string, xy=(0.5, 0.05), xycoords='axes fraction')

        protocol_ax.plot(times, trace)
       # protocol_ax.axvspan(peak_time, tend - 50, alpha=.5, color='grey')

        fig.savefig(output_path)
        plt.close(fig)

    return res.x[1], peak_current if res else np.nan, peak_current

def detect_ramp_bounds(times, voltage_sections, ramp_no=0):
    """
    Extract the the times at the start and end of the nth ramp in the protocol.

    @param times: np.array containing the time at which each sample was taken
    @param voltage_sections 2d np.array where each row describes a segment of the protocol: (tstart, tend, vstart, end)
    @param ramp_no: the index of the ramp to select. Defaults to 0 - the first ramp

    @returns tstart, tend: the start and end times for the ramp_no+1^nth ramp
    """

    # Decouple this code from syncropatch_export

    ramps = [(tstart, tend, vstart, vend) for tstart, tend, vstart, vend
             in voltage_sections if vstart != vend]
    try:
        ramp = ramps[ramp_no]
    except IndexError:
        print(f"Requested {ramp_no+1}th ramp (ramp_no={ramp_no}),"
              " but there are only {len(ramps)} ramps")

    tstart, tend = ramp[:2]

    ramp_bounds = [np.argmax(times > tstart), np.argmax(times > tend)]
    return ramp_bounds

if __name__ == '__main__':