OSU-BMBL / scDEAL

Deep Transfer Learning of Drug Sensitivity by Integrating Bulk and Single-cell RNA-seq data
Apache License 2.0
46 stars 11 forks source link

Ask about the bulk data prediction #15

Open geshuang307 opened 9 months ago

geshuang307 commented 9 months ago

Hi, thanks for your time. Although I can get the desirable result in scrna dataset. I am confusing about the bulk prediction result. I have noticed the result you showed in Supplementary Table S3. For bulk data, I run the bulkmodel.py with checkpoint = 'save/bulk_pre/org/integrate_data_GSE110894_drug_I.BET.762_bottle_512_edim_256,128_pdim_128,64_model_DAE_dropout_0.1_gene_F_lr_0.5_mod_new_sam_upsampling' on testsets of the bulk data. I can't get the desirable result as yours. I only got the rascore =0.6259416966917786, f1score=0.8001726450715422, apscore = 0.30563990307708294, I want to know how to predict on bulk test data correctly and could get desirable results as yours. For reading easily, I rewrite the code used in bulk prediction process. My code is on the attachment. [Uploading predict_bulk_org.py…]()

geshuang307 commented 9 months ago

import argparse import logging import sys import time import warnings import os import numpy as np import pandas as pd import torch from scipy.stats import pearsonr from sklearn import preprocessing from sklearn.dummy import DummyClassifier from sklearn.metrics import (average_precision_score, classification_report, mean_squared_error, r2_score, roc_auc_score, confusion_matrix) from sklearn.model_selection import train_test_split from sklearn.preprocessing import LabelEncoder from torch import nn, optim from torch.optim import lr_scheduler from torch.utils.data import DataLoader, TensorDataset from sklearn.decomposition import PCA import sys sys.path.append('./') sys.path.append('../') import sampling as sam import utils as ut import trainers as t from models import (AEBase,PretrainedPredictor, PretrainedVAEPredictor, VAEBase) import matplotlib import random import json from torch.utils.data import DataLoader, TensorDataset from models import (AEBase, DaNN, PretrainedPredictor, PretrainedVAEPredictor, VAEBase) seed=42 torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.cuda.manual_seed_all(seed) random.seed(seed) np.random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark=False device = 'cpu' para = 'integrate_data_GSE110894_drug_I.BET.762_bottle_512_edim_256,128_pdim_128,64_model_DAE_dropout_0.1_gene_F_lr_0.5_mod_new_sam_upsampling'

latent_dim_ge=128 model = PretrainedPredictor(input_dim=15962,latent_dim=latent_dim_ge,h_dims=[256,128], hidden_dims_predictor=[128,64],output_dim=2, pretrained_weights=None,freezed=bool(0),drop_out=0.1,drop_out_predictor=0.1) model.to(device) model.load_state_dict(torch.load('save/bulk_pre/' + para)) model.eval()

data_path = 'data/ALL_expression.csv' label_path = 'data/ALL_label_binary_wf.csv' data_r=pd.read_csv(data_path,index_col=0) label_r=pd.read_csv(label_path,index_col=0) print("two databases combine") label_r=label_r.fillna(1)

selected_idx = label_r.loc[:,'I.BET.762']!=1

data = data_r.loc[selected_idx,:]

label = label_r.loc[selected_idx,'I.BET.762'] data_r = data_r.loc[selected_idx,:]

mmscaler = preprocessing.MinMaxScaler()

data = mmscaler.fit_transform(data) label = label.values.reshape(-1,1)

le = LabelEncoder() label = le.fit_transform(label)

X_train_all, X_test, Y_train_all, Y_test = train_test_split(data, label, test_size=0.2, random_state=42) X_train, X_valid, Y_train, Y_valid = train_test_split(X_train_all, Y_train_all, test_size=0.2, random_state=42) X_train,Y_train=sam.upsampling(X_train,Y_train) X_testTensor = torch.FloatTensor(X_test).to(device)

predictions = model(X_testTensor)

predictions = predictions.detach().cpu().numpy() print("predictions",predictions.shape) dict = {} dict["sens_preds"] = predictions[:,1] dict["sens_label"] = predictions.argmax(axis=1)

dict["rest_preds"] = predictions[:,0] dict['sensitive'] = Y_test
sens_pb_results = dict['sens_preds']
lb_results = dict['sens_label']
tn, fp, fn, tp = confusion_matrix(Y_test, lb_results).ravel()
precision = tp/(tp+fp) recall = tp/(tp+fn) ap_score = average_precision_score(Y_test, sens_pb_results)
ra_score = roc_auc_score(Y_test, sens_pb_results) report_dict = classification_report(Y_test, lb_results, output_dict=True) f1score = report_dict['weighted avg']['f1-score']

now = time.strftime("%Y-%m-%d-%H-%M-%S") print(now) file = 'save/logs/'+ 'drug_GSE110894' + now + para + '.txt' with open(file, 'w') as f: f.writelines(para+'\t'+'f1score'+str(f1score)+'\n') f.writelines(para+'\t'+'apscore'+str(ap_score)+'\n') f.writelines(para+'\t'+'rascore'+str(ra_score)+'\n') f.writelines(para+'\t'+'precision'+str(precision)+'\n') f.writelines(para+'\t'+'recall'+str(recall)+'\n')

    f.close()

if name == 'main':

parser = argparse.ArgumentParser()

parser.add_argument('--data', type=str, default='data/ALL_expression.csv',help='Path of the bulk RNA-Seq expression profile')
parser.add_argument('--label', type=str, default='data/ALL_label_binary_wf.csv',help='Path of the processed bulk RNA-Seq drug screening annotation')
parser.add_argument('--result', type=str, default='save/results/result_',help='Path of the training result report files')
parser.add_argument('--drug', type=str, default='I.BET.762',help='Name of the selected drug, should be a column name in the input file of --label')
parser.add_argument('--missing_value', type=int, default=1,help='The value filled in the missing entry in the drug screening annotation, default: 1')
parser.add_argument('--test_size', type=float, default=0.2,help='Size of the test set for the bulk model traning, default: 0.2')
parser.add_argument('--valid_size', type=float, default=0.2,help='Size of the validation set for the bulk model traning, default: 0.2')
parser.add_argument('--var_genes_disp', type=float, default=None,help='Dispersion of highly variable genes selection when pre-processing the data. \
                     If None, all genes will be selected .default: None')
parser.add_argument('--sampling', type=str, default='upsampling',help='Samping method of training data for the bulk model traning. \
                    Can be upsampling, downsampling, or SMOTE. default: no')
parser.add_argument('--PCA_dim', type=int, default=0,help='Number of components of PCA reduction before training. If 0, no PCA will be performed. Default: 0')

parser.add_argument('--device', type=str, default="cpu",help='Device to train the model. Can be cpu or gpu. Deafult: cpu')
parser.add_argument('--bulk_encoder','-e', type=str, default='save/bulk_encoder/',help='Path of the pre-trained encoder in the bulk level')
parser.add_argument('--pretrain', type=str, default="True",help='Whether to perform pre-training of the encoder,str. False: do not pretraing, True: pretrain. Default: True')
parser.add_argument('--lr', type=float, default=0.5,help='Learning rate of model training. Default: 1e-2')
parser.add_argument('--epochs', type=int, default=500,help='Number of epoches training. Default: 500')
parser.add_argument('--batch_size', type=int, default=200,help='Number of batch size when training. Default: 200')
parser.add_argument('--bottleneck', type=int, default=512,help='Size of the bottleneck layer of the model. Default: 32')
parser.add_argument('--dimreduce', type=str, default="DAE",help='Encoder model type. Can be AE or VAE. Default: AE')
parser.add_argument('--freeze_pretrain', type=int, default=0,help='Fix the prarmeters in the pretrained model. 0: do not freeze, 1: freeze. Default: 0')
parser.add_argument('--encoder_h_dims', type=str, default="256,128",help='Shape of the encoder. Each number represent the number of neuron in a layer. \
                    Layers are seperated by a comma. Default: 512,256')
parser.add_argument('--predictor_h_dims', type=str, default="128,64",help='Shape of the predictor. Each number represent the number of neuron in a layer. \
                    Layers are seperated by a comma. Default: 16,8')
parser.add_argument('--VAErepram', type=int, default=1)
parser.add_argument('--data_name', type=str, default="GSE110894",help='Accession id for testing data, only support pre-built data.')
parser.add_argument('--checkpoint', type=str, default="save/bulk_pre/integrate_data_GSE110894_drug_I.BET.762_bottle_512_edim_256,128_pdim_128,64_model_DAE_dropout_0.1_gene_F_lr_0.5_mod_new_sam_upsampling",
                    help='Load weight from checkpoint files, can be True,False, or file path. Checkpoint files can be paraName1_para1_paraName2_para2... Default: True')

parser.add_argument('--bulk_model', '-p',  type=str, default='save/bulk_pre/',help='Path of the trained prediction model in the bulk level')
parser.add_argument('--log', '-l',  type=str, default='save/logs/',help='Path of training log')
parser.add_argument('--load_source_model',  type=int, default=0,help='Load a trained bulk level or not. 0: do not load, 1: load. Default: 0')
parser.add_argument('--mod', type=str, default="new",help='Embed the cell type label to regularized the training: new: add cell type info, ori: do not add cell type info. Default: new')                     # 嵌入细胞类型标签,使训练规范化; 新增:添加细胞类型信息
parser.add_argument('--printgene', type=str, default='F',help='Print the cirtical gene list: T: print. Default: T')
parser.add_argument('--dropout', type=float, default=0.1,help='Dropout of neural network. Default: 0.3')
parser.add_argument('--bulk', type=str, default='integrate',help='Selection of the bulk database.integrate:both dataset. old: GDSC. new: CCLE. Default: integrate')                                            # old: GDSC. new: CCLE.
parser.add_argument('--fix_source', type=int, default=0,help='Fix the bulk level model. Default: 0')
warnings.filterwarnings("ignore")

args = parser.parse_args()