We would like for models to be able to compose Float8 and DTensor. Under the hood this faces some challenges related to the composition of tensor subclasses.
Work Items
Work items:
[ ] Convert casting into an op that Dtensor can override. This make sense we should do.
This option was worked on here: https://github.com/pytorch-labs/float8_experimental/pull/218. However it highlighted some issues with this. The operations we are trying to convert to an op produce a Float8Tensor subclass. This means that the Meta Impl and FakeTensor needs to be aware of this subclass. This could be done by utilizing private FakeTensor APIs however there will likely be some problems with AOTAutograd/Inductor.
What breaks:
NestedTensor analgous problem, wants a constructor that takes in a plain tensor and produces a NestedTensor. One way is that this all happens during tracing and we can define a FakeTensor rule that is specialized for this one special factory op. "This is the wrong thing to do" - Brian.
Above wont work because, tried to how tracing subclasses work today. The final graph for traceable_wrapper_subclasses removes all subclasses. ProxyModes propogates proxies from start of graph through functions. For TWS proxies are attatched to inner tensors and not the outer tensor. Since the proxy objects are on the inner tensor the proxymode will never actually see the outer wrapper subclass. If the constructor with faketensor rule were to exist, we call function and the output will be a subclass with a proxy. We instead want the constructor to produce a float8tensor and that the inner components have the correct proxy tensors.
How is NT trying to solve this? No FakeTensor rule for "nested_from_buffer" for this op. Passing in extra args to NT rule instead of FakeTensor.
With Float8 torch_dispatch can intercept before proxy mode and before fake mode can run. This is what allows us to have the desugarded aten ops in the graph and with new proxies. Factory functions have a problem since they can't be intercepted.
One could envision a Float8Tensor Mode to intercept all Float8 constructor calls and then return the object.
Instead we plan to use special case the Float8Tensor construction if the input is a DTensor. It Float8Tensor constructor needs to maintain the invariant that DTensor is always the outermost tensor.
PR: https://github.com/pytorch-labs/float8_experimental/pull/224
Summary
We would like for models to be able to compose Float8 and DTensor. Under the hood this faces some challenges related to the composition of tensor subclasses.
Work Items
Work items:
[ ] Convert casting into an op that Dtensor can override. This make sense we should do. This option was worked on here: https://github.com/pytorch-labs/float8_experimental/pull/218. However it highlighted some issues with this. The operations we are trying to convert to an op produce a Float8Tensor subclass. This means that the Meta Impl and FakeTensor needs to be aware of this subclass. This could be done by utilizing private FakeTensor APIs however there will likely be some problems with AOTAutograd/Inductor.
What breaks:
One could envision a Float8Tensor Mode to intercept all Float8 constructor calls and then return the object.
Instead we plan to use special case the Float8Tensor construction if the input is a DTensor. It Float8Tensor constructor needs to maintain the invariant that DTensor is always the outermost tensor. PR: https://github.com/pytorch-labs/float8_experimental/pull/224
[ ] Hook re-ordering, This is still blocked the compile work, See this Issue: https://github.com/pytorch-labs/float8_experimental/issues/223.
[ ] Implement sharding strategy
[ ] Test TP/SP strategy E2E: See https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/parallel/style.py#L19 for more detail on the TP + SP strategies
Example
DTensor/Float8Tensor subclass composability ordering issue