Open gogolgrind opened 4 years ago
Thank you for the issue!
Here is the minimally reproducing example:
import torch
from torch import nn
from cplxmodule import cplx
import cplxmodule.nn as cplxnn
from torch.nn.parallel.data_parallel import data_parallel
net, x = nn.Conv2d(3, 3, 3), torch.randn(2, 3, 6, 6)
data_parallel(net.cuda(0), x.cuda(0), [0, 1])
net, x = cplxnn.CplxConv2d(3, 3, 3), cplx.Cplx(torch.randn(1, 3, 6, 6))
data_parallel(net.cuda(0), x.cuda(0), [0, 1])
data_parallel
uses three key functions:
from torch.nn.parallel.scatter_gather import scatter_kwargs
from torch.nn.parallel.replicate import replicate
from torch.nn.parallel.parallel_apply import parallel_apply
scatter_kwargs
is responsible for splitting the input along the batch dimension and moving the shards to appropriate devices. replicate
is responsible for performing in-vivo surgery on the model: taking it apart, and rebuilding it with its parameters wrapped in scatter-gather operations. Finally parallel_apply
takes properly placed inputs and model replicas and run them in parallel threads, collecting the output upon termination.
replication and moving inputs manually seems to work ok:
from torch.nn.parallel.replicate import replicate
net = cplxnn.CplxConv2d(3, 3, 3)
x = cplx.Cplx(torch.randn(1, 3, 6, 6), torch.randn(1, 3, 6, 6))
replicas = replicate(net.cuda(0), [0, 1])
print([m.weight.device for m in replicas])
print([m(x.to(m.weight.device)).device for m in replicas])
Since the error message from parallel_apply
coincides with
import torch.nn.functional as F
F.conv2d(torch.randn(1, 3, 6, 6).cuda(0), torch.randn(3, 3, 3, 3).cuda(1))
I think the issue here is that the input is on the wrong device, whilst the model might be on the right one. Thus i investigated scatter_kwargs.
scatter_kwargs
calls scatter on the input tensor. scatter uses some internal low-level functionality to scatter the tensor along the zero-th dimension. Unfortunately Cplx is a high level Python object and is duck-typed to behave like a Tensor only on high-level. It is not binary compatible with torch.Tensor
on C++ level.
Unfortunately I can only suggest a workaround wrapper.
def r2r_wrap(model, dim_in=1, dim_out=1):
return torch.nn.Sequential(
# convert Tensor `B x F*2 x ...` (dim_in=1) to Cplx as part of the computations
cplxnn.RealToCplx(dim=dim_in),
model,
# convert Cplx back to B x O*2 x ... (dim_out=1) Tensor as part of the pipeline
cplxnn.CplxToReal(dim=dim_out)
)
model_complex = r2r_wrap(cplxtest_net()).to(device)
This is a thin Real-to-Real wrapper around the whole model, which makes conversion from torch Tensors to Cplx and back a part of the model structure, and thus bypasses the scatter issue.
The key nuance is that your input to and output of the model is just Tensor
, not cplx.Cplx
# put real-imag pairs along the 1st dim
complex_data = torch.randn(1, 3*2, 224, 224).to(device)
# real-imag pairs are assumed to alternate the 1st dim
cplx_data = cplxnn.RealToCplx(dim=1)(complex_data)
assert torch.allclose(cplx_data.real, complex_data[:, 0::2])
assert torch.allclose(cplx_data.imag, complex_data[:, 1::2])
The output:
tensor_output = model_complex(complex_data)
cplx_data = cplxnn.RealToCplx(dim=1)(tensor_output)
assert torch.allclose(cplx_data.real, tensor_output[:, 0::2])
assert torch.allclose(cplx_data.imag, tensor_output[:, 1::2])
These Tensors are just a way to store the complex numbers and in no way affect their arithmetic or operations inside the Cplx network or cplxmodule itself.
It seems cplxmoduke don't work with nn.DataParallel. Attached minimal example gives the following error
cplxmodule-bug.py.txt