tristandeleu / pytorch-meta

A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch
https://tristandeleu.github.io/pytorch-meta/
MIT License
1.98k stars 256 forks source link

num_workers > 0 for BatchMetaDataLoader causes an AttributeError to be raised when 'spawn' is used as the multiprocessing start method. #38

Closed tesfaldet closed 4 years ago

tesfaldet commented 4 years ago

I've provided a minimal-ish working example to reproduce the bug. I don't want to use the 'fork' method for multiprocessing for various reasons: avoiding race conditions, deadlocks, and that CUDA isn't supported with this method.

# import pytest

import tqdm
import time

from torchmeta.datasets import TripleMNIST
from torchmeta.transforms import ClassSplitter
from torchmeta.utils.data import BatchMetaDataLoader

from torchvision import transforms

import multiprocessing
multiprocessing.set_start_method('spawn', True)

dataset_path = '~/Downloads'
transform = \
    transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,)),
    ])
dataset = TripleMNIST(root=dataset_path,
                      transform=transform,
                      target_transform=None,
                      num_classes_per_task=1,
                      meta_train=True,
                      download=True)
dataset = ClassSplitter(dataset, shuffle=True,
                        num_train_per_class=5)
dataloader = BatchMetaDataLoader(dataset,
                                 batch_size=1,
                                 pin_memory=True,
                                 num_workers=2)

def train_on_task_stub(batch):
    time.sleep(.20)

episode = 1
num_episodes = 1000
num_epochs = 10
start_epoch = 1
while episode < num_episodes:
    for i, batch in enumerate(dataloader):
        inputs, targets = batch['train']

        pbar = tqdm.tqdm(total=num_epochs, miniters=num_epochs/100,
                         ncols=180, initial=start_epoch)
        for epoch in range(start_epoch, num_epochs + 1):
            train_batch = [inputs[0], targets[0]]
            train_on_task_stub(train_batch)
            mesg = 'Episode {}/{} Epoch {}/{}:\t'.format(episode,
                                                         num_episodes,
                                                         epoch, num_epochs)
            pbar.update(1)
            pbar.set_description(mesg, refresh=False)
        pbar.close()

        if episode >= num_episodes:
            break

        episode += 1
        start_epoch = 1
Can't pickle local object 'batch_meta_collate.<locals>._collate_fn'
  File "/Users/mattie/Projects/pytorch-meta/test.py", line 42, in <module>
    for i, batch in enumerate(dataloader):

I believe the culprit is below. Functions defined within functions are not pickle-able, and pickling is the method used for transporting data between processes when 'spawn' is used. Would you happen to know the best way to rewrite the batch_meta_collate function to prevent this issue?

https://github.com/tristandeleu/pytorch-meta/blob/794bf82348fbdc2b68b04f5de89c38017d54ba59/torchmeta/utils/data/dataloader.py#L11-L24

tristandeleu commented 4 years ago

Fixed with #39