pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.46k stars 467 forks source link

Poor performance with 1 GPU? #8059

Open Patataman opened 5 days ago

Patataman commented 5 days ago

Hello, I am trying to evaluate the impact of XLA in our models but before that I want to be sure that I know how to adapt our code and execute XLA models without problem.

GPU: Nvidia 4090 GTX 24GB Cuda 12.2

$ pip freeze | grep torch
torch==2.4.0
torch-xla==2.4.0
torch_xla_cuda_plugin @ https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla_cuda_plugin-2.4.0-py3-none-any.whl#sha256=208085526f67739c2ea2ab15f1707935b2cfee7c1501116a524cfaa8d7b252d2
torchvision==0.19.0

I have been trying a simple model with MNIST

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xr
from tqdm import tqdm
import random
from torch_xla.amp import syncfree, GradScaler, autocast

import torch_xla.debug.metrics as met

def random_seed(seed_value, use_cuda):
    np.random.seed(seed_value) # cpu vars
    torch.manual_seed(seed_value) # cpu  vars
    random.seed(seed_value) # Python
    if use_cuda:
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value) # gpu vars
        torch.backends.cudnn.deterministic = True  #needed
        torch.backends.cudnn.benchmark = False

random_seed(42,True)

XLA = True

# Enable XLA SPMD execution mode.
# xr.use_spmd()
if XLA:
    device = xm.xla_device()
else:
    device = "cuda"

class ToyModel(nn.Module):
    def __init__(self):
        super(ToyModel, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 1)
        self.conv2 = nn.Conv2d(32, 64, 3, 1)
        self.dropout1 = nn.Dropout(0.25)
        self.dropout2 = nn.Dropout(0.5)
        self.fc1 = nn.Linear(9216, 128)
        self.fc2 = nn.Linear(128, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        output = F.log_softmax(x, dim=1)
        return output

model = ToyModel()
model.to(device)

transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = torchvision.datasets.MNIST(
    '.', train=True, download=True, transform=transform)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=32, shuffle=False
)

n_epochs = 10
criterion = torch.nn.MSELoss()
if XLA:
    optimizer = syncfree.SGD(model.parameters(), lr=0.1)  # torch_xla
else:
    optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

if XLA:
    scaler = GradScaler(use_zero_grad=True)  # torch_xla
else:
    scaler = torch.amp.GradScaler()

for epoch in tqdm(range(n_epochs)):
    xm.mark_step()
    for i, (images, labels) in tqdm(enumerate(train_loader), leave=False):
        if not XLA:
            optimizer.zero_grad()
        if i >= 2000:
            break
        images = images.to(device)
        labels = labels.to(device)
        # Forward pass
        if XLA:
            autoamp = autocast(device, dtype=torch.bfloat16)
        else:
            autoamp = torch.autocast(device)

        with autoamp:
            outputs = model(images)
            loss = F.nll_loss(outputs, labels)
        # Backward
        scaler.scale(loss).backward()
        if XLA:
            gradients = xm._fetch_gradients(optimizer)
            xm.all_reduce('sum', gradients, scale=1.0 / xr.world_size())
        scaler.step(optimizer)
        scaler.update()
        xm.mark_step()

    print(loss)

And I haven't see any performance improvement, at best the execution time is the same. I thought that maybe the model was being recompiled too many times or something, so I followed https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md

Metrics are

Metric: DeviceLockWait
  TotalSamples: 37520
  Accumulator: 113ms908.380us
  ValueRate: 475.217us / second
  Rate: 159.174 / second
  Percentiles: 1%=000.972us; 5%=000.989us; 10%=000.999us; 20%=001.010us; 50%=004.627us; 80%=004.978us; 90%=005.046us; 95%=005.112us; 99%=005.205us
Metric: InputOutputAliasCount
  TotalSamples: 2
  Accumulator: 42.00
  ValueRate: 21.95 / second
  Rate: 1.04547 / second
  Percentiles: 1%=8.00; 5%=8.00; 10%=8.00; 20%=8.00; 50%=34.00; 80%=34.00; 90%=34.00; 95%=34.00; 99%=34.00
Metric: IrValueTensorToXlaData
  TotalSamples: 37508
  Accumulator: 02s925ms072.075us
  ValueRate: 007ms438.792us / second
  Rate: 159.175 / second
  Percentiles: 1%=030.320us; 5%=030.752us; 10%=030.926us; 20%=031.205us; 50%=059.240us; 80%=061.600us; 90%=062.326us; 95%=062.728us; 99%=067.959us
Metric: LazyTracing
  TotalSamples: 3525066
  Accumulator: 46s352ms512.571us
  ValueRate: 216ms224.171us / second
  Rate: 15022.9 / second
  Percentiles: 1%=000.251us; 5%=000.278us; 10%=000.710us; 20%=001.712us; 50%=007.784us; 80%=023.540us; 90%=034.856us; 95%=039.848us; 99%=084.197us
Metric: TensorToData
  TotalSamples: 56264
  Accumulator: 03s688ms709.373us
  ValueRate: 011ms664.802us / second
  Rate: 238.546 / second
  Percentiles: 1%=028.756us; 5%=029.145us; 10%=029.361us; 20%=029.732us; 50%=044.907us; 80%=058.281us; 90%=059.128us; 95%=059.725us; 99%=062.806us
Metric: TensorsGraphSize
  TotalSamples: 18750
  Accumulator: 4087468.00
  ValueRate: 17318.15 / second
  Rate: 79.4411 / second
  Percentiles: 1%=218.00; 5%=218.00; 10%=218.00; 20%=218.00; 50%=218.00; 80%=218.00; 90%=218.00; 95%=218.00; 99%=218.00
Metric: UnwrapXlaData
  TotalSamples: 18760
  Accumulator: 200ms928.384us
  ValueRate: 842.498us / second
  Rate: 79.5082 / second
  Percentiles: 1%=007.854us; 5%=007.986us; 10%=008.205us; 20%=010.646us; 50%=010.826us; 80%=010.981us; 90%=011.084us; 95%=011.166us; 99%=013.599us
Metric: WrapXlaData
  TotalSamples: 18750
  Accumulator: 024ms220.583us
  ValueRate: 101.239us / second
  Rate: 79.4413 / second
  Percentiles: 1%=001.193us; 5%=001.213us; 10%=001.226us; 20%=001.240us; 50%=001.271us; 80%=001.306us; 90%=001.325us; 95%=001.341us; 99%=001.374us
Counter: CachedCompile
  Value: 18748
Counter: CreateXlaTensor
  Value: 1950018
Counter: DestroyLtcTensor
  Value: 1949972
Counter: DestroyXlaTensor
  Value: 1949972
Counter: DeviceDataCacheMiss
  Value: 6
Counter: ExecuteComputation
  Value: 18750
Counter: RegisterXLAFunctions
  Value: 1
Counter: UncachedCompile
  Value: 2
Counter: xla::_amp_foreach_non_finite_check_and_unscale_
  Value: 18750
Counter: xla::_amp_update_scale_
  Value: 18750
Counter: xla::_copy_from
  Value: 468758
Counter: xla::_log_softmax
  Value: 18750
Counter: xla::_log_softmax_backward_data
  Value: 18750
Counter: xla::_propagate_xla_data
  Value: 431242
Counter: xla::_to_copy
  Value: 468768
Counter: xla::add
  Value: 149992
Counter: xla::addmm
  Value: 37500
Counter: xla::bernoulli
  Value: 37500
Counter: xla::clone
  Value: 206258
Counter: xla::convolution_backward_overrideable
  Value: 37500
Counter: xla::convolution_overrideable
  Value: 37500
Counter: xla::detach_copy
  Value: 300016
Counter: xla::div
  Value: 37500
Counter: xla::empty_strided_symint
  Value: 18758
Counter: xla::empty_symint
  Value: 525016
Counter: xla::fill_
  Value: 18750
Counter: xla::full
  Value: 18752
Counter: xla::max_pool2d
  Value: 18750
Counter: xla::mm
  Value: 75000
Counter: xla::mul
  Value: 112500
Counter: xla::nll_loss_backward
  Value: 18750
Counter: xla::nll_loss_forward
  Value: 18750
Counter: xla::reciprocal
  Value: 18750
Counter: xla::relu
  Value: 56250
Counter: xla::stack
  Value: 18750
Counter: xla::sum
  Value: 56250
Counter: xla::t_copy
  Value: 150000
Counter: xla::threshold_backward
  Value: 56250
Counter: xla::view_copy_symint
  Value: 75000
Counter: xla::zero_
  Value: 8
Metric: CompileTime
  TotalSamples: 2
  Accumulator: 03s973ms956.974us
  ValueRate: 03s712ms926.593us / second
  Rate: 1.8244 / second
  Percentiles: 1%=01s078ms134.380us; 5%=01s078ms134.380us; 10%=01s078ms134.380us; 20%=01s078ms134.380us; 50%=02s895ms822.594us; 80%=02s895ms822.594us; 90%=02s895ms822.594us; 95%=02s895ms822.594us; 99%=02s895ms822.594us
Metric: ExecuteTime
  TotalSamples: 18750
  Accumulator: 30s004ms354.099us
  ValueRate: 127ms668.246us / second
  Rate: 79.4416 / second
  Percentiles: 1%=002ms523.522us; 5%=002ms535.358us; 10%=002ms549.469us; 20%=002ms583.945us; 50%=002ms593.995us; 80%=002ms614.090us; 90%=002ms625.599us; 95%=002ms635.458us; 99%=002ms657.749us
Metric: InboundData
  TotalSamples: 10
  Accumulator: 40.00B
  ValueRate: 0.19B / second
  Rate: 0.0471765 / second
  Percentiles: 1%=4.00B; 5%=4.00B; 10%=4.00B; 20%=4.00B; 50%=4.00B; 80%=4.00B; 90%=4.00B; 95%=4.00B; 99%=4.00B
Metric: OutboundData
  TotalSamples: 56264
  Accumulator: 1.76GB
  ValueRate: 7.62MB / second
  Rate: 238.546 / second
  Percentiles: 1%=8.00B; 5%=8.00B; 10%=8.00B; 20%=8.00B; 50%=256.00B; 80%=98.00KB; 90%=98.00KB; 95%=98.00KB; 99%=98.00KB
Metric: TransferFromDeviceTime
  TotalSamples: 10
  Accumulator: 087ms396.013us
  ValueRate: 412.304us / second
  Rate: 0.0471765 / second
  Percentiles: 1%=180.350us; 5%=180.350us; 10%=204.842us; 20%=226.256us; 50%=245.965us; 80%=261.532us; 90%=085ms309.387us; 95%=085ms309.387us; 99%=085ms309.387us
Metric: TransferToDeviceTime
  TotalSamples: 56264
  Accumulator: 02s635ms113.038us
  ValueRate: 006ms183.764us / second
  Rate: 238.546 / second
  Percentiles: 1%=017.354us; 5%=017.697us; 10%=017.852us; 20%=018.117us; 50%=028.886us; 80%=030.153us; 90%=030.375us; 95%=030.589us; 99%=035.744us
Counter: CreateCompileHandles
  Value: 2
Counter: CreateDataHandles
  Value: 881264
Counter: MarkStep
  Value: 18760

But as you can see the model was compiled very few times and there are no context switch.

Is this behavior expected when working with 1 GPU? Or XLA should give some improvement also in this case and not only with multiple devices? Maybe the model is too simple? I couldn't find info related to performance when working with 1 GPU

Thanks

JackCaoG commented 3 days ago

Try one of the example we have in https://github.com/pytorch/xla/tree/master/examples maybe. Mnist is kind of outdated. I would expect you to see difference with decoder only model and resnet.

Patataman commented 3 days ago

Hello, I tried MNIST because it was in one of the examples in https://github.com/pytorch/xla/blob/master/docs/amp.md I'll take a look to it asap