pytorch-labs / float8_experimental

This repository contains the experimental PyTorch native float8 training UX
BSD 3-Clause "New" or "Revised" License
212 stars 20 forks source link

Float8Tensor + DTensor composability #202

Closed wanchaol closed 2 months ago

wanchaol commented 8 months ago

Summary

We would like for models to be able to compose Float8 and DTensor. Under the hood this faces some challenges related to the composition of tensor subclasses.

Work Items

Work items:

Instead we plan to use special case the Float8Tensor construction if the input is a DTensor. It Float8Tensor constructor needs to maintain the invariant that DTensor is always the outermost tensor. PR: https://github.com/pytorch-labs/float8_experimental/pull/224

Example

DTensor/Float8Tensor subclass composability ordering issue

linear = nn.Linear(..)

fp8_linear = swap_linear_with_float8_linear(linear, ...)

pre-runtime:
fp8_linear.weight = nn.Parameter(distribute_tensor(tensor, mesh, [Shard(0)])) # fp32

run-time:
1. activation/weight: torch.Tensor (fp32) -> DTensor.from_local(x, mesh, [Shard(0)])
2. DTensor(fp32) -> DTensor(Float8Tensor)
    2.1 tensor_to_scale(DTensor) -> amax_scale
    2.2 Float8Tensor.to_float8(DTensor) -> Float8Tensor(DTensor)!? this is wrong order

3. DTensor.redistribute(placements=[Replicate()]) -> allgather on FLoat8Tensor, -> Float8Tensor.__torch_dispatch__ -> c10d_functional.allgather.default
vkuzo commented 2 months ago

closing as this has been working for awhile!