There appears to be an alignment issue when attempting to run with torch.compile.
This is reproduced with nightly build (2.4.0.dev20240511+cu121) and the simple script below.
When printing the shape and stride of the internal vectors, we get:
torch.Size([1040, 1040]) and (1, 1152) indicating that the memory allocator has padded to 128-alignment.
import torch
def train():
from float8_experimental import config
config.enable_amax_init = False # only needed for autocast + compile + FSDP + float8 delayed
config.enable_pre_and_post_forward = False # only needed for autocast + compile + FSDP + float8 delayed
from float8_experimental.float8_linear import Float8Linear
from float8_experimental.float8_linear_utils import (
swap_linear_with_float8_linear
)
model = torch.nn.Sequential(torch.nn.Flatten(), torch.nn.Linear(256 * 256 * 3, 1040),
torch.nn.Linear(1040,1040)
)
swap_linear_with_float8_linear(model,Float8Linear)
device = torch.device('cuda')
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
criterion = torch.nn.CrossEntropyLoss()
model = model.to(device)
model = torch.compile(model, fullgraph=True)
model.train()
inputs = torch.randn([16, 3, 256, 256], dtype=torch.float32).cuda()
label = torch.randint(high=1040, size=[16], dtype=torch.int64).cuda()
with torch.autocast(device_type='cuda', dtype=torch.bfloat16, enabled=True):
outputs = model(inputs)
loss = criterion(outputs, label)
optimizer.zero_grad(set_to_none=True)
loss.backward()
optimizer.step()
if __name__ == '__main__':
train()
There appears to be an alignment issue when attempting to run with torch.compile.
This is reproduced with nightly build (2.4.0.dev20240511+cu121) and the simple script below. When printing the shape and stride of the internal vectors, we get: torch.Size([1040, 1040]) and (1, 1152) indicating that the memory allocator has padded to 128-alignment.