pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
194 stars 18 forks source link

memory alignment issue in torch.compile mode #259

Open czmrand opened 2 months ago

czmrand commented 2 months ago

There appears to be an alignment issue when attempting to run with torch.compile.

This is reproduced with nightly build (2.4.0.dev20240511+cu121) and the simple script below. When printing the shape and stride of the internal vectors, we get: torch.Size([1040, 1040]) and (1, 1152) indicating that the memory allocator has padded to 128-alignment.

import torch
def train():
    from float8_experimental import config
    config.enable_amax_init = False  # only needed for autocast + compile + FSDP +  float8 delayed
    config.enable_pre_and_post_forward = False  # only needed for autocast + compile + FSDP +  float8 delayed

    from float8_experimental.float8_linear import Float8Linear
    from float8_experimental.float8_linear_utils import (
            swap_linear_with_float8_linear
        )
    model = torch.nn.Sequential(torch.nn.Flatten(), torch.nn.Linear(256 * 256 * 3, 1040),
                torch.nn.Linear(1040,1040)
                )
    swap_linear_with_float8_linear(model,Float8Linear)
    device = torch.device('cuda')
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = torch.nn.CrossEntropyLoss()
    model = model.to(device)
    model = torch.compile(model, fullgraph=True)
    model.train()
    inputs = torch.randn([16, 3, 256, 256], dtype=torch.float32).cuda()
    label = torch.randint(high=1040, size=[16], dtype=torch.int64).cuda()
    with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=True):
        outputs = model(inputs)
        loss = criterion(outputs, label)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

if __name__ == '__main__':
    train()