sksq96 / pytorch-summary

Model summary in PyTorch similar to `model.summary()` in Keras
MIT License
3.98k stars 412 forks source link

Dividing model over multiple gpus in pytorch. #178

Open sreenithakasarapu opened 2 years ago

sreenithakasarapu commented 2 years ago

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking arugment for argument target in method wrapper_nll_loss_forward)

from torchvision.models.resnet import ResNet, Bottleneck import torch.nn as nn num_classes = 2

class CNN(ResNet): def init(self, *args, *kwargs): super(CNN, self).init( Bottleneck, [3, 4, 6, 3], num_classes=num_classes, args, **kwargs)

    self.seq1 = nn.Sequential(
        self.conv1,
        self.bn1,
        self.relu,
        self.maxpool,

        self.layer1,
        self.layer2
    ).to('cuda:0')

    self.seq2 = nn.Sequential(
        self.layer3,
        self.layer4,
        self.avgpool,
    ).to('cuda:1')

    self.fc.to('cuda:1')

def forward(self, x):
    x = self.seq2(self.seq1(x).to('cuda:1'))
    return self.fc(x.view(x.size(0), -1))

from torch.autograd import Variable num_epochs = 10 def train(num_epochs, cnn, loaders):

cnn.train()

# Train the model
total_step = len(loaders['train'])

for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(loaders['train']):

        # gives batch data, normalize x when iterate train_loader
        b_x =  Variable(images).to('cuda:0')    # batch x
        b_y = Variable(labels).to('cuda:0')   # batch y
        output = cnn(images.to('cuda:0'))  

        loss = loss_func(output, b_y).to('cuda:1')  

        print(loss)

        # clear gradients for this training step   
        optimizer.zero_grad()           

        # backpropagation, compute gradients 
        loss.backward()    
        # apply gradients             
        optimizer.step()                

        if (i+1) % 100 == 0:
            print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))
            pass

        pass

    pass

train(num_epochs, cnn, loaders)

FraCorti commented 2 years ago

By chance were you able to solve this issue?