pytorch / ao

PyTorch native quantization and sparsity for training and inference
BSD 3-Clause "New" or "Revised" License
1.6k stars 179 forks source link

[NF4] Various bugs in how NF4 handles `.to()` to move to a different device #1310

Closed gau-nernst closed 14 minutes ago

gau-nernst commented 1 week ago

Reproduction

import torch
from torch import nn
from torchao.dtypes.nf4tensor import to_nf4

x = torch.randn(1024, 1024)
x_nf4 = to_nf4(x)
print(x_nf4.cuda())  # this will dequantize NF4 -> unwanted
print(x_nf4.to(device="cuda"))  # this will raise error
print(x_nf4.to("cuda"))  # this will do the right thing

# .cpu() does not move .nf4 to CPU, because call_from_inner_tensors does not call the method on .nf4
x = torch.randn(1024, 1024).cuda()
x_nf4 = to_nf4(x).cpu()
print(x_nf4.quantized_data.device)  # cpu
print(x_nf4.nf4.device)  # cuda:0
print(x_nf4.to(torch.float32))  # error due to device mismatch

# not working with nn.Module
linear = nn.Linear(1024, 1024)
linear.weight = nn.Parameter(to_nf4(linear.weight.detach()), requires_grad=False)
linear.cuda()  # NF4 weight is not moved to CUDA
# linear.to("cuda")  # same problem

print(linear.weight.device)  # cuda:0
print(linear.weight.quantized_data.device)  # cpu
print(linear.weight.to(torch.float32).device)  # cpu

Summary:

  1. NF4Tensor.cuda() will dequantize -> this is unwanted
  2. NF4Tensor.to(device="cuda") will raise IndexError, since args[1] does not exist
  3. NF4Tensor.cpu() does not move .nf4 attribute -> cannot dequantize
  4. Does not work with nn.Module.to(device)