Open ruisearch42 opened 1 month ago
Another data point is that Pytorch has a number of other data types that aren't natively supported in numpy (like float8, etc). I'm not particularly sure of the details for which tensor
was converted to a numpy array here but it would be best to not assume any interoperability with numpy and keep tensors as tensors during serialization.
Hmm really what we need is to be able to zero-copy deserialize torch.Tensors. Using numpy was easiest at the time but this is indeed an issue. Maybe arrow is an option?
Hmm actually seems like this could be supported by using torch.Tensor.view with a uint8 dtype:
In [99]: t
Out[99]:
tensor([ 1.8750, -0.6875, -1.1250, -1.3750, 1.3750, -1.1250, 0.4688, -0.4062,
0.8750, -1.7500], dtype=torch.float8_e4m3fn)
In [104]: torch.as_tensor(t.view(torch.uint8).numpy()).view(torch.float8_e4m3fn)
Out[104]:
tensor([ 1.8750, -0.6875, -1.1250, -1.3750, 1.3750, -1.1250, 0.4688, -0.4062,
0.8750, -1.7500], dtype=torch.float8_e4m3fn)
Hi @stephanie-wang,
I am interested in taking up this issue.
Hi @stephanie-wang,
I am interested in taking up this issue.
Thanks, feel free to open a PR. Will assign it to you for now and revisit in a week or so.
Thanks, do you have suggestions on reproducing the issue on my local?
You want a DAG that looks something like:
t = A.return_tensor(...).with_type_hint(TorchTensorType())
dag = B.read_tensor(t)
And A should return a tensor with a dtype unsupported by numpy.
I am new to the eco-system. I'm sorry if this is an obvious question. What do A
and B
represent in your previous comment? I tried creating an example using the "Getting Started" docs. However, I am struggling to reproduce the issue.
import ray
from ray import workflow
from ray.experimental.channel.torch_tensor_type import TorchTensorType
import torch
@ray.remote
def return_tensor():
return torch.Tensor([1.0, 2.0])
@ray.remote
def read_tensor(a):
return a
dag = read_tensor.bind(return_tensor.bind().with_type_hint(TorchTensorType()))
print(workflow.run(dag))
Am I on the right track? Also, if I try to specify the dtype
param in the torch.Tensor
init, the program fails with a weird torch data type error. I am not sure if that's related.
Hi @sahilgupta2105, yes we don't have great docs for compiled graphs yet. Please work through the developer guide first and comment here if you still have questions.
Thanks, the developer guide is helpful. I was able to reproduce the issue.
import ray
import ray.dag
from ray.experimental.channel.torch_tensor_type import TorchTensorType
import torch
@ray.remote
class Actor:
def process(self, tensor: torch.Tensor):
return tensor.shape
actor = Actor.remote()
with ray.dag.InputNode() as inp:
inp = inp.with_type_hint(TorchTensorType())
dag = actor.process.bind(inp)
dag = dag.experimental_compile()
print(ray.get(dag.execute(torch.zeros(10, dtype=torch.float8_e4m3fn))))
Hmm actually seems like this could be supported by using torch.Tensor.view with a uint8 dtype:
In [99]: t Out[99]: tensor([ 1.8750, -0.6875, -1.1250, -1.3750, 1.3750, -1.1250, 0.4688, -0.4062, 0.8750, -1.7500], dtype=torch.float8_e4m3fn) In [104]: torch.as_tensor(t.view(torch.uint8).numpy()).view(torch.float8_e4m3fn) Out[104]: tensor([ 1.8750, -0.6875, -1.1250, -1.3750, 1.3750, -1.1250, 0.4688, -0.4062, 0.8750, -1.7500], dtype=torch.float8_e4m3fn)
@stephanie-wang to be sure, you meant type-casting the torch tensor to the unsigned int type with the appropriate precision, right? E.g. float8 -> uint8, float16 -> uint16, ...
No, cast everything to uint8 (because it's a supported np dtype and the lowest possible precision) and then cast back on the receiving side.
Gotcha. For my understanding, wouldn't casting everything to uint8 cause a loss of information if the input tensor is of higher precision?
@stephanie-wang It seems like the casting operations only update the metadata, which ensures that information is not lost when floatXX
is converted to uint8
. However, to ensure data integrity, we need information about the dtype
of the original array. Numpy arrays support custom metadata. Shall we use that to store the dtype
on serialization for deserialization to ensure the correct format?
It would be better to not rely on a third-party API for passing the dtype. You can pass it through Ray instead. Check out how TorchTensorType is used.
I dug more into the code. Can you confirm my understanding before I send out a PR?
TorchTensorType
uses a serializer from the serialization context that internally calls the .numpy()
.TorchTensorType
hint is unique per DAG node and contains the type in _dtype
attribute._dtype
attribute to maintain data integrity.If you agree so far, do I need special handling for the "AUTO" dtype?
I don't think you need to touch TorchTensorType at all. Should be enough to just modify the existing custom serialization function.
What happened + What you expected to happen
Got the following error when experimenting with OpenRLHF:
TypeError: Got unsupported ScalarType BFloat16 :
Versions / Dependencies
head
Reproduction script
will add later
Issue Severity
None