ivannz / cplxmodule

Complex-valued neural networks for pytorch and Variational Dropout for real and complex layers.
MIT License
138 stars 27 forks source link

Workaround for nn.DataParallel bug #1

Open gogolgrind opened 4 years ago

gogolgrind commented 4 years ago

It seems cplxmoduke don't work with nn.DataParallel. Attached minimal example gives the following error

RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 0 does not equal 1 (while checking arguments for cudnn_convolution)

cplxmodule-bug.py.txt

ivannz commented 4 years ago

Thank you for the issue!

Analysis

Here is the minimally reproducing example:

import torch
from torch import nn

from cplxmodule import cplx
import cplxmodule.nn as cplxnn

from torch.nn.parallel.data_parallel import data_parallel

net, x = nn.Conv2d(3, 3, 3), torch.randn(2, 3, 6, 6)
data_parallel(net.cuda(0), x.cuda(0), [0, 1])

net, x = cplxnn.CplxConv2d(3, 3, 3), cplx.Cplx(torch.randn(1, 3, 6, 6))
data_parallel(net.cuda(0), x.cuda(0), [0, 1])

data_parallel uses three key functions:

from torch.nn.parallel.scatter_gather import scatter_kwargs
from torch.nn.parallel.replicate import replicate
from torch.nn.parallel.parallel_apply import parallel_apply

scatter_kwargs is responsible for splitting the input along the batch dimension and moving the shards to appropriate devices. replicate is responsible for performing in-vivo surgery on the model: taking it apart, and rebuilding it with its parameters wrapped in scatter-gather operations. Finally parallel_apply takes properly placed inputs and model replicas and run them in parallel threads, collecting the output upon termination.

replication and moving inputs manually seems to work ok:

from torch.nn.parallel.replicate import replicate

net = cplxnn.CplxConv2d(3, 3, 3)
x = cplx.Cplx(torch.randn(1, 3, 6, 6), torch.randn(1, 3, 6, 6))

replicas = replicate(net.cuda(0), [0, 1])
print([m.weight.device for m in replicas])
print([m(x.to(m.weight.device)).device for m in replicas])

Since the error message from parallel_apply coincides with

import torch.nn.functional as F

F.conv2d(torch.randn(1, 3, 6, 6).cuda(0), torch.randn(3, 3, 3, 3).cuda(1))

I think the issue here is that the input is on the wrong device, whilst the model might be on the right one. Thus i investigated scatter_kwargs.

scatter_kwargs calls scatter on the input tensor. scatter uses some internal low-level functionality to scatter the tensor along the zero-th dimension. Unfortunately Cplx is a high level Python object and is duck-typed to behave like a Tensor only on high-level. It is not binary compatible with torch.Tensor on C++ level.

Solution

Unfortunately I can only suggest a workaround wrapper.

def r2r_wrap(model, dim_in=1, dim_out=1):
    return torch.nn.Sequential(
        # convert Tensor `B x F*2 x ...` (dim_in=1) to Cplx as part of the computations
        cplxnn.RealToCplx(dim=dim_in),
        model,
        # convert Cplx back to B x O*2 x ... (dim_out=1) Tensor as part of the pipeline
        cplxnn.CplxToReal(dim=dim_out) 
    )

model_complex = r2r_wrap(cplxtest_net()).to(device)

This is a thin Real-to-Real wrapper around the whole model, which makes conversion from torch Tensors to Cplx and back a part of the model structure, and thus bypasses the scatter issue.

The key nuance is that your input to and output of the model is just Tensor, not cplx.Cplx

# put real-imag pairs along the 1st dim
complex_data = torch.randn(1, 3*2, 224, 224).to(device)

# real-imag pairs are assumed to alternate the 1st dim
cplx_data = cplxnn.RealToCplx(dim=1)(complex_data)
assert torch.allclose(cplx_data.real, complex_data[:, 0::2])
assert torch.allclose(cplx_data.imag, complex_data[:, 1::2])

The output:

tensor_output = model_complex(complex_data)
cplx_data = cplxnn.RealToCplx(dim=1)(tensor_output)
assert torch.allclose(cplx_data.real, tensor_output[:, 0::2])
assert torch.allclose(cplx_data.imag, tensor_output[:, 1::2])

These Tensors are just a way to store the complex numbers and in no way affect their arithmetic or operations inside the Cplx network or cplxmodule itself.