jacobkimmel / pytorch_convgru

Convolutional Gated Recurrent Units implemented in PyTorch
MIT License
191 stars 41 forks source link

nn.DataParellel does not work properly with this model #6

Open motazalfarraj opened 5 years ago

motazalfarraj commented 5 years ago

My code:

model = ConvGRU(model = ConvGRU(input_size=8, hidden_sizes=[32,64,16],
                  kernel_sizes=[3, 5, 3], n_layers=3)) 
model = torch.nn.DataParallel(
model, device_ids=range(torch.cuda.device_count()))
model = model.to(device) 

This is the error I get when using 2 GPUs with this model.

RuntimeError: Expected tensor for argument #1 'input' to have the same device as tensor for argument #2 'weight'; but device 1 does not equal 0 (while checking arguments for cudnn_convolution)

motazalfarraj commented 5 years ago

I tried this "hack" in the forward for ConvGRU:

upd_cell_hidden = nn.parallel.data_parallel(cell, (input_, cell_hidden), range(torch.cuda.device_count()))

it seems to be working for now. but I am not sure if this is the best way to do so

HamishCPratt commented 4 years ago

I had this same problem and found it was happening when the ConvGru was storing the ConvGruCells in this part of the code.

        for i in range(self.n_layers):
            if i == 0:
                input_dim = self.input_size
            else:
                input_dim = self.hidden_sizes[i-1]

            cell = ConvGRUCell(input_dim, self.hidden_sizes[i], self.kernel_sizes[i])
            name = 'ConvGRUCell_' + str(i).zfill(2)

            setattr(self, name, cell)
            cells.append(getattr(self, name))

        self.cells = cells

I got around it by changing the ConvGru to store the cells as a module list instead of a normal python array

        cells = []
        for i in range(self.n_layers):
            if i == 0:
                input_dim = self.input_size
            else:
                input_dim = self.hidden_sizes[i-1]
            cells.append(ConvGRUCell(input_dim, self.hidden_sizes[i], self.kernel_sizes[i]))

        self.cells = nn.ModuleList(cells)