Open omarkhaled-28 opened 8 months ago
Can you provide more context for your code? The following works for me:
import sidechainnet as scn
dataloaders = scn.load(casp_version=12, casp_thinning=30, with_pytorch="dataloaders")
batch = next(iter(dataloaders['train']))
it gives error when trying to put traindata into batch `batch = next(iter(dataloader['train'])) print("Protein IDs\n ", batch.ids) print("Sequences\n ", batch.seqs.shape) print("Evolutionary Data\n ", batch.evolutionary.shape) print("Secondary Structure\n ", batch.secondary.shape) print("Angle Data\n ", batch.angles.shape) print("Coordinate Data\n ", batch.coords.shape) print("X-ray Resolution\n ", batch.resolutions) print("Integer sequence") print("\tShape:", batch.seqs_int.shape) print("\tEx:", batch.seqs_int[0,:3])
print("1-hot sequence") print("\tShape:", batch.seqs.shape) print("\tEx:\n", batch.seqs[0,:3]) in <cell line: 1>()
----> 1 batch = next(iter(dataloader['train']))
2 print("Protein IDs\n ", batch.ids)
3 print("Sequences\n ", batch.seqs.shape)
4 print("Evolutionary Data\n ", batch.evolutionary.shape)
5 print("Secondary Structure\n ", batch.secondary.shape)
this is the output once i run it
--------------------------------------------------------------------------- ValueError Traceback (most recent call last)3 frames /usr/local/lib/python3.10/dist-packages/torch/_utils.py in reraise(self) 692 # instantiate since we don't know how to 693 raise RuntimeError(msg) from None --> 694 raise exception 695 696
ValueError: Caught ValueError in DataLoader worker process 0. Original Traceback (most recent call last): File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop data = fetcher.fetch(index) File "/usr/local/lib/python3.10/dist-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch return self.collate_fn(data) File "/usr/local/lib/python3.10/dist-packages/sidechainnet/dataloaders/collate.py", line 85, in collate_fn padded_crds = pad_for_batch(coords, max_batch_len, 'crd') File "/usr/local/lib/python3.10/dist-packages/sidechainnet/dataloaders/collate.py", line 186, in pad_for_batch c = np.concatenate((item, z), axis=0) File "<__array_function__ internals>", line 180, in concatenate ValueError: all the input arrays must have same number of dimensions, but the array at index 0 has 3 dimension(s) and the array at index 1 has 2 dimension(s)`