BorgwardtLab / proteinshake

Protein structure datasets for machine learning.
https://proteinshake.ai
BSD 3-Clause "New" or "Revised" License
101 stars 9 forks source link

num_workers > 1 #57

Closed cgoliver closed 2 years ago

cgoliver commented 2 years ago

When num_workers>1 in Dataloader, it doesn’t work anymore.

Error msg: AttributeError: Can't pickle local object 'GraphDataset.pyg..Dataset'

From @claying

cgoliver commented 2 years ago

To reproduce:

commit

import tempfile
from proteinshake.datasets import AlphaFoldDataset
from torch_geometric.loader import DataLoader

with tempfile.TemporaryDirectory() as tmp:
    dset = AlphaFoldDataset(root=tmp, organism='escherichia_coli').to_graph(eps=8).pyg()
    l = DataLoader(dset, num_workers=4, batch_size=8)
    for batch in l:
        pass

Traceback:

Downloading AlphaFoldDataset_escherichia_coli.json.gz:
100%|█████████████████████████████████████████████████████████████████████████| 1.39k/1.39k [00:01<00:00, 9.34MiB/s]
Unzipping...
Converting proteins to graphs: 100%|███████████████████████████████████████████| 4363/4363 [00:20<00:00, 216.73it/s]
Traceback (most recent call last):
  File "<..>/Projects/proteinshake/test.py", line 10, in <module>
    for batch in l:
  File "<..>/.venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 444, in __iter__
    return self._get_iterator()
  File "<..>/.venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 390, in _get_iterator
    return _MultiProcessingDataLoaderIter(self)
  File "<..>/proteinshake/.venv/lib/python3.9/site-packages/torch/utils/data/dataloader.py", line 1077, in __init__
    w.start()
  File "/usr/local/Cellar/python@3.9/3.9.13_2/Frameworks/Python.framework/Versions/3.9/lib/python3.9/multiprocessing/process.py", line 121, in start
    self._popen = self._Popen(self)
  File "/usr/local/Cellar/python@3.9/3.9.13_2/Frameworks/Python.framework/Versions/3.9/lib/python3.9/multiprocessing/context.py", line 224, in _Popen
    return _default_context.get_context().Process._Popen(process_obj)
  File "/usr/local/Cellar/python@3.9/3.9.13_2/Frameworks/Python.framework/Versions/3.9/lib/python3.9/multiprocessing/context.py", line 284, in _Popen
    return Popen(process_obj)
  File "/usr/local/Cellar/python@3.9/3.9.13_2/Frameworks/Python.framework/Versions/3.9/lib/python3.9/multiprocessing/popen_spawn_posix.py", line 32, in __init__
    super().__init__(process_obj)
  File "/usr/local/Cellar/python@3.9/3.9.13_2/Frameworks/Python.framework/Versions/3.9/lib/python3.9/multiprocessing/popen_fork.py", line 19, in __init__
    self._launch(process_obj)
  File "/usr/local/Cellar/python@3.9/3.9.13_2/Frameworks/Python.framework/Versions/3.9/lib/python3.9/multiprocessing/popen_spawn_posix.py", line 47, in _launch
    reduction.dump(process_obj, fp)
  File "/usr/local/Cellar/python@3.9/3.9.13_2/Frameworks/Python.framework/Versions/3.9/lib/python3.9/multiprocessing/reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
AttributeError: Can't pickle local object 'GraphDataset.pyg.<locals>.Dataset'
cgoliver commented 2 years ago

Seems to be the classic problem of pickling custom types.

cgoliver commented 2 years ago

Ok so the problem is fixed by moving the class Dataset(InMemoryDataset) outside of the GraphDatset class. However this of course would mean a user that does not want to install pyg would get an import error.

I think the simplest solution would be to split up the different frameworks into their own module which subclasses GraphDataset and overrides the convert() function.

e.g.

pyg_data.py (new non-existent file):

import torch
from torch_geometric.data import Data, InMemoryDataset 

from proteinshake.representations import GraphDataset

class Dataset(InMemoryDataset):
    pass

class PyGData(GraphDataset):
    def __init__(self):
        super().__init__()
    def convert(self, graphs):
         # convert graphs to a PyG dataset
        graphs = (pdb2pyg(g) for g in graphs)
        return Dataset(graphs)

This might interfere with the .pyg() interface however.

timkucera commented 2 years ago

fixed with #67, subclassing with mixins (from pyg) doesn't work too well in python. I kept your idea of putting it in a separate file (such that you don't need to import pyg if not necessary), just reorganized the folders.