import torch
from torch import nn
from torchao.dtypes.nf4tensor import to_nf4
x = torch.randn(1024, 1024)
x_nf4 = to_nf4(x)
print(x_nf4.cuda()) # this will dequantize NF4 -> unwanted
print(x_nf4.to(device="cuda")) # this will raise error
print(x_nf4.to("cuda")) # this will do the right thing
# .cpu() does not move .nf4 to CPU, because call_from_inner_tensors does not call the method on .nf4
x = torch.randn(1024, 1024).cuda()
x_nf4 = to_nf4(x).cpu()
print(x_nf4.quantized_data.device) # cpu
print(x_nf4.nf4.device) # cuda:0
print(x_nf4.to(torch.float32)) # error due to device mismatch
# not working with nn.Module
linear = nn.Linear(1024, 1024)
linear.weight = nn.Parameter(to_nf4(linear.weight.detach()), requires_grad=False)
linear.cuda() # NF4 weight is not moved to CUDA
# linear.to("cuda") # same problem
print(linear.weight.device) # cuda:0
print(linear.weight.quantized_data.device) # cpu
print(linear.weight.to(torch.float32).device) # cpu
Summary:
NF4Tensor.cuda() will dequantize -> this is unwanted
NF4Tensor.to(device="cuda") will raise IndexError, since args[1] does not exist
NF4Tensor.cpu() does not move .nf4 attribute -> cannot dequantize
Does not work with nn.Module.to(device)
IMO, the semantics NF4Tensor.to(torch.float32) will dequantize is the culprit that causes these troubles + it is not consistent with AQT behavor. If .to(dtype) does not dequantize (only change appearance dtype), we only need to implement aten._to_copy instead of Tensor.cpu, Tensor.to and myriad of others. Though I understand this design is to make NF4 feels more like a true dtype.
I think it makes more sense to designate NF4Tensor.dequantize() as the method to dequantize the tensor (also consistent with plain Tensor behavior, though plain Tensor.dequantize() will always return FP32), instead of the current situation (NF4Tensor.dequantize() is a static method for lookup table, while NF4Tensor.get_original_weight() does dequant)
Changing this is BC, so we probably leave it as is.
Reproduction
Summary:
NF4Tensor.cuda()
will dequantize -> this is unwantedNF4Tensor.to(device="cuda")
will raiseIndexError
, sinceargs[1]
does not existNF4Tensor.cpu()
does not move.nf4
attribute -> cannot dequantizenn.Module.to(device)
NF4Tensor.to(torch.float32)
will dequantize is the culprit that causes these troubles + it is not consistent with AQT behavor. If.to(dtype)
does not dequantize (only change appearance dtype), we only need to implementaten._to_copy
instead ofTensor.cpu
,Tensor.to
and myriad of others. Though I understand this design is to make NF4 feels more like a true dtype.NF4Tensor.dequantize()
as the method to dequantize the tensor (also consistent with plain Tensor behavior, though plainTensor.dequantize()
will always return FP32), instead of the current situation (NF4Tensor.dequantize()
is a static method for lookup table, whileNF4Tensor.get_original_weight()
does dequant)