ROCm / MIOpen

AMD's Machine Intelligence Library
https://rocm.docs.amd.com/projects/MIOpen/en/latest/
Other
1.06k stars 221 forks source link

A100 vs MI250X conv performance #3310

Open etiennemlb opened 5 days ago

etiennemlb commented 5 days ago

I would like to inquire about the performance of two kernels: naive_conv_nonpacked_bwd_nchw_half_double_half naive_conv_nonpacked_fwd_nchw_half_double_half

When are these used when we call miopen_convolution_forward ? I have a pytorch model that is x6.4 times slower on MI250X compared to A100.

averinevg commented 4 days ago

Hi @etiennemlb, naive kernels are the last resort when none of the other kernels are applicable. Could you provide more information about the tensor sizes? It would also be useful to have a minimal reproducer.

formiel commented 10 hours ago

Hello @averinevg. Thanks a lot for your reply! I'm @etiennemlb's colleague. We have prepared a minimal code example to reproduce the speed differences we observed, along with profiling results for various configurations. These details are available via the link below. I’ve also included the code here for your convenience.

import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torch.profiler import profile, ProfilerActivity, record_function

class TransposeLast(nn.Module):
    def __init__(self, transpose_dim=-2):
        super().__init__()
        self.transpose_dim = transpose_dim
    def forward(self, x):
        return x.transpose(self.transpose_dim, -1)

class MyAudioModel(nn.Module):
    def __init__(self, num_conv=2):
        super().__init__()

        conv_layers = []
        # first conv layer
        conv_layers.append(nn.Conv1d(1, 512, 10, 5))
        conv_layers.append(TransposeLast())
        conv_layers.append(nn.LayerNorm(512, elementwise_affine=True))
        conv_layers.append(TransposeLast())
        conv_layers.append(nn.GELU())

        for _ in range(num_conv - 1):
            conv_layers.append(nn.Conv1d(512, 512, 3, 2))
            conv_layers.append(TransposeLast())
            conv_layers.append(nn.LayerNorm(512, elementwise_affine=True))
            conv_layers.append(TransposeLast())
            conv_layers.append(nn.GELU())
        self.conv_layers = nn.Sequential(*conv_layers)

        self.proj = nn.Sequential(
            TransposeLast(),
            nn.LayerNorm(512),
            nn.Linear(512, 64),
            TransposeLast(),
        )

    def forward(self, x):
        # BxT -> BxCxT
        x = x.unsqueeze(1)
        x = self.conv_layers(x)
        x = self.proj(x) # BxCxT -> BxTxD
        return torch.mean(x, dim=-1)

def main():

    # Main params
    fp16_training = True
    input_size = "8_320000"
    num_conv = 6
    device = "cuda"
    suffix = "_mi250x"
    epochs = 3

    B, T = [int(i) for i in input_size.split("_")]

    # Create dummy dataset and data loader
    x = torch.randn((B, T))
    y = torch.randn(B, 64)
    dataset = TensorDataset(x, y)
    data_loader = DataLoader(dataset, batch_size=8, shuffle=True)

    # Initialize model and optimizer
    model = MyAudioModel(num_conv=int(num_conv))
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    model.to(device=device)
    if fp16_training:
        model = model.half()

  # Profiling
    folder_name = f"conv{num_conv}L_input{input_size}{suffix}"
    profile_dir = f"{os.environ.get('WORK')}/profile-conv-ops/{folder_name}"
    os.makedirs(profile_dir, exist_ok=True)

    prof = profile(activities=[
                        ProfilerActivity.CPU, 
                        ProfilerActivity.CUDA
                    ],
                    on_trace_ready=torch.profiler.tensorboard_trace_handler(profile_dir),
                    record_shapes=True, profile_memory=True, with_flops=True,
                    with_stack=True,
                    experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True),
                )
    prof.start()

    for _ in range(epochs):
        for batch_x, batch_y in data_loader:
            batch_x, batch_y = batch_x.to(device=device), batch_y.to(device=device)
            if fp16_training:
                batch_x, batch_y = batch_x.half(), batch_y.half()
            with record_function("model_forward"):
                output = model(batch_x)

            # Compute loss
            loss = criterion(output, batch_y)

            # Backward pass
            with record_function("backward_pass"):
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

    prof.stop()

We found that the MI250x is around 7 times slower than the A100 when using 6 CNN layers for input tensors of shape (8, 320000). Disabling the direct convolution algorithm with export MIOPEN_DEBUG_CONV_DIRECT=0 prevents the naive_conv_packed kernels from being invoked, but we did not observe improvements in speed for our configuration by setting this global variable. Could you please help us take a look if there is anything that we can do to improve the training speed?

Many thanks in advance for your help!