pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
212 stars 20 forks source link

Docs should say what's the smallest model users will see a benefit for #280

Closed msaroufim closed 2 months ago

msaroufim commented 3 months ago

I was working on a minimal example to showcase the benefits of fp8 an H100 without forcing users to download a chunky model like here https://github.com/pytorch-labs/float8_experimental/issues/279

I guess it's expected that fp8 will be slower for tiny models because of overhead in which case we should say in docs what's the minimal model size people should try

Training time in FP16: 7.10 seconds
Training time in FP8: 9.80 seconds
import torch
import torch.nn as nn
import copy
from torch.cuda.amp import autocast
from float8_experimental.float8_linear_utils import swap_linear_with_float8_linear
from float8_experimental.float8_dynamic_linear import Float8DynamicLinear
torch.set_float32_matmul_precision('high')

class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.layer1 = nn.Linear(32, 32)
        self.layer2 = nn.Linear(32, 32)

    def forward(self, x):
        x = torch.relu(self.layer1(x))
        x = self.layer2(x)
        return x

def train(model, data_loader):
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters())
    model.train()

    for data, target in data_loader:
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()

def benchmark_training(model, data_loader, iterations=100, warmup_iterations=10):
    # Warm-up phase: Run a few iterations to get the GPU to a steady state
    model = torch.compile(model)
    for _ in range(warmup_iterations):
        train(model, data_loader)

    # Timing phase
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    torch.cuda.synchronize()  # Wait for all operations on the CUDA device to complete
    start_event.record()

    for _ in range(iterations):
        train(model, data_loader)

    end_event.record()
    torch.cuda.synchronize()  # Wait for the events to be recorded
    elapsed_time = start_event.elapsed_time(end_event) / 1000.0  # Convert milliseconds to seconds
    return elapsed_time

data_loader = [(torch.randn(32, 32, device="cuda"), torch.randn(32, 1, device="cuda")) for _ in range(110)]

# Initial model setup
base_model = Model().cuda()

# Training in fp16
model_fp16 = copy.deepcopy(base_model)
fp16_time = benchmark_training(model_fp16, data_loader)

# Training in fp8
model_fp8 = copy.deepcopy(base_model)
swap_linear_with_float8_linear(model_fp8, Float8DynamicLinear)
fp8_time = benchmark_training(model_fp8, data_loader)

print(f"Training time in FP16: {fp16_time:.2f} seconds")
print(f"Training time in FP8: {fp8_time:.2f} seconds")
vkuzo commented 3 months ago

Great idea, let's do it

vkuzo commented 2 months ago

https://github.com/pytorch/ao/issues/572