AI4HealthUOL / SSSD-ECG

Repository for the paper: 'Diffusion-based Conditional ECG Generation with Structured State Space Models'
MIT License
42 stars 6 forks source link

Provide model training on synthetic data #9

Closed Skorik99 closed 6 months ago

Skorik99 commented 11 months ago

Hello! I was really inspired by your article and want to repeat the results of Table 1 for the capacity to replace real data metric. So, I'm interested in training on synthetic data and testing on real data. But I didn't get such good results (macro_auc=0.586).

For this purpose, I downloaded the synthetic dataset you provided. After unzipping, it has the following structure:

- data
    - ptbxl_test_data.npy
    - ptbxl_train_data.npy
    - ptbxl_validation_data.npy
- labels
    - ptbxl_test_labels.npy
    - ptbxl_train_labels.npy
    - ptbxl_validation_labels.npy      

The next step was to train the neural network. For this I used the recommended benchmark https://github.com/helme/ecg_ptbxl_benchmarking/tree/master. To run it on a synthetic dataset I needed to change prepare method of code.experiments.scp_experiments.SCP_Experiment class. To do this, I removed the content related to getting the train-val (up to this line) and changed it by loading synthetic data according to

synth_path = 'path/to/Dataset/

self.y_train = np.load(os.path.join(synth_path, 'labels', 'ptbxl_train_labels.npy'))
self.y_val = np.load(os.path.join(synth_path, 'labels', 'ptbxl_validation_labels.npy'))
self.y_test = np.load(os.path.join(synth_path, 'labels', 'ptbxl_test_labels.npy'))
self.X_train = np.load(os.path.join(synth_path, 'data', 'ptbxl_train_data.npy'))
self.X_val = np.load(os.path.join(synth_path, 'data', 'ptbxl_validation_data.npy'))
self.X_test = np.load(os.path.join(synth_path, 'data', 'ptbxl_test_data.npy'))

self.X_train = np.transpose(self.X_train, (0, 2, 1))
self.X_val = np.transpose(self.X_val, (0, 2, 1))
self.X_test = np.transpose(self.X_test, (0, 2, 1))

self.input_shape = self.X_train[0].shape

Next I just ran code.reproduce_results.py only for xresnet1d50 and ptbxl_all task according to the paper:

def main():

    datafolder = '../data/ptbxl/'
    datafolder_icbeb = '../data/ICBEB/'
    outputfolder = '../output/'

    models = [
        conf_fastai_xresnet1d50,
        # conf_fastai_resnet1d_wang,
        # conf_fastai_lstm,
        # conf_fastai_lstm_bidir,
        # conf_fastai_fcn_wang,
        # conf_fastai_inception1d,
        # conf_wavelet_standard_nn,
        ]

    ##########################################
    # STANDARD SCP EXPERIMENTS ON PTBXL
    ##########################################

    experiments = [
        ('exp0', 'all'),
        # ('exp1', 'diagnostic'),
        # ('exp1.1', 'subdiagnostic'),
        # ('exp1.1.1', 'superdiagnostic'),
        # ('exp2', 'form'),
        # ('exp3', 'rhythm')
       ]

    for name, task in experiments:
        e = SCP_Experiment(name, task, datafolder, outputfolder, models)
        e.prepare()
        e.perform()
        e.evaluate()

    # generate greate summary table
    utils.generate_ptbxl_summary_table()

For testing on synthetic data I got suitable results from the table (te_results.csv in outputs directory):

,macro_auc
point,0.988766041489155
mean,0.988766041489155
lower,0.988766041489155
upper,0.988766041489155

But when I use the trained model on the inference of real test data, I get metrics

,macro_auc
point,0.5860400175807828
mean,0.5860400175807828
lower,0.5860400175807828
upper,0.5860400175807828

which are significantly lower than those stated in the article and differ little from the baselines. I also include code for inference of the trained model on real data for complete reproducibility of the experiment.

from utils import utils
# model configs
from configs.fastai_configs import *
from configs.wavelet_configs import *
from models.fastai_model import fastai_model
import numpy as np
import multiprocessing
from itertools import repeat
import pandas as pd

def main():
    # Prepare data
    datafolder = '../data/ptbxl/'
    outputfolder = '../output/'
    experiment_name = 'exp0'
    sf = 100
    task = 'all'
    test_fold = 10
    n_jobs=20

    data, raw_labels = utils.load_dataset(datafolder, sf, name='ptbxl')
    labels = utils.compute_label_aggregations(raw_labels, datafolder, task)

    data, labels, Y, _ = utils.select_data(data, labels, task, min_samples=0, outputfolder='')
    input_shape = data[0].shape

    # 10th fold for testing (9th for now)
    X_test = data[labels.strat_fold == test_fold]
    y_test = Y[labels.strat_fold == test_fold]
    n_classes = y_test.shape[1]

    # Load model
    config = conf_fastai_xresnet1d50
    modelname = config['modelname']
    modeltype = config['modeltype']
    modelparams = config['parameters']
    mpath = outputfolder+experiment_name+'/models/'+modelname+'/'
    model = fastai_model(modelname, n_classes, sf, mpath, input_shape, **modelparams)

    # Pedict
    y_test_pred = model.predict(X_test)

    # Get metrics
    test_samples = np.array([range(len(y_test))])
    rpath = mpath+'results/'
    thresholds = None
    pool = multiprocessing.Pool(n_jobs)

    te_df = pd.concat(pool.starmap(utils.generate_results, zip(test_samples, repeat(y_test), repeat(y_test_pred), repeat(thresholds))))
    te_df_point = utils.generate_results(range(len(y_test)), y_test, y_test_pred, thresholds)
    te_df_result = pd.DataFrame(
        np.array([
            te_df_point.mean().values, 
            te_df.mean().values,
            te_df.quantile(0.05).values,
            te_df.quantile(0.95).values]), 
        columns=te_df.columns, 
        index=['point', 'mean', 'lower', 'upper'])

    pool.close()

    te_df_result.to_csv(rpath+'real_te_results.csv')

if __name__ == '__main__':
    main()
juanlopezcode commented 6 months ago

Hi,

Unfortunately, I'm not a maintainer/developer of the https://github.com/helme/ecg_ptbxl_benchmarking/tree/master repository. But we have released the model used for this paper under this repo: https://github.com/AI4HealthUOL/ECG-MIMIC/blob/main/src/main_ecg.py you might need to pass the numpy files/paths of the data into the TimeseriesDatasetCrops from lines 405-407

I hope this helps