kakaobrain / torchgpipe

A GPipe implementation in PyTorch
BSD 3-Clause "New" or "Revised" License
798 stars 97 forks source link

The same batch size, different micro batches, the algorithm effects are inconsistent. #35

Open Kurama622 opened 5 months ago

Kurama622 commented 5 months ago

🐞 Bug

The same batch size, different micro batches, the algorithm effects are inconsistent.

I have fixed the random seed.

I set chunks equal to 2 or 4

Code that reproduces

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.nn import functional as F
from torchgpipe import GPipe
import random, os
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class SimpleDNN(nn.Module):
    def __init__(self):
        super(SimpleDNN, self).__init__()
        self.fc1 = nn.Linear(28*28, 512)  # assuming input images are 28x28
        self.fc2 = nn.Linear(512, 256)
        self.fc3 = nn.Linear(256, 10)

    def forward(self, x):
        x = x.view(-1, 28*28)  # flatten the image
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

transform = transforms.Compose([
    transforms.Normalize((0.1307,), (0.3081,))

train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)

# set random seed
seed = 0
os.environ['PYTHONHASHSEED'] = str(seed)

model = SimpleDNN().to(device)
model = nn.Sequential(
    nn.Linear(10, 10)

chunks = 2  # Assume you want to divide the model into chunks

model = GPipe(model, balance=[1, 1, 1], chunks=chunks, devices=[device] * 3)

optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

total = 0
correct = 0
for epoch in range(1):
    for batch_idx, (data, target) in enumerate(train_loader):
        if data.size(0) % chunks != 0:
            continue  # Skip batches that do not have the correct size
        data, target = data.to(device), target.to(device)
        output = model(data)
        loss = F.cross_entropy(output, target)
        _, predicted = torch.max(output, 1)
        total += target.size(0)
        correct += (predicted == target).sum().item()

        print(f'batch {batch_idx}, Accuracy: {100 * correct / total}%')
Kurama622 commented 5 months ago

I'm not sure if this is a bug. Maybe it’s the cumulative error caused by floating point operations, maybe there’s some difference between micro batch gradient update and mini batch.

ymkasad commented 4 months ago


Kurama622 commented 4 months ago

