jonathanking / sidechainnet

An all-atom protein structure dataset for machine learning.
BSD 3-Clause "New" or "Revised" License
330 stars 38 forks source link

error value in batch #65

Open omarkhaled-28 opened 8 months ago

omarkhaled-28 commented 8 months ago

it gives error when trying to put traindata into batch `batch = next(iter(dataloader['train'])) print("Protein IDs\n ", batch.ids) print("Sequences\n ", batch.seqs.shape) print("Evolutionary Data\n ", batch.evolutionary.shape) print("Secondary Structure\n ", batch.secondary.shape) print("Angle Data\n ", batch.angles.shape) print("Coordinate Data\n ", batch.coords.shape) print("X-ray Resolution\n ", batch.resolutions) print("Integer sequence") print("\tShape:", batch.seqs_int.shape) print("\tEx:", batch.seqs_int[0,:3])

print("1-hot sequence") print("\tShape:", batch.seqs.shape) print("\tEx:\n", batch.seqs[0,:3]) this is the output once i run it --------------------------------------------------------------------------- ValueError Traceback (most recent call last) in <cell line: 1>() ----> 1 batch = next(iter(dataloader['train'])) 2 print("Protein IDs\n ", batch.ids) 3 print("Sequences\n ", batch.seqs.shape) 4 print("Evolutionary Data\n ", batch.evolutionary.shape) 5 print("Secondary Structure\n ", batch.secondary.shape)

3 frames /usr/local/lib/python3.10/dist-packages/torch/_utils.py in reraise(self) 692 # instantiate since we don't know how to 693 raise RuntimeError(msg) from None --> 694 raise exception 695 696

ValueError: Caught ValueError in DataLoader worker process 0. Original Traceback (most recent call last): File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop data = fetcher.fetch(index) File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch return self.collate_fn(data) File "/usr/local/lib/python3.10/dist-packages/sidechainnet/dataloaders/collate.py", line 85, in collate_fn padded_crds = pad_for_batch(coords, max_batch_len, 'crd') File "/usr/local/lib/python3.10/dist-packages/sidechainnet/dataloaders/collate.py", line 186, in pad_for_batch c = np.concatenate((item, z), axis=0) File "<__array_function__ internals>", line 180, in concatenate ValueError: all the input arrays must have same number of dimensions, but the array at index 0 has 3 dimension(s) and the array at index 1 has 2 dimension(s)`

dkoes commented 7 months ago

Can you provide more context for your code? The following works for me:

import sidechainnet as scn
dataloaders = scn.load(casp_version=12, casp_thinning=30, with_pytorch="dataloaders")
batch = next(iter(dataloaders['train']))