pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
194 stars 18 forks source link

Float8Tensor.to_original_precision() returns wrong dtype #292

Open ani300 opened 2 weeks ago

ani300 commented 2 weeks ago

For some reason, the unit tests that test for this particular issue work, but when I try these changes on (fp468) I get the following output:

@torch._dynamo.allow_in_graph
class FromFloat8ConstrFunc(torch.autograd.Function):
    """
    A differentiable conversion from fp8.
    * forward: convert from float8 to high precision
    * backward: pass the gradient without changes
    """

    @staticmethod
    def forward(ctx, tensor):
        print(tensor._data.to(tensor._orig_dtype).dtype, tensor._scale.dtype, (tensor._data.to(tensor._orig_dtype) / tensor._scale).dtype)
        return tensor._data.to(tensor._orig_dtype) / tensor._scale

    @staticmethod
    def backward(ctx, g):
        return g, None, None
torch.float16 torch.float32 torch.float32

Clearly, the scale division is upscaling to float32 as scales must always be in float32, but I don't know if the fix is to downcast the scale first (fast and much more memory efficient, wrong result?) vs downcasting the results (much slower, giant memory copy, correct? results).