ivannz / cplxmodule

Complex-valued neural networks for pytorch and Variational Dropout for real and complex layers.
MIT License
138 stars 27 forks source link

FIX: defective replication in multi GPU setting #3

Closed ivannz closed 4 years ago

ivannz commented 4 years ago

When trying to use *Masked layers in mutli-gpu setting, I encountered an issue that parts of the model were not replicated on a GPU properly. The following code

import torch
from cplxmodule.nn.masked import LinearMasked
from torch.nn.parallel.data_parallel import data_parallel

net, x = LinearMasked(15, 30), torch.randn(1024, 1024, 15)
net.mask = torch.randint(0, 2, size=(30, 1))

data_parallel(net.cuda(0), x.cuda(0), [0, 1])

failed with mismatched device exception

RuntimeError: Caught RuntimeError in replica 1 on device 1.
  ...
  File "./cplxmodule/nn/masked/base.py", line 134, in weight_masked
    return self.weight * self.mask
RuntimeError: expected device cuda:1 but got device cuda:0

I inspected the key steps of data_parallel: replicate, scatter_kwargs, and parallel_apply. The following code revealed that replication was defective:

from torch.nn.parallel.replicate import replicate

replicas = replicate(network.cuda(0), [0, 1])
print([m.weight.device for m in replicas])
print([m.mask.device for m in replicas])

returned

[device(type='cuda', index=0), device(type='cuda', index=1)]
[device(type='cuda', index=0), device(type='cuda', index=0)]

In the section responsible for associating copies of buffers with model replicas replicate.py#L161-L163 the following line does the assignment:

# key = 'mask', buffer_copies[j][buffer_idx] is a torch.Tensor on a device
setattr(replica, key, buffer_copies[j][buffer_idx])

Since replica is an instance of BaseMasked base class, the setattr operation ultimately invokes __setattr__ method of this class. This method in turn dispatches all calls to torch.nn.Module, unless the set attribute's name happens to be "mask". Manipulations with masks are handled by .mask_ special method.

.mask_ method does the necessary reallocation and cleaning up when masks are set, updated or removed. Specifically, since replica is a copy of the original network all its masks were properly set and located on device('cuda:0'). Therefore mask_ in fact ended up in the "mask update" branch:

self.mask.copy_(mask.detach())

Now .copy_ method of tensor.Tensor only copies the data, moves between host and device, and casts if necessary. Therefore mask update ignored the intended device from the buffer copy and simply overwrote the mask with its copy and left the device the same. dtype seems to be unaffected by replication.

Possible solutions included:

  1. adding .to(self.weight.device) to masking action
  2. slightly modifying the logic of .mask_ so that upon mask update it is properly synchronised with weight as if is were set for the first time

Option 1 would incurred a cost of .to call which would be significant, since .weight_masked method is called upon every forward pass. However this will make sure that the mask is always on the same device as the weight. Option 2 would not incur such overhead, since mask is set much less often, but could potentially desync the weight and mask.

I went with option 2, since torch users relocate models from device to host and back via .to, .cpu, .cuda API, which all automatically move buffers and parameters in sync. Thus this interaction seems unlikely to cause device desync. replicate does not use these methods, and instead modifies internal dictionaries of nn.Module directly, or through setattr. However, it does so in a depth-first fashion: first module's children are reassigned, then parameters and finally buffers. Therefore unless the replicate significantly changes its logic, it is safe to go with option 2, get the device from the already reassigned weight of the replica, and bypass costly .to() call.