Closed drisspg closed 3 weeks ago
On simple static weight and activation mlp I am seeing a copy_ error
swap_linear_with_float8_linear( static_fp8_mlp, Float8DynamicLinear, from_float_kwargs={"static_quantize_weight": True, "activation_scale": torch.tensor([1.0], device="cuda", dtype=torch.float32)}, ) print(f"out_static = {static_fp8_mlp(input_tensor)}") torch.save(static_fp8_mlp.state_dict(), "/home/drisspg/meta/scripts/fp8/saving/dumm_dict2.pt") static_load = torch.load("/home/drisspg/meta/scripts/fp8/saving/dumm_dict2.pt") static_fp8_mlp.load_state_dict(static_load) print(f"out_static_load = {static_load(input_tensor)}")
RuntimeError: Error(s) in loading state_dict for FeedForward: While copying the parameter named "w1.weight", whose dimensions in the model are torch.Size([14336, 4096]) and whose dimensions in the checkpoint are torch.Size([14336, 4096]), an exception occurred : ('attempting to run aten.copy_.default, this is not supported',). While copying the parameter named "w3.weight", whose dimensions in the model are torch.Size([14336, 4096]) and whose dimensions in the checkpoint are torch.Size([14336, 4096]), an exception occurred : ('attempting to run aten.copy_.default, this is not supported',). While copying the parameter named "w2.weight", whose dimensions in the model are torch.Size([4096, 14336]) and whose dimensions in the checkpoint are torch.Size([4096, 14336]), an exception occurred : ('attempting to run aten.copy_.default, this is not supported',).
@ani300 I imagine you were doing some state_dict loading for Float8Tensors?
Hey, we were quantizing bf16 weights on the fly from our checkpoints, but I think we'll do something akin to AutoFP8 (https://github.com/neuralmagic/AutoFP8) to handle the FP8 checkpoints and load into Float8Tensors
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.
@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.
@drisspg merged this pull request in pytorch-labs/float8_experimental@36405a7781bb0cebf323dfb79cac336a882c908b.
Summary
Perf script:
https://gist.github.com/drisspg/f7a553710d64cce013227a2249d582d2
Performance
In eager this produces:
UX
Dynamic activation quantization
Static activation quantization
Weight Only quantization
All of these are using Per-Tensor scaling will add in a follow up PR row-wise scaling and likely make this the default.