For some reason, the unit tests that test for this particular issue work, but when I try these changes on (fp468) I get the following output:
@torch._dynamo.allow_in_graph
class FromFloat8ConstrFunc(torch.autograd.Function):
"""
A differentiable conversion from fp8.
* forward: convert from float8 to high precision
* backward: pass the gradient without changes
"""
@staticmethod
def forward(ctx, tensor):
print(tensor._data.to(tensor._orig_dtype).dtype, tensor._scale.dtype, (tensor._data.to(tensor._orig_dtype) / tensor._scale).dtype)
return tensor._data.to(tensor._orig_dtype) / tensor._scale
@staticmethod
def backward(ctx, g):
return g, None, None
torch.float16 torch.float32 torch.float32
Clearly, the scale division is upscaling to float32 as scales must always be in float32, but I don't know if the fix is to downcast the scale first (fast and much more memory efficient, wrong result?) vs downcasting the results (much slower, giant memory copy, correct? results).
For some reason, the unit tests that test for this particular issue work, but when I try these changes on (fp468) I get the following output:
Clearly, the scale division is upscaling to float32 as scales must always be in float32, but I don't know if the fix is to downcast the scale first (fast and much more memory efficient, wrong result?) vs downcasting the results (much slower, giant memory copy, correct? results).