pytorch / functorch

functorch is JAX-like composable function transforms for PyTorch.
https://pytorch.org/functorch/
BSD 3-Clause "New" or "Revised" License
1.38k stars 102 forks source link

runtime performance gain on model ensembling #1061

Closed xuyxu closed 1 year ago

xuyxu commented 1 year ago

Hi,

After noticing this nice package from the release note of pytorch, we are making our efforts to include it into our repo Ensemble-Pytorch, a member of the pytorch ecosystem focusing on state-of-the-art ensemble methods.

Following the introduction on model ensembling, here is our code snippet on runtime benchmarking. The snippet trains 5 simple LeNet5 models on CIFAR-10, and checks the runtime on test_loader using functorch and the original forward method.

import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchensemble.voting import VotingClassifier
from torchensemble.utils.logging import set_logger

from functorch import vmap
from memory_profiler import profile

class LeNet5(nn.Module):
    def __init__(self):
        super(LeNet5, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(400, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 400)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# @profile
def functorch_inference(data_loader, fmodel, params, buffers, device):
    for idx, (data, target) in enumerate(data_loader):
        data = data.to(device)
        vmap(fmodel, in_dims=(0, 0, None))(params, buffers, data)

# @profile
def pytorch_inference(data_loader, model):
    for idx, (data, target) in enumerate(data_loader):
        data = data.to(model.device)
        model(data)

if __name__ == "__main__":

    # Hyper-parameters
    n_estimators = 5
    lr = 1e-3
    weight_decay = 5e-4
    epochs = 5
    n_trials = 10

    # Utils
    batch_size = 128
    data_dir = "../../Dataset/cifar"  # MODIFY THIS IF YOU WANT
    records = []
    torch.manual_seed(0)

    # Load data
    train_transformer = transforms.Compose(
        [
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            transforms.Normalize(
                (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
            ),
        ]
    )

    test_transformer = transforms.Compose(
        [
            transforms.ToTensor(),
            transforms.Normalize(
                (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
            ),
        ]
    )

    train_loader = DataLoader(
        datasets.CIFAR10(
            data_dir, train=True, download=True, transform=train_transformer
        ),
        batch_size=batch_size,
        shuffle=True,
    )

    test_loader = DataLoader(
        datasets.CIFAR10(data_dir, train=False, transform=test_transformer),
        batch_size=batch_size,
        shuffle=True,
    )

    logger = set_logger("functorch_benchmark", use_tb_logger=True)

    # VotingClassifier
    model = VotingClassifier(
        estimator=LeNet5, n_estimators=n_estimators, cuda=False
    )

    # Set the optimizer
    model.set_optimizer("Adam", lr=lr, weight_decay=weight_decay)

    # Training
    tic = time.time()
    model.fit(train_loader, epochs=epochs)
    toc = time.time()
    training_time = toc - tic

    fmodel, params, buffers = model.vectorize()  # Internally: fmodel, params, buffers = combine_state_for_ensemble(self.estimators_)

    tic = time.time()
    for _ in range(n_trials):
        functorch_inference(test_loader, fmodel, params, buffers, model.device)
    toc = time.time()
    print("functorch: {:.3f}s".format(toc - tic))

    tic = time.time()
    for _ in range(n_trials):
        pytorch_inference(test_loader, model)
    toc = time.time()
    print("pytorch: {:.3f}s".format(toc - tic))

The result is kind of strange:

The performance gain is marginal compared to the official document. I will appreciate it very much if anyone could tell me where goes wrong. Thanks!

zou3519 commented 1 year ago

Hey @xuyxu, it's cool to hear that you're integrating functorch with Ensemble-PyTorch.

The CPU results are expected. PyTorch's convolution kernels are not optimized for CPU; changing the kernel can lead to different performance characteristics. vmap ends up changing the convolution call into another convolution call that seems to be unfortunately slower.

For CUDA: what gpu are you benchmarking on? Is there an easy way for us to repro your results? Using your input sizes, and without the Ensemble-PyTorch library, I compared the vmap ensembling approach to a for-loop ensembling approach on an A100 GPU, and it looks to be significantly faster (https://gist.github.com/zou3519/98e69289ba28f80247039723d073ef07). Though I'm not completely sure this is what your code is doing under the hood.

(pt1.13) [0] rzou@a100-st-p4d24xlarge-55:~  $ python foo.py
<torch.utils.benchmark.utils.common.Measurement object at 0x7f386b623f70>
vmap_inference()
setup: from __main__ import vmap_inference
  860.66 us
  1 measurement, 1000 runs , 1 thread
<torch.utils.benchmark.utils.common.Measurement object at 0x7f386b6237c0>
forloop_inference()
setup: from __main__ import forloop_inference
  2.11 ms
  1 measurement, 1000 runs , 1 thread
xuyxu commented 1 year ago

Thanks @zou3519. We will first add the vectorize API, and stay tuned.