eth-easl / orion

An interference-aware scheduler for fine-grained GPU sharing
MIT License
77 stars 11 forks source link

Reproducing Table 2 from Orion Paper #35

Open atomicapple0 opened 3 weeks ago

atomicapple0 commented 3 weeks ago

I am trying to reproduce the numbers from the conv/bnorm toy benchmark from the Orion paper . I saw some code provided here but did not see a script to run bnorm and conv in parallel on different streams. I rewrote this benchmark in the following script and reran on H100. I got the following results and didn't see any significant speedup from running in parallel. Any advise?

(.venv) ubuntu@209-20-159-95:~/conv_bnorm$ python3 main.py 
------------------------------
solo_conv: 0.42575 +- 0.00234 ms
solo_bnorm: 0.14967 +- 0.00172 ms
------------------------------
conv_conv_seq 0.80283 +- 0.00713 ms
conv_conv_par 0.77428 +- 0.01028 ms
speedup: 1.04x
------------------------------
bnorm_bnorm_seq 0.25994 +- 0.00549 ms
bnorm_bnorm_par 0.25029 +- 0.00464 ms
speedup: 1.04x
------------------------------
conv_bnorm_seq 0.53533 +- 0.00751 ms
conv_bnorm_par 0.51155 +- 0.01055 ms
speedup: 1.05x
------------------------------
bnorm_conv_seq 0.53632 +- 0.00601 ms
bnorm_conv_par 0.52873 +- 0.00558 ms
speedup: 1.01x
------------------------------

Source:

import os
from platform import node
import sched
import sys
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torchvision import models, datasets, transforms
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn.functional as F
from torch.multiprocessing import Pool, Process, set_start_method, Manager, Value, Lock
from datetime import timedelta
import random
import numpy as np
import time
import os
import argparse
import threading

class Conv(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)

    def forward(self, x):
        y = self.conv(x)

class Bnorm(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.bn = torch.nn.BatchNorm2d(64)

    def forward(self, x):
        x = self.bn(x)

bs = 32
conv_model = Conv().cuda().eval()
bnorm_model = Bnorm().cuda().eval()
conv_data = torch.rand([bs,3,224,224]).cuda().contiguous()
bnorm_data = torch.rand([bs,64,122,122]).cuda().contiguous()
conv_datas = [conv_data.clone() for _ in range(2)]
bnorm_datas = [bnorm_data.clone() for _ in range(2)]

def run_conv(curr_stream, data_idx=0):
    with torch.no_grad():
        with torch.cuda.stream(curr_stream):
            output = conv_model(conv_datas[data_idx])

def run_bnorm(curr_stream, data_idx=0):
    with torch.no_grad():
        with torch.cuda.stream(curr_stream):
            output = bnorm_model(bnorm_datas[data_idx])

stream = torch.cuda.Stream()
streamA = torch.cuda.Stream()
streamB = torch.cuda.Stream()

event1 = torch.cuda.Event(enable_timing=True)
event2 = torch.cuda.Event(enable_timing=True)
event3 = torch.cuda.Event(enable_timing=True)
event4 = torch.cuda.Event(enable_timing=True)

def solo(fun):
    torch.cuda.synchronize()

    fun(streamA)
    event2.record(streamA)

    event1.record(stream)
    event2.wait(stream=stream)
    event4.record(stream)

    torch.cuda.synchronize()
    return event1.elapsed_time(event4)

def seq(funA, funB):
    torch.cuda.synchronize()

    funA(streamA)
    funB(streamA, data_idx=1)
    event2.record(streamA)

    event1.record(stream)
    event2.wait(stream=stream)
    event4.record(stream)

    torch.cuda.synchronize()
    return event1.elapsed_time(event4)

def par(funA, funB):
    torch.cuda.synchronize()

    funA(streamA)
    funB(streamB, data_idx=1)
    event2.record(streamA)
    event3.record(streamB)

    event1.record(stream)
    event2.wait(stream=stream)
    event3.wait(stream=stream)
    event4.record(stream)

    torch.cuda.synchronize()
    return event1.elapsed_time(event4)

def timeit(fun):
    times = []
    for i in range(100):
        times.append(fun())
    times = times[10:]
    avg = np.mean(times)
    std = np.std(times)
    print(f'{avg:.5f} +- {std:.5f} ms')
    return avg

def speedup(seq_time, par_time):
    print(f'speedup: {seq_time/par_time:.2f}x')

print('------------------------------')
print(f'solo_conv: ', end='')
timeit(lambda: solo(run_conv))
print(f'solo_bnorm: ', end='')
timeit(lambda: solo(run_bnorm))
print('------------------------------')
print(f'conv_conv_seq ', end='')
seq_time = timeit(lambda: seq(run_conv, run_conv))
print(f'conv_conv_par ', end='')
par_time = timeit(lambda: par(run_conv, run_conv))
speedup(seq_time, par_time)
print('------------------------------')
print(f'bnorm_bnorm_seq ', end='')
seq_time = timeit(lambda: seq(run_bnorm, run_bnorm))
print(f'bnorm_bnorm_par ', end='')
par_time = timeit(lambda: par(run_bnorm, run_bnorm))
speedup(seq_time, par_time)
print('------------------------------')
print(f'conv_bnorm_seq ', end='')
seq_time = timeit(lambda: seq(run_conv, run_bnorm))
print(f'conv_bnorm_par ', end='')
par_time = timeit(lambda: par(run_conv, run_bnorm))
speedup(seq_time, par_time)
print('------------------------------')
print(f'bnorm_conv_seq ', end='')
seq_time = timeit(lambda: seq(run_bnorm, run_conv))
print(f'bnorm_conv_par ', end='')
par_time = timeit(lambda: par(run_bnorm, run_conv))
speedup(seq_time, par_time)
print('------------------------------')
fotstrt commented 3 weeks ago

Hello!

I have not tried this example on an H100 GPU, only on a V100. The problem with the code snippet (i think), is that there are a lot of torch.cuda.synchronize calls which add overhead in kernel launching, and my guess is that the kernels will never end up executing in parallel.

Now, they way that i recommend getting the timings is by removing the torch.cuda.synchronize , and use the NVIDIA Nsight Systems tool to see what actually is happening in the GPU.

For example, i adapted your snippet to:

import os
from platform import node
import sched
import sys
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
from torchvision import models, datasets, transforms
from torch.nn.parallel import DistributedDataParallel as DDP
import torch.nn.functional as F
from torch.multiprocessing import Pool, Process, set_start_method, Manager, Value, Lock
from datetime import timedelta
import random
import numpy as np
import time
import os
import argparse
import threading

class Conv(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = torch.nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)

    def forward(self, x):
        y = self.conv(x)
        return y

class Bnorm(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.bn = torch.nn.BatchNorm2d(64)

    def forward(self, x):
        x = self.bn(x)
        return x

conv_model = Conv().cuda()
bnorm_model = Bnorm().cuda()
conv_data = torch.rand([64,3,224,224]).cuda().contiguous()
bnorm_data = torch.rand([32,64,112,112]).cuda().contiguous()
conv_datas = [conv_data.clone() for _ in range(2)]
bnorm_datas = [bnorm_data.clone() for _ in range(2)]

stream = torch.cuda.Stream()
streamA = torch.cuda.Stream()
streamB = torch.cuda.Stream()

def my_test(s1, s2):
    torch.cuda.synchronize()
    for i in range(100):
        if i==10:
            torch.cuda.synchronize()
            start = time.time()

        with torch.cuda.stream(s1):
            output_b1 = bnorm_model(bnorm_data) # bn

        with torch.cuda.stream(s1):
            output_c1 = conv_model(conv_data) # conv
            #print(output_c1.shape)

    torch.cuda.synchronize()
    end = time.time()
    print(f"It took {(end-start)*1000} ms")

torch.cuda.profiler.cudart().cudaProfilerStart()
my_test(streamA, streamB)
torch.cuda.profiler.cudart().cudaProfilerStop()

and run it on a V100 GPU (with cuda 11.6) and got 1.2-1.3x overall speedup compared to using the same stream for the two kernels. To verify that the kernels, indeed, run together, i profiled with the Nvidia NSight Systems tool, running:

nsys profile -w true -t cuda,nvtx,osrt,cudnn,cublas -s none -o output_nsys --cudabacktrace=true --capture-range=cudaProfilerApi --stop-on-range-end=true -f true -x true python3 test1.py

and then checked the trace, and saw:

Screenshot from 2024-06-17 20-25-01

which means the kernels actually are scheduled together.

Now, if i try to schedule the two convolution kernels together (using the same script, but both streams run conv) i see the following trace:

Screenshot from 2024-06-17 20-27-49

meaning that the kernels are serialized

So i would recommend using the nsys tool to see what happens. Unfortunately, i do not have access to an H100 GPU, to see exactly what happens.

I hope this helped, and please let me know if anything else is needed!