jleuschn / dival

Deep Inversion Validation Library
MIT License
74 stars 13 forks source link

Using multiple num_workers in DataLoader in Spawning subprocesses #22

Closed liutianlin0121 closed 3 years ago

liutianlin0121 commented 4 years ago

Hi! Thanks for these awesome datasets and library!

The library works perfectly well for me when training models on a single GPU. But when using two GPUs with torch.multiprocessing, an error related to the num_workers and RandomAccessTorchDataset occurs. Basically, whenever I use num_workers > 0, torch.multiprocessing somehow breaks down. A code script that showcases this problem is attached at the end. I am not sure whether this is directly related to the library itself, so feel free to close the issue. But it will be great if you have any insights on this problem. Thanks a lot!

The script debug.py attached below appears to be long, but the essential part is just the main_worker function. What occurs is that when I run python debug.py --num_workers 0, the line x, d = next(iter(train_loader)) will be executed successfully; but when python debug.py --num_workers 4, an error of the following occurs:

(base) liu0003@dmi-20-pc-09:~/Desktop/projects/ct$ python debug.py --num_workers 4
Use GPU: 0 for training
Use GPU: 1 for training
start iterating
start iterating
Traceback (most recent call last):
  File "debug.py", line 135, in <module>
    main()
  File "debug.py", line 57, in main
    mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
  File "/home/liu0003/anaconda3/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 200, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/liu0003/anaconda3/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 158, in start_processes
    while not context.join():
  File "/home/liu0003/anaconda3/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 106, in join
    raise Exception(
Exception: process 1 terminated with signal SIGKILL

Here is the debug.py script:

import os, torch
import numpy as np
import random
from os import path

import torch.distributed as dist
import torch.multiprocessing as mp
import argparse

from torch.utils.data import DataLoader
from dival import get_standard_dataset
from dival.datasets.fbp_dataset import get_cached_fbp_dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import Dataset as TorchDataset

parser = argparse.ArgumentParser(description='debug')

# Dataset settings
parser.add_argument('--BATCH_SIZE', type=int, default=4, help='mini-batch size for training.')

# DistributedDataParallel settings
parser.add_argument('--num_workers', type=int, default=8, help='')
parser.add_argument("--gpu_devices", type=int, nargs='+', default=[0,1], help="")
parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.')
parser.add_argument('--dist-url', default='tcp://127.0.0.1:3456', type=str, help='')
parser.add_argument('--dist-backend', default='nccl', type=str, help='')
parser.add_argument('--rank', default=0, type=int, help='')
parser.add_argument('--world_size', default=1, type=int, help='')
parser.add_argument('--distributed', action='store_true', help='')

args = parser.parse_args()

gpu_devices = ','.join([str(id) for id in args.gpu_devices])
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_devices

def set_random_seeds(random_seed=0):

    torch.manual_seed(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)

def main():
    set_random_seeds()
    args = parser.parse_args()

    ngpus_per_node = torch.cuda.device_count()

    args.world_size = ngpus_per_node * args.world_size

    mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))

def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu
    ngpus_per_node = torch.cuda.device_count()    
    print("Use GPU: {} for training".format(args.gpu))

    args.rank = args.rank * ngpus_per_node + gpu    
    dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                            world_size=args.world_size, rank=args.rank)

    datasets = CTdatasets()

    batch_sizes = {'train': args.BATCH_SIZE, 'validation': args.BATCH_SIZE, 'test':1}

    dataloaders = {x: DataLoader(datasets[x], batch_size=batch_sizes[x], num_workers = args.num_workers, pin_memory=True, sampler=DistributedSampler(datasets[x]) ) for x in ['train', 'validation', 'test']}

    train_loader = dataloaders['train']    

    print('start iterating')

    x, d = next(iter(train_loader))

    print(x.flatten()[:10]) # when args.gpu_devices = 0, this line will be executed successfully; but when args.gpu_devices != 0, an error occurs.

class RandomAccessTorchDataset(TorchDataset):
    def __init__(self, dataset, part, reshape=None):
        self.dataset = dataset
        self.part = part
        self.reshape = reshape or (
            (None,) * self.dataset.get_num_elements_per_sample())

    def __len__(self):
        return self.dataset.get_len(self.part)

    def __getitem__(self, idx):
        arrays = self.dataset.get_sample(idx, part=self.part)
        mult_elem = isinstance(arrays, tuple)
        if not mult_elem:
            arrays = (arrays,)
        tensors = []
        for arr, s in zip(arrays, self.reshape):
            t = torch.from_numpy(np.asarray(arr))
            if s is not None:
                t = t.view(*s)
            tensors.append(t)
        return tuple(tensors) if mult_elem else tensors[0]

def CTdatasets(IMPL = 'skimage', cache_dir = '/home/liu0003/Desktop/projects/dival/', **kwargs):

    CACHE_FILES = {'train':
            (path.join(cache_dir, 'cache_lodopab_train_fbp.npy'), None),
                   'validation':
            (path.join(cache_dir, 'cache_lodopab_validation_fbp.npy'), None)}

    standard_dataset = get_standard_dataset('lodopab', impl=IMPL)
    ray_trafo = standard_dataset.get_ray_trafo(impl=IMPL)
    dataset = get_cached_fbp_dataset(standard_dataset, ray_trafo, CACHE_FILES)

    # create PyTorch datasets
    dataset_train = RandomAccessTorchDataset(dataset = dataset,
        part='train', reshape=((1,) + dataset.space[0].shape,
                               (1,) + dataset.space[1].shape))

    dataset_validation = RandomAccessTorchDataset(dataset = dataset,
        part='validation', reshape=((1,) + dataset.space[0].shape,
                               (1,) + dataset.space[1].shape))

    dataset_test = RandomAccessTorchDataset(dataset = dataset,
        part='test', reshape=((1,) + dataset.space[0].shape,
                               (1,) + dataset.space[1].shape))
    datasets = {'train': dataset_train, 'validation': dataset_validation, 'test': dataset_test}

    return datasets

if __name__=='__main__':
    main()

Please run python debug.py --num_workers 4 to see the error I mentioned earlier. Thanks in advance!

jleuschn commented 4 years ago

Hi, nice to hear you find the library useful!

When running debug.py with --num_workers > 0 i get the error: AttributeError: Can't pickle local object 'RayTransform.adjoint.<locals>.RayBackProjection'.

It fails due to the pickling performed by the multiprocessing library, which can't handle the local class definition in RayTransform.adjoint. I just tried a dirty hot-fix that works by patching the internal attribute ray_trafo._adjoint:

from odl import Operator

# Class copied and modified from
# odl.tomo.operators.ray_trafo.RayTransform.adjoint.
# This main-scope class definition is necessary for multiprocessing, because
# pickling the local class in RayTransform.adjoint fails.
class RayBackProjection(Operator):
    """Adjoint of the discrete Ray transform between L^p spaces.
    """
    def __init__(self, ray_trafo, **kwargs):
        self.ray_trafo = ray_trafo
        super().__init__(**kwargs)

    def _call(self, x, out=None, **kwargs):
        """Backprojection.

        Parameters
        ----------
        x : DiscretizedSpaceElement
            A sinogram. Must be an element of
            `RayTransform.range` (domain of `RayBackProjection`).
        out : `RayBackProjection.domain` element, optional
            A volume to which the result of this evaluation is
            written.
        **kwargs
            Extra keyword arguments, passed on to the
            implementation backend.

        Returns
        -------
        DiscretizedSpaceElement
            Result of the transform in the domain
            of `RayProjection`.
        """
        return self.ray_trafo.get_impl(
            self.ray_trafo.use_cache
        ).call_backward(x, out, **kwargs)

    @property
    def geometry(self):
        return self.ray_trafo.geometry

    @property
    def adjoint(self):
        return self.ray_trafo

def patch_ray_trafo_for_mp(ray_trafo):
    kwargs = ray_trafo._extra_kwargs.copy()
    kwargs['domain'] = ray_trafo.range
    ray_trafo._adjoint = RayBackProjection(
        ray_trafo, range=ray_trafo.domain, linear=True, **kwargs
    )

Then in CTdatasets, you can patch the two ray trafo objects by inserting the latter two lines of these:

def CTdatasets(...)
    ...
    standard_dataset = get_standard_dataset('lodopab', impl=IMPL)
    ray_trafo = standard_dataset.get_ray_trafo(impl=IMPL)
    patch_ray_trafo_for_mp(standard_dataset.ray_trafo)
    patch_ray_trafo_for_mp(ray_trafo)
    ...

This is quite cumbersome, but at least seems to work.

Hope it helps!

liutianlin0121 commented 4 years ago

Thanks a lot! I will try this fix later ;-)

liutianlin0121 commented 4 years ago

Hi,

Thanks again for your suggestions! I have used the patch function as suggested, but for some reasons it does not seem to work. To see if this is a problem related to my local environment, I am wondering if you could perhaps run the following python script, which includes the patching function you provided. If this works for you, I must have a local problem. Many thanks!

python debug.py --num_workers 2

where debug.py is the following script (please modify the default cache_dir in the CTdatasets function below)

import os, torch
import numpy as np
import random
from os import path

import torch.distributed as dist
import torch.multiprocessing as mp
import argparse

from torch.utils.data import DataLoader
from dival import get_standard_dataset
from dival.datasets.fbp_dataset import get_cached_fbp_dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import Dataset as TorchDataset

parser = argparse.ArgumentParser(description='debug')

# Dataset settings
parser.add_argument('--BATCH_SIZE', type=int, default=4, help='mini-batch size for training.')

# DistributedDataParallel settings
parser.add_argument('--num_workers', type=int, default=8, help='')
parser.add_argument("--gpu_devices", type=int, nargs='+', default=[0,1], help="")
parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.')
parser.add_argument('--dist-url', default='tcp://127.0.0.1:3456', type=str, help='')
parser.add_argument('--dist-backend', default='nccl', type=str, help='')
parser.add_argument('--rank', default=0, type=int, help='')
parser.add_argument('--world_size', default=1, type=int, help='')
parser.add_argument('--distributed', action='store_true', help='')

args = parser.parse_args()

gpu_devices = ','.join([str(id) for id in args.gpu_devices])
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_devices

def set_random_seeds(random_seed=0):

    torch.manual_seed(random_seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    np.random.seed(random_seed)
    random.seed(random_seed)

def main():
    set_random_seeds()
    args = parser.parse_args()

    ngpus_per_node = torch.cuda.device_count()

    args.world_size = ngpus_per_node * args.world_size

    mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))

def main_worker(gpu, ngpus_per_node, args):
    args.gpu = gpu
    ngpus_per_node = torch.cuda.device_count()    
    print("Use GPU: {} for training".format(args.gpu))

    args.rank = args.rank * ngpus_per_node + gpu    
    dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
                            world_size=args.world_size, rank=args.rank)

    datasets = CTdatasets()

    dataset_train = datasets['train']

    train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train)

    train_loader = DataLoader(dataset_train, batch_size=args.BATCH_SIZE, 
                              shuffle=(train_sampler is None), num_workers=args.num_workers, 
                              sampler=train_sampler)

    print('start iterating')
    x, d = next(iter(train_loader))

    print(x.flatten()[:10])

class RandomAccessTorchDataset(TorchDataset):
    def __init__(self, dataset, part, reshape=None):
        self.dataset = dataset
        self.part = part
        self.reshape = reshape or (
            (None,) * self.dataset.get_num_elements_per_sample())

    def __len__(self):
        return self.dataset.get_len(self.part)

    def __getitem__(self, idx):
        arrays = self.dataset.get_sample(idx, part=self.part)
        mult_elem = isinstance(arrays, tuple)
        if not mult_elem:
            arrays = (arrays,)
        tensors = []
        for arr, s in zip(arrays, self.reshape):
            t = torch.from_numpy(np.asarray(arr))
            if s is not None:
                t = t.view(*s)
            tensors.append(t)
        return tuple(tensors) if mult_elem else tensors[0]

def CTdatasets(IMPL = 'astra_cuda', cache_dir = '/home/liu0003/Desktop/projects/dival/', **kwargs):

    CACHE_FILES = {'train': (path.join(cache_dir, 'cache_lodopab_train_fbp.npy'), None),
                   'validation':(path.join(cache_dir, 'cache_lodopab_validation_fbp.npy'), None)}

    standard_dataset = get_standard_dataset('lodopab', impl=IMPL)
    ray_trafo = standard_dataset.get_ray_trafo(impl=IMPL)
    patch_ray_trafo_for_mp(standard_dataset.ray_trafo)
    patch_ray_trafo_for_mp(ray_trafo)

    dataset = get_cached_fbp_dataset(standard_dataset, ray_trafo, CACHE_FILES)

    # create PyTorch datasets
    dataset_train = RandomAccessTorchDataset(dataset = dataset,
        part='train', reshape=((1,) + dataset.space[0].shape,
                               (1,) + dataset.space[1].shape))

    dataset_test = RandomAccessTorchDataset(dataset = dataset,
        part='test', reshape=((1,) + dataset.space[0].shape,
                               (1,) + dataset.space[1].shape))
    datasets = {'train': dataset_train,   'test': dataset_test}

    return datasets

from odl import Operator

# Class copied and modified from
# odl.tomo.operators.ray_trafo.RayTransform.adjoint.
# This main-scope class definition is necessary for multiprocessing, because
# pickling the local class in RayTransform.adjoint fails.
class RayBackProjection(Operator):
    """Adjoint of the discrete Ray transform between L^p spaces.
    """
    def __init__(self, ray_trafo, **kwargs):
        self.ray_trafo = ray_trafo
        super().__init__(**kwargs)

    def _call(self, x, out=None, **kwargs):
        """Backprojection.

        Parameters
        ----------
        x : DiscretizedSpaceElement
            A sinogram. Must be an element of
            `RayTransform.range` (domain of `RayBackProjection`).
        out : `RayBackProjection.domain` element, optional
            A volume to which the result of this evaluation is
            written.
        **kwargs
            Extra keyword arguments, passed on to the
            implementation backend.

        Returns
        -------
        DiscretizedSpaceElement
            Result of the transform in the domain
            of `RayProjection`.
        """
        return self.ray_trafo.get_impl(
            self.ray_trafo.use_cache
        ).call_backward(x, out, **kwargs)

    @property
    def geometry(self):
        return self.ray_trafo.geometry

    @property
    def adjoint(self):
        return self.ray_trafo

def patch_ray_trafo_for_mp(ray_trafo):
    kwargs = ray_trafo._extra_kwargs.copy()
    kwargs['domain'] = ray_trafo.range
    ray_trafo._adjoint = RayBackProjection(
        ray_trafo, range=ray_trafo.domain, linear=True, **kwargs
    )

if __name__=='__main__':
    main()

When running python debug.py --num_workers 2, I did not see any complaint about RayBackProjection, which makes sense; but there are some generic error messages related to multiprocessing:

python debug.py --num_workers 2
Use GPU: 1 for training
Use GPU: 0 for training
start iterating
start iterating
Traceback (most recent call last):
  File "debug.py", line 187, in <module>
    main()
  File "debug.py", line 56, in main
    mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
  File "/home/liu0003/anaconda3/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 200, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/home/liu0003/anaconda3/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 158, in start_processes
    while not context.join():
  File "/home/liu0003/anaconda3/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 106, in join
    raise Exception(
Exception: process 0 terminated with signal SIGKILL
(base) liu0003@dmi-20-pc-09:~$ /home/liu0003/anaconda3/lib/python3.8/multiprocessing/resource_tracker.py:216: UserWarning: resource_tracker: There appear to be 22 leaked semaphore objects to clean up at shutdown
  warnings.warn('resource_tracker: There appear to be %d '

and I am not sure how to proceed to debug. Any suggestions are greatly appreciated!

jleuschn commented 4 years ago

Hi,

i cannot look into this currently, sorry, but i'm happy to check on 5th of October if that would still be useful.

Best wishes

liutianlin0121 commented 4 years ago

Thanks so much! That will be great.

jleuschn commented 3 years ago

Hi,

your code works fine in my environment.

Maybe you had some processes still running from a previous run of debug.py? When i start the script, abort it, and then start it again, i get a similar error message to yours.

As a side note, i needed to change the pickle protocol to version 4 with a workaround, but at first sight this does not seem to be related to the other issue.

liutianlin0121 commented 3 years ago

Thanks so much!! I will debug locally then.