pytorch / torchdynamo

A Python-level JIT compiler designed to make unmodified PyTorch programs faster.
BSD 3-Clause "New" or "Revised" License
1.01k stars 123 forks source link

Problems of using the acceleration of torchdynamo for Resnet when training #694

Closed TraceCS closed 2 years ago

TraceCS commented 2 years ago

We are encountering problems of using torchdynamo to accelerate the training process, our model and resnet are very alike, so we first ran tests on resnet152. here is the testing code.

import torch
import torch.nn as nn
import torchdynamo
import torchvision.models as models
import datetime
device = torch.device("cuda")
# class Res50(nn.Module):
#     def __init__(self) -> None:
#         super(Res50, self).__init__()
#         self.model = models.resnet50(pretrained = False)
#         self.model.to(device)

#     def forward(self, x):
#         return self.model(x)
def avg_list(lis):
    return sum(lis)/len(lis)

class Res152(nn.Module):
    def __init__(self, input_sz=(3,224,224)) -> None:
        super(Res152, self).__init__()
        self.input_sz = input_sz
        self.model = models.resnet152()
        # from torchstat import stat
        # stat(self.model, input_sz)
        self.model.to(device)

    def forward(self, x):
        return self.model(x)

def train(Net=Res152, iters = 10, bsz = 32, input_sz=(3,224,224), use_dynamo=False):
    inference_time_lis = []
    bsz = bsz
    loss_fcn = nn.CrossEntropyLoss()

    if(use_dynamo==False):
        resnet = Net()
        for _ in range(iters):
            st_time = datetime.datetime.now()

            bdata = torch.randn((bsz,)+input_sz, device=device)
            labels = torch.ones(bsz, dtype=torch.long, device=device)
            output = resnet(bdata)
            loss = loss_fcn(output, labels)
            loss.backward()

            ed_time = datetime.datetime.now()
            inference_time_lis.append((ed_time-st_time))
    else:
        resnet = Net()
        # torchdynamo.config.debug = True
        with torchdynamo.optimize("aot_nvfuser", nopython=True):
            for _ in range(iters):
                st_time = datetime.datetime.now()

                bdata = torch.randn((bsz,)+input_sz, device=device)
                labels = torch.ones(bsz, dtype=torch.long, device=device)
                output = resnet(bdata)
                loss = loss_fcn(output, labels)
                loss.backward()

                ed_time = datetime.datetime.now()
                inference_time_lis.append((ed_time-st_time))

    return inference_time_lis

def run_test(iters=1, bsz=8, Net=Res152):
    iters = iters
    bsz = bsz
    time_lis = train(Net=Net,iters=iters, bsz=bsz, use_dynamo=False)
    time_lis_tdy = train(Net=Net, iters=iters, bsz=bsz, use_dynamo=True)
    assert(len(time_lis) == len(time_lis_tdy))
    t_mic = 0
    t_mic_tdy = 0
    st_idx = 0
    for i in range(st_idx, len(time_lis)):
        t_mic += time_lis[i].microseconds
        t_mic_tdy += time_lis_tdy[i].microseconds

    print("iters={}  bsz={}  avg time compare:{:.2f} -- (dynamo){:.2f}   ratio:{:.4f}".format(
        iters, bsz, t_mic/iters, t_mic_tdy/iters, ((t_mic/iters)/(t_mic_tdy/iters))
    ))

run_test(iters=20, bsz=16)
run_test(iters=50, bsz=16)
run_test(iters=100, bsz=16)
run_test(iters=200, bsz=16)

The results are: iters=20 bsz=16 avg time compare:209189.15 -- (dynamo)240762.80 ratio:0.8689 iters=50 bsz=16 avg time compare:189001.68 -- (dynamo)198850.06 ratio:0.9505 iters=100 bsz=16 avg time compare:189400.90 -- (dynamo)204025.96 ratio:0.9283 iters=200 bsz=16 avg time compare:191534.67 -- (dynamo)195713.67 ratio:0.9786

The torch and corresponding modules we use are: torch 1.13.0.dev20220801+cu113 torch-struct 0.5 torchaudio 0.12.0 torchdynamo 1.13.0.dev0 /data/dev/torchdynamo torchfile 0.1.0 torchmetrics 0.9.3 torchrec-nightly 2022.8.1 torchstat 0.0.7 torchtext 0.13.0 torchvision 0.13.0 torchx-nightly 2022.8.1

we ran this test on docker with NVIDIA TITAN Xp Are there any mistakes in our usage ? We want to know if we are not using it in a proper way or how to use torchdynamo to speed up training process.

anijain2305 commented 2 years ago

Hi @TraceCS My numbers look like these

iters=20  bsz=16  avg time compare:83546.40 -- (dynamo)91269.50   ratio:0.9154
iters=50  bsz=16  avg time compare:65139.68 -- (dynamo)61867.60   ratio:1.0529
iters=100  bsz=16  avg time compare:64421.50 -- (dynamo)54427.80   ratio:1.1836
iters=200  bsz=16  avg time compare:64340.89 -- (dynamo)50062.75   ratio:1.2852

I am not sure about what the docker does. This is also on A100 GPUs.

TraceCS commented 2 years ago

Thanks for the followup and reply, this seems the print-out of the code I put, so maybe the difference is because my gpu is TITAN Xp?

---Original--- From: "Animesh @.> Date: Fri, Aug 12, 2022 07:58 AM To: @.>; Cc: @.**@.>; Subject: Re: [pytorch/torchdynamo] Problems of using the acceleration oftorchdynamo for Resnet when training (Issue #694)

Hi @TraceCS My numbers look like these iters=20 bsz=16 avg time compare:83546.40 -- (dynamo)91269.50 ratio:0.9154 iters=50 bsz=16 avg time compare:65139.68 -- (dynamo)61867.60 ratio:1.0529 iters=100 bsz=16 avg time compare:64421.50 -- (dynamo)54427.80 ratio:1.1836 iters=200 bsz=16 avg time compare:64340.89 -- (dynamo)50062.75 ratio:1.2852
I am not sure about what the docker does. This is also on A100 GPUs.

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you were mentioned.Message ID: @.***>

ezyang commented 2 years ago

The backends are targeted at Volta and later GPUs, so yeah, if you do get improvements on TITAN Xp it would be a pleasant surprise, but we are not specifically targeting it.

Jack47 commented 2 years ago

@TraceCS so may we close this issue?

TraceCS commented 2 years ago

sorry for the delay and yes, thanks for asking.

---Original--- From: "Jack @.> Date: Wed, Aug 31, 2022 12:51 PM To: @.>; Cc: @.**@.>; Subject: Re: [pytorch/torchdynamo] Problems of using the acceleration oftorchdynamo for Resnet when training (Issue #694)

@TraceCS so may we close this issue?

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you were mentioned.Message ID: @.***>