Closed cgoliver closed 2 years ago
To reproduce:
import tempfile
from proteinshake.datasets import AlphaFoldDataset
from torch_geometric.loader import DataLoader
with tempfile.TemporaryDirectory() as tmp:
dset = AlphaFoldDataset(root=tmp, organism='escherichia_coli').to_graph(eps=8).pyg()
l = DataLoader(dset, num_workers=4, batch_size=8)
for batch in l:
pass
Traceback:
Downloading AlphaFoldDataset_escherichia_coli.json.gz:
100%|█████████████████████████████████████████████████████████████████████████| 1.39k/1.39k [00:01<00:00, 9.34MiB/s]
Unzipping...
Converting proteins to graphs: 100%|███████████████████████████████████████████| 4363/4363 [00:20<00:00, 216.73it/s]
Traceback (most recent call last):
File "<..>/Projects/proteinshake/test.py", line 10, in <module>
for batch in l:
File "<..>/.venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 444, in __iter__
return self._get_iterator()
File "<..>/.venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 390, in _get_iterator
return _MultiProcessingDataLoaderIter(self)
File "<..>/proteinshake/.venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1077, in __init__
w.start()
File "/usr/local/Cellar/python@3.9/3.9.13_2/Frameworks/Python.framework/Versions/3.9/lib/python3.9/multiprocessing/process.py", line 121, in start
self._popen = self._Popen(self)
File "/usr/local/Cellar/python@3.9/3.9.13_2/Frameworks/Python.framework/Versions/3.9/lib/python3.9/multiprocessing/context.py", line 224, in _Popen
return _default_context.get_context().Process._Popen(process_obj)
File "/usr/local/Cellar/python@3.9/3.9.13_2/Frameworks/Python.framework/Versions/3.9/lib/python3.9/multiprocessing/context.py", line 284, in _Popen
return Popen(process_obj)
File "/usr/local/Cellar/python@3.9/3.9.13_2/Frameworks/Python.framework/Versions/3.9/lib/python3.9/multiprocessing/popen_spawn_posix.py", line 32, in __init__
super().__init__(process_obj)
File "/usr/local/Cellar/python@3.9/3.9.13_2/Frameworks/Python.framework/Versions/3.9/lib/python3.9/multiprocessing/popen_fork.py", line 19, in __init__
self._launch(process_obj)
File "/usr/local/Cellar/python@3.9/3.9.13_2/Frameworks/Python.framework/Versions/3.9/lib/python3.9/multiprocessing/popen_spawn_posix.py", line 47, in _launch
reduction.dump(process_obj, fp)
File "/usr/local/Cellar/python@3.9/3.9.13_2/Frameworks/Python.framework/Versions/3.9/lib/python3.9/multiprocessing/reduction.py", line 60, in dump
ForkingPickler(file, protocol).dump(obj)
AttributeError: Can't pickle local object 'GraphDataset.pyg.<locals>.Dataset'
Seems to be the classic problem of pickling custom types.
Ok so the problem is fixed by moving the class Dataset(InMemoryDataset)
outside of the GraphDatset
class. However this of course would mean a user that does not want to install pyg would get an import error.
I think the simplest solution would be to split up the different frameworks into their own module which subclasses GraphDataset
and overrides the convert()
function.
e.g.
pyg_data.py
(new non-existent file):
import torch
from torch_geometric.data import Data, InMemoryDataset
from proteinshake.representations import GraphDataset
class Dataset(InMemoryDataset):
pass
class PyGData(GraphDataset):
def __init__(self):
super().__init__()
def convert(self, graphs):
# convert graphs to a PyG dataset
graphs = (pdb2pyg(g) for g in graphs)
return Dataset(graphs)
This might interfere with the .pyg()
interface however.
fixed with #67, subclassing with mixins (from pyg) doesn't work too well in python. I kept your idea of putting it in a separate file (such that you don't need to import pyg if not necessary), just reorganized the folders.
When num_workers>1 in Dataloader, it doesn’t work anymore.
Error msg: AttributeError: Can't pickle local object 'GraphDataset.pyg..Dataset'
From @claying