AlexanderKroll / ProSmith

MIT License
19 stars 7 forks source link

Shuffling the data between epochs #6

Closed gtrevnenski closed 4 days ago

gtrevnenski commented 1 month ago

I see that the Dataloaders have the shuffle parameter set to False and this is done because the way data loading is handled currently only allows sequential data reads. However, training the algorithm every epoch on data points in the same order is generally a bad practice. Is there another reason for that besides the convenience in terms of data loading and do you expect an increase in accuracy if the algorithm allowed for shuffling during training?

davdma commented 1 month ago

Hi @gtrevnenski, @AlexanderKroll could correct me if I am wrong but looking at the source code in the class ProteinSMILESDataset it appears that the dataset loads each subset (a given number of predefined protein and smiles pairings) and shuffles them with a new seed at every epoch (the seed being the epoch number):

In the training loop:

train_dataset = SMILESProteinDataset(
    data_path=args.train_dir,
    embed_dir=args.embed_path,
    train=True,
    device=device, 
    gpu=gpu,
    random_state=int(epoch),
    task=args.task,
    extraction_mode = False) 

Part of the dataset loading process of ProteinSMILESDataset :

def update_subset(self):
    # sections of code skipped here
    all_subset_smiles = list(self.smiles_reprs.keys())
    all_subset_sequences = list(self.protein_repr.keys())

    help_df = self.df.loc[self.df["SMILES"].isin(all_subset_smiles)].copy()
    help_df["index"] = list(help_df.index)
    help_df["Protein sequence"] = [seq[:1018] for seq in help_df["Protein sequence"]]
    help_df = help_df.loc[help_df["Protein sequence"].isin(all_subset_sequences)]
    if self.train:
        help_df = help_df.sample(frac=1, random_state=self.random_state)

So while it may not exactly be random shuffling across the entire training dataset at each epoch, there is random shuffling happening within each subset of the training dataset as it is being loaded in. That should at least guarantee that the order of protein smiles pairs seen each epoch will not be the exact same order. I think if we wanted to add more randomness to the order of datapoints, we could actually easily shuffle the order of the subsets by shuffling the list of protein dictionaries at each epoch, as the order of the dictionaries is currently always the same: self.prot_dicts = os.listdir(join(embed_dir, "Protein")).

AlexanderKroll commented 4 days ago

Thanks for the explanation, everything you said is correct. Since the files with the ESM embeddings for all proteins are so large, we cannot load them all at once during model training, but have to load them successively. I am not changing the order of these larger files (which could also be implemented), just within these larger protein chunks. Since these are large chunks of the total dataset, I don't expect there to be much difference from completely shuffling the dataset, but of course I don't know for sure. Maybe it helps a little bit to implement more data shuffling.