hpcaitech / ColossalAI

Making large AI models cheaper, faster and more accessible
https://www.colossalai.org
Apache License 2.0
38.7k stars 4.34k forks source link

[DOC]: How to compute loss and backpropagation in 2D tensor parallelism? #3664

Open gaylong9 opened 1 year ago

gaylong9 commented 1 year ago

📚 The doc issue

Description: I am trying to implement 2D tensor parallelism and I am having trouble computing loss and performing backpropagation. As shown in the documentation at https://colossalai.org/docs/features/2D_tensor_parallel/, the final output of the network has only one-quarter of the size of the regular output, which leads to size mismatch errors when computing the loss with the output and labels directly. For example, when using a dataset such as CIFAR10 with a batch size of 64, the label tensor has a size of (64). However, the output tensor has a size of (32, 5), resulting in a dimension mismatch error when calculating the loss using the output and label tensors.

Suggestions: I suggest that the document be improved to include the code for calculating the loss and performing backpropagation in the 2D tensor parallelism scenario.

Thank you!

JThh commented 1 year ago

Thanks for your feedback!

May I know the codes you ran? If you ran correctly, there should not be a dimensionality issue for gradient computations as well as backward operations.

gaylong9 commented 1 year ago
"""
torchrun --nproc_per_node 4 train.py
"""

CONFIG = dict(parallel=dict(
    data=1,
    pipeline=1,
    tensor=dict(size=4, mode='2d'),
))

class MLP(torch.nn.Module):
    def __init__(self, dim: int = 256):
        super().__init__()
        intermediate_dim = dim * 4
        self.dense_1 = col_nn.Linear(dim, intermediate_dim)
        print_rank_0(f'Weight of the first linear layer: {self.dense_1.weight.shape}')
        self.activation = torch.nn.GELU()
        self.dense_2 = col_nn.Linear(intermediate_dim, 10, bias=False)
        print_rank_0(f'Weight of the second linear layer: {self.dense_2.weight.shape}')

    def forward(self, x):
        print_rank_0(f'Input shape: {x.shape}')
        x = self.dense_1(x)
        print_rank_0(f'Output of the first linear layer: {x.shape}')
        x = self.activation(x)
        x = self.dense_2(x)
        print_rank_0(f'Output of the second linear layer: {x.shape}')
        return x

def train():
    colossalai.launch_from_torch(config=CONFIG)
    logger = get_dist_logger()

    input_size = 3 * 32 * 32
    model = MLP(dim = input_size)

    train_dataset = CIFAR10(
        root=Path('data/cifar10/'),
        download=True,
        transform=transforms.Compose([
            transforms.RandomCrop(size=32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.4914, 0.4822, 0.4465], 
                std=[0.2023, 0.1994, 0.2010]
            ),
            transforms.Lambda(lambda x: x.view(-1)),
        ])
    )
    train_dataloader = get_dataloader(...)
    criterion = nn.CrossEntropyLoss(reduction='sum')
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    lr_scheduler = CosineAnnealingWarmupLR(...)

    engine, train_dataloader, test_dataloader, _ = colossalai.initialize(...)

    for epoch in range(5):
        engine.train()

        if gpc.get_global_rank() == 0:
            trainLoader = tqdm(train_dataloader)
        else:
            trainLoader = train_dataloader

        for data, label in trainLoader:
            engine.zero_grad()
            data = data.cuda()
            # print_rank_0(f'raw data shape: {data.shape}')  # [bs, 3x32x32=3072]
            torch.distributed.broadcast(data, src=0)
            data = torch.chunk(data, 2, dim=0)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)]
            data = torch.chunk(data, 2, dim=-1)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)]
            # print_rank_0(f'broadcasted data shape: {data.shape}')  # [bs/2, 3072/2=1536]
            label = label.cuda()

            output = engine(data)
            # print_rank_0(f'output shape: {output.shape}')  # [bs/2, 10/2=5]

            train_loss = engine.criterion(output, label) 
            # print_rank_0(f'train_loss: {train_loss}')
            engine.backward(train_loss)
            engine.step()

    gpc.destroy()

Here is my code, which was written with reference to the official documentation. I am using the CIFAR10 dataset and dividing the images into 10 classes. Prior to training, I partition the data and the size changes from [bs, 3072] to [bs/2, 1536]. Without using TP, the output size should be [bs, 10], and here it is [bs/2, 5]. This leads to an error when computing the loss, as the output and label (with a size of [32]) dimensions do not match, resulting in a ValueError: "Expected input batch_size (16) to match target batch_size (32)".

Thank you for your help.