When trying to use *Masked layers in mutli-gpu setting, I encountered an issue that parts of the model were not replicated on a GPU properly. The following code
import torch
from cplxmodule.nn.masked import LinearMasked
from torch.nn.parallel.data_parallel import data_parallel
net, x = LinearMasked(15, 30), torch.randn(1024, 1024, 15)
net.mask = torch.randint(0, 2, size=(30, 1))
data_parallel(net.cuda(0), x.cuda(0), [0, 1])
failed with mismatched device exception
RuntimeError: Caught RuntimeError in replica 1 on device 1.
...
File "./cplxmodule/nn/masked/base.py", line 134, in weight_masked
return self.weight * self.mask
RuntimeError: expected device cuda:1 but got device cuda:0
I inspected the key steps of data_parallel: replicate, scatter_kwargs, and parallel_apply. The following code revealed that replication was defective:
from torch.nn.parallel.replicate import replicate
replicas = replicate(network.cuda(0), [0, 1])
print([m.weight.device for m in replicas])
print([m.mask.device for m in replicas])
In the section responsible for associating copies of buffers with model replicas replicate.py#L161-L163 the following line does the assignment:
# key = 'mask', buffer_copies[j][buffer_idx] is a torch.Tensor on a device
setattr(replica, key, buffer_copies[j][buffer_idx])
Since replica is an instance of BaseMasked base class, the setattr operation ultimately invokes __setattr__ method of this class. This method in turn dispatches all calls to torch.nn.Module, unless the set attribute's name happens to be "mask". Manipulations with masks are handled by .mask_ special method.
.mask_ method does the necessary reallocation and cleaning up when masks are set, updated or removed. Specifically, since replica is a copy of the original network all its masks were properly set and located on device('cuda:0'). Therefore mask_ in fact ended up in the "mask update" branch:
self.mask.copy_(mask.detach())
Now .copy_ method of tensor.Tensor only copies the data, moves between host and device, and casts if necessary. Therefore mask update ignored the intended device from the buffer copy and simply overwrote the mask with its copy and left the device the same. dtype seems to be unaffected by replication.
slightly modifying the logic of .mask_ so that upon mask update it is properly synchronised with weight as if is were set for the first time
Option 1 would incurred a cost of .to call which would be significant, since .weight_masked method is called upon every forward pass. However this will make sure that the mask is always on the same device as the weight. Option 2 would not incur such overhead, since mask is set much less often, but could potentially desync the weight and mask.
I went with option 2, since torch users relocate models from device to host and back via .to, .cpu, .cuda API, which all automatically move buffers and parameters in sync. Thus this interaction seems unlikely to cause device desync. replicate does not use these methods, and instead modifies internal dictionaries of nn.Module directly, or through setattr. However, it does so in a depth-first fashion: first module's children are reassigned, then parameters and finally buffers. Therefore unless the replicate significantly changes its logic, it is safe to go with option 2, get the device from the already reassigned weight of the replica, and bypass costly .to() call.
When trying to use
*Masked
layers in mutli-gpu setting, I encountered an issue that parts of the model were not replicated on a GPU properly. The following codefailed with mismatched device exception
I inspected the key steps of
data_parallel
: replicate, scatter_kwargs, and parallel_apply. The following code revealed that replication was defective:returned
In the section responsible for associating copies of buffers with model replicas replicate.py#L161-L163 the following line does the assignment:
Since replica is an instance of
BaseMasked
base class, thesetattr
operation ultimately invokes__setattr__
method of this class. This method in turn dispatches all calls totorch.nn.Module
, unless the set attribute's name happens to be "mask". Manipulations with masks are handled by .mask_ special method..mask_
method does the necessary reallocation and cleaning up when masks are set, updated or removed. Specifically, sincereplica
is a copy of the originalnetwork
all its masks were properly set and located on device('cuda:0'). Thereforemask_
in fact ended up in the "mask update" branch:Now
.copy_
method oftensor.Tensor
only copies the data, moves between host and device, and casts if necessary. Therefore mask update ignored the intended device from the buffer copy and simply overwrote the mask with its copy and left the device the same.dtype
seems to be unaffected by replication.Possible solutions included:
.to(self.weight.device)
to masking action.mask_
so that upon mask update it is properly synchronised with weight as if is were set for the first timeOption 1 would incurred a cost of
.to
call which would be significant, since .weight_masked method is called upon every forward pass. However this will make sure that themask
is always on the same device as theweight
. Option 2 would not incur such overhead, since mask is set much less often, but could potentially desync theweight
andmask
.I went with option 2, since torch users relocate models from device to host and back via
.to
,.cpu
,.cuda
API, which all automatically move buffers and parameters in sync. Thus this interaction seems unlikely to cause device desync.replicate
does not use these methods, and instead modifies internal dictionaries ofnn.Module
directly, or throughsetattr
. However, it does so in a depth-first fashion: first module's children are reassigned, then parameters and finally buffers. Therefore unless the replicate significantly changes its logic, it is safe to go with option 2, get the device from the already reassignedweight
of the replica, and bypass costly.to()
call.