jyaacoub / MutDTA

Improving the precision oncology pipeline by providing binding affinity purtubations predictions on a pirori identified cancer driver genes.
https://drive.google.com/drive/folders/1mdiA1gf1IjPZNhk79I2cYUu6pwcH0OTD
2 stars 2 forks source link

Memory issues during dataset creation #7

Closed jyaacoub closed 1 year ago

jyaacoub commented 1 year ago

during dataset creation, we run out of memory at the collate fn call.

Not sure why but this is only an issue with shannon entropy versions of davis and the large kiba dataset...

Options:

jyaacoub commented 1 year ago

Creating generators is hard for graph style datasets:

My first naive attempt fails and is very hacky:

    def _data_generator(self, df, batch_size=32):
        start_idx = 0
        self.errors = []
        while start_idx < len(df):
            batch_data = []
            for idx in range(start_idx, min(start_idx + batch_size, len(df))):
                code = df.loc[idx]['code']
                cmap = np.load(self.cmap_p(code))
                pro_seq = df.loc[idx]['prot_seq']
                lig_seq = df.loc[idx]['SMILE']
                label = df.loc[idx]['pkd']
                label = torch.Tensor([[label]])

                _, pro_feat, pro_edge = target_to_graph(pro_seq, cmap, 
                                                        threshold=self.cmap_threshold,
                                                        aln_file=self.aln_p(code),
                                                        shannon=self.shannon)
                try:
                    _, mol_feat, mol_edge = smile_to_graph(lig_seq)
                except ValueError:
                    self.errors.append(code)
                    continue

                pro = torchg.data.Data(x=torch.Tensor(pro_feat),
                                    edge_index=torch.LongTensor(pro_edge).transpose(1, 0),
                                    y=label,
                                    code=code,
                                    prot_id=df.loc[idx]['prot_id'])
                lig = torchg.data.Data(x=torch.Tensor(mol_feat),
                                    edge_index=torch.LongTensor(mol_edge).transpose(1, 0),
                                    y=label,
                                    code=code,
                                    prot_id=df.loc[idx]['prot_id'])
                batch_data.append([pro, lig])
            start_idx += batch_size
            yield batch_data

    def process(self):
        """
        This method is used to create the processed data files after feature extraction.
        """
        if not os.path.isfile(self.processed_paths[0]):
            df = self.pre_process()
        else:
            df = pd.read_csv(self.processed_paths[0])
            print(f'{self.processed_paths[0]} file found, using it to create the dataset')
        print(f'Number of codes: {len(df)}')

        # creating the dataset:
        data_gen = self._data_generator(df, batch_size=self.batch_size_gen)
        self._data_pro = None              
        self._data_mol = None

        max_ptr_pro = 0
        max_ptr_mol = 0
        for i, batch_data in enumerate(data_gen):
            batchpro = torchg.data.Batch.from_data_list([data[0] for data in batch_data])
            batchmol = torchg.data.Batch.from_data_list([data[1] for data in batch_data])

            if self._data_pro is None:
                self._data_pro = batchpro
                self._data_mol = batchmol
                max_ptr_mol = batchmol.ptr[-1]
                max_ptr_pro = batchpro.ptr[-1]
            else:
                # adjust batch and ptr so that they are correct after concatenation:
                batchpro.batch += i*self.batch_size_gen
                batchpro.ptr = batchpro.ptr[1:] + max_ptr_pro
                max_ptr_pro = batchpro.ptr[-1]

                batchmol.batch += i*self.batch_size_gen
                batchmol.ptr = batchmol.ptr[1:] + max_ptr_mol

                max_ptr_mol = batchmol.ptr[-1]

                # concat:
                self._data_pro = torchg.data.Batch(
                    x=torch.cat([self._data_pro.x, batchpro.x], axis=0),
                    edge_index=torch.cat([self._data_pro.edge_index, batchpro.edge_index], axis=1),
                    y=torch.cat([self._data_pro.y, batchpro.y], axis=0),
                    code=self._data_pro.code + batchpro.code,
                    prot_id=self._data_pro.prot_id + batchpro.prot_id,
                    batch=torch.cat([self._data_pro.batch, batchpro.batch], axis=0),
                    ptr=torch.cat([self._data_pro.ptr, batchpro.ptr], axis=0)
                )

                self._data_mol = torchg.data.Batch(
                    x=torch.cat([self._data_mol.x, batchmol.x], axis=0),
                    edge_index=torch.cat([self._data_mol.edge_index, batchmol.edge_index], axis=1),
                    y=torch.cat([self._data_mol.y, batchmol.y], axis=0),
                    code=self._data_mol.code + batchmol.code,
                    prot_id=self._data_mol.prot_id + batchmol.prot_id,
                    batch=torch.cat([self._data_mol.batch, batchmol.batch], axis=0),
                    ptr=torch.cat([self._data_mol.ptr, batchmol.ptr], axis=0)
                )

this causes an issue when you go to index the Batch:

RuntimeError: Cannot reconstruct 'Data' object from 'Batch' because 'Batch' was not created via 'Batch.from_data_list()'

might need to consider restructuring how I am collating.

jyaacoub commented 1 year ago

Resolved... The solution is to just store protein and mol data as a dictionary and access them after indexing the df object (see BaseDataset class).

This is still an issue with the ESM embeddings but that should be handled on the model side of things for scalability reasons.