ray-project / ray

Ray is an AI compute engine. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
34.12k stars 5.79k forks source link

[core][compiled graphs] Support all torch.dtypes for tensors sent through shared memory channels #48141

Open ruisearch42 opened 1 month ago

ruisearch42 commented 1 month ago

What happened + What you expected to happen

Got the following error when experimenting with OpenRLHF:

TypeError: Got unsupported ScalarType BFloat16 :

(CriticModelRayActor pid=33730)   cpu_tensor = torch.from_numpy(np_array) [repeated 2x across cluster]
(RewardModelRayActor pid=33731) Traceback (most recent call last): [repeated 4x across cluster]
(RewardModelRayActor pid=33731) 2024-10-21 09:10:49 ERROR    Compiled DAG task exited with exception
(RewardModelRayActor pid=33731)   File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/experimental/channel/shared_memory_channel.py", line 469, in write [repeated 5x across cluster]
(RewardModelRayActor pid=33731)     serialized_value = self._worker.get_serialization_context().serialize(
(RewardModelRayActor pid=33731)                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(RewardModelRayActor pid=33731)   File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/experimental/channel/torch_tensor_type.py", line 108, in serialize [repeated 2x across cluster]
(RewardModelRayActor pid=33731)     return self._serialize_to_msgpack(value)
(RewardModelRayActor pid=33731)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(RewardModelRayActor pid=33731)   File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/_private/serialization.py", line 497, in _serialize_to_msgpack
(RewardModelRayActor pid=33731)     pickle5_serialized_object = self._serialize_to_pickle5(
(RewardModelRayActor pid=33731)                                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^
(RewardModelRayActor pid=33731)   File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/_private/serialization.py", line 439, in _serialize_to_pickle5 [repeated 2x across cluster]
(RewardModelRayActor pid=33731)     raise e
(RewardModelRayActor pid=33731)     inband = pickle.dumps(
(RewardModelRayActor pid=33731)            ^^^^^^^^^^^^^ [repeated 2x across cluster]
(RewardModelRayActor pid=33731)   File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/cloudpickle/cloudpickle.py", line 1479, in dumps
(RewardModelRayActor pid=33731)     cp.dump(obj)
(RewardModelRayActor pid=33731)   File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/cloudpickle/cloudpickle.py", line 1245, in dump
(RewardModelRayActor pid=33731)     return super().dump(obj)
(RewardModelRayActor pid=33731)            ^^^^^^^^^^^^^^^^^
(RewardModelRayActor pid=33731)   File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/_private/serialization.py", line 175, in _CloudPicklerReducer
(RewardModelRayActor pid=33731)     return custom_deserializer, (custom_serializer(obj),)
(RewardModelRayActor pid=33731)                                  ^^^^^^^^^^^^^^^^^^^^^^
(RewardModelRayActor pid=33731)     return ctx.serialization_context.serialize_tensor(t)
(RewardModelRayActor pid=33731)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(RewardModelRayActor pid=33731)   File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/experimental/channel/serialization_context.py", line 77, in serialize_tensor
(RewardModelRayActor pid=33731)     return self.serialize_to_numpy(tensor)
(RewardModelRayActor pid=33731)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(RewardModelRayActor pid=33731)   File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/experimental/channel/serialization_context.py", line 88, in serialize_to_numpy
(RewardModelRayActor pid=33731)     return tensor.numpy()
(RewardModelRayActor pid=33731)            ^^^^^^^^^^^^^^
(RewardModelRayActor pid=33731) TypeError: Got unsupported ScalarType BFloat16
(RewardModelRayActor pid=33731)   File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/dag/compiled_dag_node.py", line 118, in do_exec_tasks
(RewardModelRayActor pid=33731)     done = tasks[operation.exec_task_idx].exec_operation(
(RewardModelRayActor pid=33731)            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(RewardModelRayActor pid=33731)   File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/dag/compiled_dag_node.py", line 547, in exec_operation
(RewardModelRayActor pid=33731)     return self._write()
(RewardModelRayActor pid=33731)   File "/home/ray/anaconda3/lib/python3.11/site-packages/ray/dag/compiled_dag_node.py", line 517, in _write
(RewardModelRayActor pid=33731)     self.output_writer.write(output_val)
(RewardModelRayActor pid=33731)     channel.write(val_i, timeout)
(RewardModelRayActor pid=33731)     channel.write(value, timeout)
(RewardModelRayActor pid=33731)     self._buffers[self._next_write_index].write(value, timeout)

Versions / Dependencies

head

Reproduction script

will add later

Issue Severity

None

SumanthRH commented 1 month ago

Another data point is that Pytorch has a number of other data types that aren't natively supported in numpy (like float8, etc). I'm not particularly sure of the details for which tensor was converted to a numpy array here but it would be best to not assume any interoperability with numpy and keep tensors as tensors during serialization.

stephanie-wang commented 1 month ago

Hmm really what we need is to be able to zero-copy deserialize torch.Tensors. Using numpy was easiest at the time but this is indeed an issue. Maybe arrow is an option?

stephanie-wang commented 1 month ago

Hmm actually seems like this could be supported by using torch.Tensor.view with a uint8 dtype:

In [99]: t
Out[99]:
tensor([ 1.8750, -0.6875, -1.1250, -1.3750,  1.3750, -1.1250,  0.4688, -0.4062,
         0.8750, -1.7500], dtype=torch.float8_e4m3fn)

In [104]: torch.as_tensor(t.view(torch.uint8).numpy()).view(torch.float8_e4m3fn)
Out[104]:
tensor([ 1.8750, -0.6875, -1.1250, -1.3750,  1.3750, -1.1250,  0.4688, -0.4062,
         0.8750, -1.7500], dtype=torch.float8_e4m3fn)
sahilgupta2105 commented 4 weeks ago

Hi @stephanie-wang,

I am interested in taking up this issue.

stephanie-wang commented 4 weeks ago

Hi @stephanie-wang,

I am interested in taking up this issue.

Thanks, feel free to open a PR. Will assign it to you for now and revisit in a week or so.

sahilgupta2105 commented 4 weeks ago

Thanks, do you have suggestions on reproducing the issue on my local?

stephanie-wang commented 4 weeks ago

You want a DAG that looks something like:

t = A.return_tensor(...).with_type_hint(TorchTensorType())
dag = B.read_tensor(t)

And A should return a tensor with a dtype unsupported by numpy.

sahilgupta2105 commented 4 weeks ago

I am new to the eco-system. I'm sorry if this is an obvious question. What do A and B represent in your previous comment? I tried creating an example using the "Getting Started" docs. However, I am struggling to reproduce the issue.

import ray
from ray import workflow
from ray.experimental.channel.torch_tensor_type import TorchTensorType
import torch

@ray.remote
def return_tensor():
    return torch.Tensor([1.0, 2.0])

@ray.remote
def read_tensor(a):
    return a

dag = read_tensor.bind(return_tensor.bind().with_type_hint(TorchTensorType()))

print(workflow.run(dag))

Am I on the right track? Also, if I try to specify the dtype param in the torch.Tensor init, the program fails with a weird torch data type error. I am not sure if that's related.

stephanie-wang commented 4 weeks ago

Hi @sahilgupta2105, yes we don't have great docs for compiled graphs yet. Please work through the developer guide first and comment here if you still have questions.

sahilgupta2105 commented 4 weeks ago

Thanks, the developer guide is helpful. I was able to reproduce the issue.

import ray
import ray.dag
from ray.experimental.channel.torch_tensor_type import TorchTensorType
import torch

@ray.remote
class Actor:
  def process(self, tensor: torch.Tensor):
    return tensor.shape

actor = Actor.remote()

with ray.dag.InputNode() as inp:
  inp = inp.with_type_hint(TorchTensorType())
  dag = actor.process.bind(inp)

  dag = dag.experimental_compile()
  print(ray.get(dag.execute(torch.zeros(10, dtype=torch.float8_e4m3fn))))
sahilgupta2105 commented 4 weeks ago

Hmm actually seems like this could be supported by using torch.Tensor.view with a uint8 dtype:

In [99]: t
Out[99]:
tensor([ 1.8750, -0.6875, -1.1250, -1.3750,  1.3750, -1.1250,  0.4688, -0.4062,
         0.8750, -1.7500], dtype=torch.float8_e4m3fn)

In [104]: torch.as_tensor(t.view(torch.uint8).numpy()).view(torch.float8_e4m3fn)
Out[104]:
tensor([ 1.8750, -0.6875, -1.1250, -1.3750,  1.3750, -1.1250,  0.4688, -0.4062,
         0.8750, -1.7500], dtype=torch.float8_e4m3fn)

@stephanie-wang to be sure, you meant type-casting the torch tensor to the unsigned int type with the appropriate precision, right? E.g. float8 -> uint8, float16 -> uint16, ...

stephanie-wang commented 4 weeks ago

No, cast everything to uint8 (because it's a supported np dtype and the lowest possible precision) and then cast back on the receiving side.

sahilgupta2105 commented 4 weeks ago

Gotcha. For my understanding, wouldn't casting everything to uint8 cause a loss of information if the input tensor is of higher precision?

sahilgupta2105 commented 3 weeks ago

@stephanie-wang It seems like the casting operations only update the metadata, which ensures that information is not lost when floatXX is converted to uint8. However, to ensure data integrity, we need information about the dtype of the original array. Numpy arrays support custom metadata. Shall we use that to store the dtype on serialization for deserialization to ensure the correct format?

stephanie-wang commented 3 weeks ago

It would be better to not rely on a third-party API for passing the dtype. You can pass it through Ray instead. Check out how TorchTensorType is used.

sahilgupta2105 commented 3 weeks ago

I dug more into the code. Can you confirm my understanding before I send out a PR?

If you agree so far, do I need special handling for the "AUTO" dtype?

stephanie-wang commented 2 weeks ago

I don't think you need to touch TorchTensorType at all. Should be enough to just modify the existing custom serialization function.