pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
194 stars 18 forks source link

Add a Float8LinearInference module to support static, dynamic, and wo quant #287

Closed drisspg closed 3 weeks ago

drisspg commented 1 month ago

Summary

Perf script:

https://gist.github.com/drisspg/f7a553710d64cce013227a2249d582d2

Performance

In eager this produces:

Operation Time (μs)
bf16 2667.9172
fp8_dynamic_activations 2494.7294
fp8_static_activations 2449.1784
fp8_weight_only_activations 4084.7190
With compile this produces: Operation Time (μs)
bf16 2547.1938
fp8_dynamic_activations 1542.0729
fp8_static_activations 1407.0310
fp8_weight_only_activations 2750.6369

UX

Dynamic activation quantization


original_mlp = FeedForward().to("cuda", dtype=dtype)
original_mlp.reset_parameters()

dynamic_fp8_mlp = copy.deepcopy(original_mlp)

quant_config = QuantConfig(ActivationCasting.DYNAMIC)
quantize_to_float8(dynamic_fp8_mlp, quant_config)

Static activation quantization

original_mlp = FeedForward().to("cuda", dtype=dtype)
original_mlp.reset_parameters()

static_fp8_mlp = copy.deepcopy(original_mlp)
quant_config = QuantConfig(
    ActivationCasting.STATIC,
    static_quantization_scale=torch.tensor(
        [1.0], device="cuda", dtype=torch.float32
    ),
)
quantize_to_float8(static_fp8_mlp, quant_config)

Weight Only quantization

  original_mlp = FeedForward().to("cuda", dtype=dtype)
  original_mlp.reset_parameters()

  wo_fp8_mlp = copy.deepcopy(original_mlp)
  quant_config = QuantConfig(ActivationCasting.WEIGHT_ONLY)
  quantize_to_float8(wo_fp8_mlp, quant_config)

All of these are using Per-Tensor scaling will add in a follow up PR row-wise scaling and likely make this the default.

ani300 commented 3 weeks ago

On simple static weight and activation mlp I am seeing a copy_ error

    swap_linear_with_float8_linear(
            static_fp8_mlp,
            Float8DynamicLinear,
            from_float_kwargs={"static_quantize_weight": True, "activation_scale": torch.tensor([1.0], device="cuda", dtype=torch.float32)},
    )

    print(f"out_static = {static_fp8_mlp(input_tensor)}")
    torch.save(static_fp8_mlp.state_dict(), "/home/drisspg/meta/scripts/fp8/saving/dumm_dict2.pt")

    static_load = torch.load("/home/drisspg/meta/scripts/fp8/saving/dumm_dict2.pt")
    static_fp8_mlp.load_state_dict(static_load)
    print(f"out_static_load = {static_load(input_tensor)}")
RuntimeError: Error(s) in loading state_dict for FeedForward:
       While copying the parameter named "w1.weight", whose dimensions in the model are torch.Size([14336, 4096]) and whose dimensions in the checkpoint are torch.Size([14336, 4096]), an exception occurred : ('attempting to run aten.copy_.default, this is not supported',).
       While copying the parameter named "w3.weight", whose dimensions in the model are torch.Size([14336, 4096]) and whose dimensions in the checkpoint are torch.Size([14336, 4096]), an exception occurred : ('attempting to run aten.copy_.default, this is not supported',).
       While copying the parameter named "w2.weight", whose dimensions in the model are torch.Size([4096, 14336]) and whose dimensions in the checkpoint are torch.Size([4096, 14336]), an exception occurred : ('attempting to run aten.copy_.default, this is not supported',).

@ani300 I imagine you were doing some state_dict loading for Float8Tensors?

Hey, we were quantizing bf16 weights on the fly from our checkpoints, but I think we'll do something akin to AutoFP8 (https://github.com/neuralmagic/AutoFP8) to handle the FP8 checkpoints and load into Float8Tensors

facebook-github-bot commented 3 weeks ago

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot commented 3 weeks ago

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot commented 3 weeks ago

@drisspg has imported this pull request. If you are a Meta employee, you can view this diff on Phabricator.

facebook-github-bot commented 3 weeks ago

@drisspg merged this pull request in pytorch-labs/float8_experimental@36405a7781bb0cebf323dfb79cac336a882c908b.