Closed liutianlin0121 closed 3 years ago
Hi, nice to hear you find the library useful!
When running debug.py
with --num_workers
> 0 i get the error:
AttributeError: Can't pickle local object 'RayTransform.adjoint.<locals>.RayBackProjection'
.
It fails due to the pickling performed by the multiprocessing
library, which can't handle the local class definition in RayTransform.adjoint
.
I just tried a dirty hot-fix that works by patching the internal attribute ray_trafo._adjoint
:
from odl import Operator
# Class copied and modified from
# odl.tomo.operators.ray_trafo.RayTransform.adjoint.
# This main-scope class definition is necessary for multiprocessing, because
# pickling the local class in RayTransform.adjoint fails.
class RayBackProjection(Operator):
"""Adjoint of the discrete Ray transform between L^p spaces.
"""
def __init__(self, ray_trafo, **kwargs):
self.ray_trafo = ray_trafo
super().__init__(**kwargs)
def _call(self, x, out=None, **kwargs):
"""Backprojection.
Parameters
----------
x : DiscretizedSpaceElement
A sinogram. Must be an element of
`RayTransform.range` (domain of `RayBackProjection`).
out : `RayBackProjection.domain` element, optional
A volume to which the result of this evaluation is
written.
**kwargs
Extra keyword arguments, passed on to the
implementation backend.
Returns
-------
DiscretizedSpaceElement
Result of the transform in the domain
of `RayProjection`.
"""
return self.ray_trafo.get_impl(
self.ray_trafo.use_cache
).call_backward(x, out, **kwargs)
@property
def geometry(self):
return self.ray_trafo.geometry
@property
def adjoint(self):
return self.ray_trafo
def patch_ray_trafo_for_mp(ray_trafo):
kwargs = ray_trafo._extra_kwargs.copy()
kwargs['domain'] = ray_trafo.range
ray_trafo._adjoint = RayBackProjection(
ray_trafo, range=ray_trafo.domain, linear=True, **kwargs
)
Then in CTdatasets
, you can patch the two ray trafo objects by inserting the latter two lines of these:
def CTdatasets(...)
...
standard_dataset = get_standard_dataset('lodopab', impl=IMPL)
ray_trafo = standard_dataset.get_ray_trafo(impl=IMPL)
patch_ray_trafo_for_mp(standard_dataset.ray_trafo)
patch_ray_trafo_for_mp(ray_trafo)
...
This is quite cumbersome, but at least seems to work.
Hope it helps!
Thanks a lot! I will try this fix later ;-)
Hi,
Thanks again for your suggestions! I have used the patch function as suggested, but for some reasons it does not seem to work. To see if this is a problem related to my local environment, I am wondering if you could perhaps run the following python script, which includes the patching function you provided. If this works for you, I must have a local problem. Many thanks!
python debug.py --num_workers 2
where debug.py
is the following script (please modify the default cache_dir
in the CTdatasets
function below)
import os, torch
import numpy as np
import random
from os import path
import torch.distributed as dist
import torch.multiprocessing as mp
import argparse
from torch.utils.data import DataLoader
from dival import get_standard_dataset
from dival.datasets.fbp_dataset import get_cached_fbp_dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import Dataset as TorchDataset
parser = argparse.ArgumentParser(description='debug')
# Dataset settings
parser.add_argument('--BATCH_SIZE', type=int, default=4, help='mini-batch size for training.')
# DistributedDataParallel settings
parser.add_argument('--num_workers', type=int, default=8, help='')
parser.add_argument("--gpu_devices", type=int, nargs='+', default=[0,1], help="")
parser.add_argument('--gpu', default=None, type=int, help='GPU id to use.')
parser.add_argument('--dist-url', default='tcp://127.0.0.1:3456', type=str, help='')
parser.add_argument('--dist-backend', default='nccl', type=str, help='')
parser.add_argument('--rank', default=0, type=int, help='')
parser.add_argument('--world_size', default=1, type=int, help='')
parser.add_argument('--distributed', action='store_true', help='')
args = parser.parse_args()
gpu_devices = ','.join([str(id) for id in args.gpu_devices])
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_devices
def set_random_seeds(random_seed=0):
torch.manual_seed(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(random_seed)
random.seed(random_seed)
def main():
set_random_seeds()
args = parser.parse_args()
ngpus_per_node = torch.cuda.device_count()
args.world_size = ngpus_per_node * args.world_size
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
def main_worker(gpu, ngpus_per_node, args):
args.gpu = gpu
ngpus_per_node = torch.cuda.device_count()
print("Use GPU: {} for training".format(args.gpu))
args.rank = args.rank * ngpus_per_node + gpu
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
world_size=args.world_size, rank=args.rank)
datasets = CTdatasets()
dataset_train = datasets['train']
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset_train)
train_loader = DataLoader(dataset_train, batch_size=args.BATCH_SIZE,
shuffle=(train_sampler is None), num_workers=args.num_workers,
sampler=train_sampler)
print('start iterating')
x, d = next(iter(train_loader))
print(x.flatten()[:10])
class RandomAccessTorchDataset(TorchDataset):
def __init__(self, dataset, part, reshape=None):
self.dataset = dataset
self.part = part
self.reshape = reshape or (
(None,) * self.dataset.get_num_elements_per_sample())
def __len__(self):
return self.dataset.get_len(self.part)
def __getitem__(self, idx):
arrays = self.dataset.get_sample(idx, part=self.part)
mult_elem = isinstance(arrays, tuple)
if not mult_elem:
arrays = (arrays,)
tensors = []
for arr, s in zip(arrays, self.reshape):
t = torch.from_numpy(np.asarray(arr))
if s is not None:
t = t.view(*s)
tensors.append(t)
return tuple(tensors) if mult_elem else tensors[0]
def CTdatasets(IMPL = 'astra_cuda', cache_dir = '/home/liu0003/Desktop/projects/dival/', **kwargs):
CACHE_FILES = {'train': (path.join(cache_dir, 'cache_lodopab_train_fbp.npy'), None),
'validation':(path.join(cache_dir, 'cache_lodopab_validation_fbp.npy'), None)}
standard_dataset = get_standard_dataset('lodopab', impl=IMPL)
ray_trafo = standard_dataset.get_ray_trafo(impl=IMPL)
patch_ray_trafo_for_mp(standard_dataset.ray_trafo)
patch_ray_trafo_for_mp(ray_trafo)
dataset = get_cached_fbp_dataset(standard_dataset, ray_trafo, CACHE_FILES)
# create PyTorch datasets
dataset_train = RandomAccessTorchDataset(dataset = dataset,
part='train', reshape=((1,) + dataset.space[0].shape,
(1,) + dataset.space[1].shape))
dataset_test = RandomAccessTorchDataset(dataset = dataset,
part='test', reshape=((1,) + dataset.space[0].shape,
(1,) + dataset.space[1].shape))
datasets = {'train': dataset_train, 'test': dataset_test}
return datasets
from odl import Operator
# Class copied and modified from
# odl.tomo.operators.ray_trafo.RayTransform.adjoint.
# This main-scope class definition is necessary for multiprocessing, because
# pickling the local class in RayTransform.adjoint fails.
class RayBackProjection(Operator):
"""Adjoint of the discrete Ray transform between L^p spaces.
"""
def __init__(self, ray_trafo, **kwargs):
self.ray_trafo = ray_trafo
super().__init__(**kwargs)
def _call(self, x, out=None, **kwargs):
"""Backprojection.
Parameters
----------
x : DiscretizedSpaceElement
A sinogram. Must be an element of
`RayTransform.range` (domain of `RayBackProjection`).
out : `RayBackProjection.domain` element, optional
A volume to which the result of this evaluation is
written.
**kwargs
Extra keyword arguments, passed on to the
implementation backend.
Returns
-------
DiscretizedSpaceElement
Result of the transform in the domain
of `RayProjection`.
"""
return self.ray_trafo.get_impl(
self.ray_trafo.use_cache
).call_backward(x, out, **kwargs)
@property
def geometry(self):
return self.ray_trafo.geometry
@property
def adjoint(self):
return self.ray_trafo
def patch_ray_trafo_for_mp(ray_trafo):
kwargs = ray_trafo._extra_kwargs.copy()
kwargs['domain'] = ray_trafo.range
ray_trafo._adjoint = RayBackProjection(
ray_trafo, range=ray_trafo.domain, linear=True, **kwargs
)
if __name__=='__main__':
main()
When running python debug.py --num_workers 2
, I did not see any complaint about RayBackProjection
, which makes sense; but there are some generic error messages related to multiprocessing:
python debug.py --num_workers 2
Use GPU: 1 for training
Use GPU: 0 for training
start iterating
start iterating
Traceback (most recent call last):
File "debug.py", line 187, in <module>
main()
File "debug.py", line 56, in main
mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
File "/home/liu0003/anaconda3/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 200, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
File "/home/liu0003/anaconda3/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 158, in start_processes
while not context.join():
File "/home/liu0003/anaconda3/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 106, in join
raise Exception(
Exception: process 0 terminated with signal SIGKILL
(base) liu0003@dmi-20-pc-09:~$ /home/liu0003/anaconda3/lib/python3.8/multiprocessing/resource_tracker.py:216: UserWarning: resource_tracker: There appear to be 22 leaked semaphore objects to clean up at shutdown
warnings.warn('resource_tracker: There appear to be %d '
and I am not sure how to proceed to debug. Any suggestions are greatly appreciated!
Hi,
i cannot look into this currently, sorry, but i'm happy to check on 5th of October if that would still be useful.
Best wishes
Thanks so much! That will be great.
Hi,
your code works fine in my environment.
Maybe you had some processes still running from a previous run of debug.py? When i start the script, abort it, and then start it again, i get a similar error message to yours.
As a side note, i needed to change the pickle protocol to version 4 with a workaround, but at first sight this does not seem to be related to the other issue.
Thanks so much!! I will debug locally then.
Hi! Thanks for these awesome datasets and library!
The library works perfectly well for me when training models on a single GPU. But when using two GPUs with
torch.multiprocessing
, an error related to thenum_workers
andRandomAccessTorchDataset
occurs. Basically, whenever I usenum_workers > 0
,torch.multiprocessing
somehow breaks down. A code script that showcases this problem is attached at the end. I am not sure whether this is directly related to the library itself, so feel free to close the issue. But it will be great if you have any insights on this problem. Thanks a lot!The script
debug.py
attached below appears to be long, but the essential part is just themain_worker
function. What occurs is that when I runpython debug.py --num_workers 0
, the linex, d = next(iter(train_loader))
will be executed successfully; but whenpython debug.py --num_workers 4
, an error of the following occurs:Here is the
debug.py
script:Please run
python debug.py --num_workers 4
to see the error I mentioned earlier. Thanks in advance!