Closed jyaacoub closed 1 year ago
Creating generators is hard for graph style datasets:
My first naive attempt fails and is very hacky:
def _data_generator(self, df, batch_size=32):
start_idx = 0
self.errors = []
while start_idx < len(df):
batch_data = []
for idx in range(start_idx, min(start_idx + batch_size, len(df))):
code = df.loc[idx]['code']
cmap = np.load(self.cmap_p(code))
pro_seq = df.loc[idx]['prot_seq']
lig_seq = df.loc[idx]['SMILE']
label = df.loc[idx]['pkd']
label = torch.Tensor([[label]])
_, pro_feat, pro_edge = target_to_graph(pro_seq, cmap,
threshold=self.cmap_threshold,
aln_file=self.aln_p(code),
shannon=self.shannon)
try:
_, mol_feat, mol_edge = smile_to_graph(lig_seq)
except ValueError:
self.errors.append(code)
continue
pro = torchg.data.Data(x=torch.Tensor(pro_feat),
edge_index=torch.LongTensor(pro_edge).transpose(1, 0),
y=label,
code=code,
prot_id=df.loc[idx]['prot_id'])
lig = torchg.data.Data(x=torch.Tensor(mol_feat),
edge_index=torch.LongTensor(mol_edge).transpose(1, 0),
y=label,
code=code,
prot_id=df.loc[idx]['prot_id'])
batch_data.append([pro, lig])
start_idx += batch_size
yield batch_data
def process(self):
"""
This method is used to create the processed data files after feature extraction.
"""
if not os.path.isfile(self.processed_paths[0]):
df = self.pre_process()
else:
df = pd.read_csv(self.processed_paths[0])
print(f'{self.processed_paths[0]} file found, using it to create the dataset')
print(f'Number of codes: {len(df)}')
# creating the dataset:
data_gen = self._data_generator(df, batch_size=self.batch_size_gen)
self._data_pro = None
self._data_mol = None
max_ptr_pro = 0
max_ptr_mol = 0
for i, batch_data in enumerate(data_gen):
batchpro = torchg.data.Batch.from_data_list([data[0] for data in batch_data])
batchmol = torchg.data.Batch.from_data_list([data[1] for data in batch_data])
if self._data_pro is None:
self._data_pro = batchpro
self._data_mol = batchmol
max_ptr_mol = batchmol.ptr[-1]
max_ptr_pro = batchpro.ptr[-1]
else:
# adjust batch and ptr so that they are correct after concatenation:
batchpro.batch += i*self.batch_size_gen
batchpro.ptr = batchpro.ptr[1:] + max_ptr_pro
max_ptr_pro = batchpro.ptr[-1]
batchmol.batch += i*self.batch_size_gen
batchmol.ptr = batchmol.ptr[1:] + max_ptr_mol
max_ptr_mol = batchmol.ptr[-1]
# concat:
self._data_pro = torchg.data.Batch(
x=torch.cat([self._data_pro.x, batchpro.x], axis=0),
edge_index=torch.cat([self._data_pro.edge_index, batchpro.edge_index], axis=1),
y=torch.cat([self._data_pro.y, batchpro.y], axis=0),
code=self._data_pro.code + batchpro.code,
prot_id=self._data_pro.prot_id + batchpro.prot_id,
batch=torch.cat([self._data_pro.batch, batchpro.batch], axis=0),
ptr=torch.cat([self._data_pro.ptr, batchpro.ptr], axis=0)
)
self._data_mol = torchg.data.Batch(
x=torch.cat([self._data_mol.x, batchmol.x], axis=0),
edge_index=torch.cat([self._data_mol.edge_index, batchmol.edge_index], axis=1),
y=torch.cat([self._data_mol.y, batchmol.y], axis=0),
code=self._data_mol.code + batchmol.code,
prot_id=self._data_mol.prot_id + batchmol.prot_id,
batch=torch.cat([self._data_mol.batch, batchmol.batch], axis=0),
ptr=torch.cat([self._data_mol.ptr, batchmol.ptr], axis=0)
)
this causes an issue when you go to index the Batch:
RuntimeError: Cannot reconstruct 'Data' object from 'Batch' because 'Batch' was not created via 'Batch.from_data_list()'
might need to consider restructuring how I am collating.
Resolved...
The solution is to just store protein and mol data as a dictionary and access them after indexing the df object (see BaseDataset
class).
This is still an issue with the ESM embeddings but that should be handled on the model side of things for scalability reasons.
during dataset creation, we run out of memory at the collate fn call.
Not sure why but this is only an issue with shannon entropy versions of davis and the large kiba dataset...
Options: