SarielMa / Robust_DNN_for_ECG

10 stars 2 forks source link

About code loss in attacking CPSC2018 database using SAP #1

Closed gongchengguangxue closed 1 year ago

gongchengguangxue commented 1 year ago

Hello, reducing the noise to signal ratio to improve the robustness of DNN is a good work. I'm also doing work related to adversarial training of ECG signals and I would like to get the code related to attacking the missing CPSC2018 database using SAP for comparison experiments.The specific missing code is in CPSC2018/core/Evaluate_mask.py.Hope to get your response.

def test_adv(model, device, dataloader, num_classes, noise_norm, norm_type, max_iter, step, method,
             targeted=False, clip_X_min=0, clip_X_max=1,
             stop_if_label_change=True, use_optimizer=False, pgd_loss_fn=None, num_repeats=0,
             save_model_output=False, class_balanced_acc=False):
    model.eval()#set model to evaluation mode
    confusion_clean=np.zeros((num_classes,num_classes))
    confusion_noisy=np.zeros((num_classes,num_classes))
    sample_count=0
    adv_sample_count=0
    sample_idx_wrong=[]
    sample_idx_attack=[]
    if save_model_output == True:
        y_list=[]
        z_list=[]
        yp_list=[]
        adv_z_list=[]
        adv_yp_list=[]
    #---------------------
    print('testing robustness wba ', method, '(', num_repeats, ')', sep='')
    print('norm_type:', norm_type, ', noise_norm:', noise_norm, ', max_iter:', max_iter, ', step:', step, sep='')
    pgd_loss_fn=get_pgd_loss_fn_by_name(pgd_loss_fn)
    print('pgd_loss_fn', pgd_loss_fn)
    #---------------------
    for batch_idx, batch_data in enumerate(dataloader):
        X, Y = batch_data[0].to(device), batch_data[1].to(device)
        Mask = batch_data[2].to(device)
        #------------------
        Z = model(X, Mask)#classify the 'clean' signal X
        if len(Z.size()) <= 1:
            Yp = (Z>0).to(torch.int64) #binary/sigmoid
        else:
            Yp = Z.data.max(dim=1)[1] #multiclass/softmax
        #------------------
        if method == 'ifgsm':
            Xn, Zn, Ypn = ifgsm_attack(model, X, Y, Mask, noise_norm=noise_norm, norm_type=norm_type,
                                       max_iter=max_iter, step=step, targeted=targeted,
                                       clip_X_min=clip_X_min, clip_X_max=clip_X_max,
                                       stop_if_label_change=stop_if_label_change,
                                       use_optimizer=use_optimizer, loss_fn=pgd_loss_fn, return_output=True)
        elif method == 'pgd':
            Xn, Zn, Ypn = repeated_pgd_attack(model, X, Y, Mask, noise_norm=noise_norm, norm_type=norm_type,
                                              max_iter=max_iter, step=step, targeted=targeted,
                                              clip_X_min=clip_X_min, clip_X_max=clip_X_max,
                                              stop_if_label_change=stop_if_label_change,
                                              use_optimizer=use_optimizer, loss_fn=pgd_loss_fn,
                                              return_output=True, num_repeats=num_repeats)
        else:
            raise NotImplementedError("other method is not implemented.")
SarielMa commented 1 year ago

Hello, reducing the noise to signal ratio to improve the robustness of DNN is a good work. I'm also doing work related to adversarial training of ECG signals and I would like to get the code related to attacking the missing CPSC2018 database using SAP for comparison experiments.The specific missing code is in CPSC2018/core/Evaluate_mask.py.Hope to get your response.

def test_adv(model, device, dataloader, num_classes, noise_norm, norm_type, max_iter, step, method,
             targeted=False, clip_X_min=0, clip_X_max=1,
             stop_if_label_change=True, use_optimizer=False, pgd_loss_fn=None, num_repeats=0,
             save_model_output=False, class_balanced_acc=False):
    model.eval()#set model to evaluation mode
    confusion_clean=np.zeros((num_classes,num_classes))
    confusion_noisy=np.zeros((num_classes,num_classes))
    sample_count=0
    adv_sample_count=0
    sample_idx_wrong=[]
    sample_idx_attack=[]
    if save_model_output == True:
        y_list=[]
        z_list=[]
        yp_list=[]
        adv_z_list=[]
        adv_yp_list=[]
    #---------------------
    print('testing robustness wba ', method, '(', num_repeats, ')', sep='')
    print('norm_type:', norm_type, ', noise_norm:', noise_norm, ', max_iter:', max_iter, ', step:', step, sep='')
    pgd_loss_fn=get_pgd_loss_fn_by_name(pgd_loss_fn)
    print('pgd_loss_fn', pgd_loss_fn)
    #---------------------
    for batch_idx, batch_data in enumerate(dataloader):
        X, Y = batch_data[0].to(device), batch_data[1].to(device)
        Mask = batch_data[2].to(device)
        #------------------
        Z = model(X, Mask)#classify the 'clean' signal X
        if len(Z.size()) <= 1:
            Yp = (Z>0).to(torch.int64) #binary/sigmoid
        else:
            Yp = Z.data.max(dim=1)[1] #multiclass/softmax
        #------------------
        if method == 'ifgsm':
            Xn, Zn, Ypn = ifgsm_attack(model, X, Y, Mask, noise_norm=noise_norm, norm_type=norm_type,
                                       max_iter=max_iter, step=step, targeted=targeted,
                                       clip_X_min=clip_X_min, clip_X_max=clip_X_max,
                                       stop_if_label_change=stop_if_label_change,
                                       use_optimizer=use_optimizer, loss_fn=pgd_loss_fn, return_output=True)
        elif method == 'pgd':
            Xn, Zn, Ypn = repeated_pgd_attack(model, X, Y, Mask, noise_norm=noise_norm, norm_type=norm_type,
                                              max_iter=max_iter, step=step, targeted=targeted,
                                              clip_X_min=clip_X_min, clip_X_max=clip_X_max,
                                              stop_if_label_change=stop_if_label_change,
                                              use_optimizer=use_optimizer, loss_fn=pgd_loss_fn,
                                              return_output=True, num_repeats=num_repeats)
        else:
            raise NotImplementedError("other method is not implemented.")

Hi! Thanks for your feedback. I have already added the missing SAP for CPSC2018. Should you have any questions, please feel free to let me know.

gongchengguangxue commented 1 year ago

Thanks for your work, it has helped me a lot.