zhangyumeng1sjtu / EPACT

GNU General Public License v3.0
4 stars 3 forks source link

New data issue #1

Open night-chen opened 11 hours ago

night-chen commented 11 hours ago

Hi authors, I find this work useful and meaningful. I am trying to fine-tune the model on my own curated pmhc-tcr pairs. However, I encountered some issue with the kfold_data in the config. I replace the train_pos_data_path and test_data_path in the config with my own data path. However, I have no idea how to deal with the kfold_data_path. If I ignore it and run the code, the code will bump such error message:

Traceback (most recent call last):
  File "EPACT/scripts/train/train_tcr_pmhc_binding.py", line 97, in <module>
    main(args)
  File "EPACT/scripts/train/train_tcr_pmhc_binding.py", line 42, in main
    train_idx = get_epitope_idx_from_fasta(pos_dataset, f'{config.data.kfold_data_path}/train_epitope_fold_{i+1}.fasta')
  File "EPACT/utils/sampling.py", line 48, in get_epitope_idx_from_fasta
    indices = [idx for idx, data in enumerate(dataset) if data['epitope_seq'] in epitope_list]
  File "/opt/conda/envs/EPACT_env/lib/python3.10/site-packages/EPACT/utils/sampling.py", line 48, in <listcomp>
    indices = [idx for idx, data in enumerate(dataset) if data['epitope_seq'] in epitope_list]
  File "/opt/conda/envs/EPACT_env/lib/python3.10/site-packages/EPACT/dataset/data.py", line 196, in __getitem__
    data['mhc_seq'] = self.transform_mhc_allele(data['mhc_allele'])
  File "/opt/conda/envs/EPACT_env/lib/python3.10/site-packages/EPACT/dataset/data.py", line 189, in transform_mhc_allele
    assert mhc_seq is not None
AssertionError

I will very much appreciate it if you can take time and respond to this issue. Thank you so much!

zhangyumeng1sjtu commented 7 hours ago

Hi Yuchen,

Thanks for your interest in our project. Based on the default setting of EPACT, five-fold cross-validation is performed when training, so maybe you need to prepare a .fasta file containing the training epitopes and a .csv file containing validation (positive and negative) TCR-pMHC pairs for each fold in the directory of kfold_data. Please refer to the original data on Zenodo.

If you want to train the model on your customized data without cross-validation, slightly modifying the code may help.

# Configure training positive samples and validation data
train_pos_dataset = PairedTCRpMHCDataset(data_path = config.data.train_pos_data_path, ...)
val_dataset = PairedTCRpMHCDataset(data_path = <your validation data path>, ...)

# Configure training and validation data loaders
batch_converter = PairedCDR123pMHCBatchConverter(max_mhc_len = config.model.mhc_seq_len, sample_cdr3 = False)
train_loader = DataLoader(
            dataset = train_pos_dataset, batch_size = config.training.train_batch_size,
            num_workers = config.training.num_workers, shuffle = True,
            collate_fn = batch_converter
)
val_loader = DataLoader(
            dataset = val_dataset, batch_size = config.training.test_batch_size,
            num_workers = config.training.num_workers, shuffle = False,
            collate_fn = batch_converter
)

Trainer = PairedCDR123pMHCCoembeddingTrainer(config, log_dir=config.training.log_dir)
Trainer.fit(train_loader, val_loader)
night-chen commented 7 hours ago

Thank you so much for the timely response! Yes, I have figured out this and proceeded to encounter another problem. I have added my data and update the 'train_pmhc_path', 'train_pos_data_path', 'test_data_path' in the config file. It seems that the 'train_tcr_feat_path' still needs update as I am using my own data. However, the original path 'data/binding/Paired-TCR/train_paired_cdr3_seq.pt' seems to be a pytorch tensor file and I am not sure how I can obtain this feature for my own CDR3 data. Thank you very much!

zhangyumeng1sjtu commented 7 hours ago

Sorry. I forgot that preparing this tensor file is necessary to train the model from scratch. You can try the following code to generate the train_tcr_feat_path.

import pandas as pd
import torch

cdr3_data = pd.read_table('Paired-TCR/TCR-train-data.tsv')
res = []
for i in range(len(cdr3_data)):
     cdr_alpha_seq, cdr_beta_seq, pmhc_idx = cdr3_data.iloc[i, 0], cdr3_data.iloc[i, 1], cdr3_data.iloc[i, 2]
     alpha_seq_len = len(cdr_alpha_seq)
     beta_seq_len = len(cdr_beta_seq)
     pmhc_idx = torch.tensor([int(idx) for idx in pmhc_idx.split(";")])
     res.append({
         'cdr3.alpha': cdr_alpha_seq,
         'cdr3.beta': cdr_beta_seq,
         'pmhc': pmhc_idx,
         'len': [alpha_seq_len, beta_seq_len]
     })
 torch.save(res, 'Paired-TCR/train_paired_cdr3_seq.pt')