ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
31.97k stars 5.45k forks source link

[Ray component: Train] Ray can integrate with Lightning or XLA, but not both #45035

Open BrianF-tessera opened 2 months ago

BrianF-tessera commented 2 months ago

Description

Based on testing, Ray works well with AWS Trainium and Torch and Lightning and Trainium integrate well, but there is no possible integration with all three pieces. What would be needed for this to work is creation of RayDDPXLAStrategy() and RayFSDPXLAStrategy().

Use case

Use case would be easier productionization of XLA based training for end users who are primarily Lightning users rather than users who write raw Torch

BrianF-tessera commented 2 months ago

I took an initial shot at this in our sandbox environment and hit the following blockers:

class RayXLADDPStrategy(pl.pytorch.strategies.DDPStrategy):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        record_extra_usage_tag(TagKey.TRAIN_LIGHTNING_RAYDDPSTRATEGY, "1")

    @property
    def root_device(self) -> torch.device:
        device = xm.xla_device()
        return device

    @property
    def distributed_sampler_kwargs(self) -> Dict[str, Any]:
        return dict(
            num_replicas=self.world_size,
            rank=self.global_rank,
        )

and hit the following stacktrace:

TypeError: Could not serialize the put value <function train_func at 0x7f7016477250>:
===================================================================
Checking Serializability of <function train_func at 0x7f7016477250>
===================================================================
!!! FAIL serialization: cannot pickle 'google.protobuf.pyext._message.EnumDescriptor' object
Detected 6 global variables. Checking serializability...
    Serializing 'FashionMNIST' <class 'torchvision.datasets.mnist.FashionMNIST'>...
    Serializing 'transforms' <module 'torchvision.transforms' from '/home/ec2-user/miniconda3/envs/raytest/lib/python3.10/site-packages/torchvision/transforms/__init__.py'>...
    Serializing 'DataLoader' <class 'torch.utils.data.dataloader.DataLoader'>...
    Serializing 'MNISTClassifier' <class '__main__.MNISTClassifier'>...
    Serializing 'RayXLADDPStrategy' <class '__main__.RayXLADDPStrategy'>...
    !!! FAIL serialization: cannot pickle 'google.protobuf.pyext._message.EnumDescriptor' object
        Serializing '__getstate__' <function Strategy.__getstate__ at 0x7f6f8798d630>...
        Serializing '__init__' <function RayXLADDPStrategy.__init__ at 0x7f7016476e60>...
        !!! FAIL serialization: cannot pickle 'google.protobuf.pyext._message.EnumDescriptor' object
        Detected 2 global variables. Checking serializability...
            Serializing 'record_extra_usage_tag' <function record_extra_usage_tag at 0x7f7002d4e7a0>...
            Serializing 'TagKey' <google.protobuf.internal.enum_type_wrapper.EnumTypeWrapper object at 0x7f7002b78220>...
            !!! FAIL serialization: cannot pickle 'google.protobuf.pyext._message.EnumDescriptor' object
        Detected 1 nonlocal variables. Checking serializability...
            Serializing '__class__' <class '__main__.RayXLADDPStrategy'>...
            !!! FAIL serialization: cannot pickle 'google.protobuf.pyext._message.EnumDescriptor' object
        Serializing '_abc_impl' <_abc._abc_data object at 0x7f7016488700>...
        !!! FAIL serialization: cannot pickle '_abc._abc_data' object
        WARNING: Did not find non-serializable object in <_abc._abc_data object at 0x7f7016488700>. This may be an oversight.
===================================================================
Variable: 

    FailTuple(__class__ [obj=<class '__main__.RayXLADDPStrategy'>, parent=<function RayXLADDPStrategy.__init__ at 0x7f7016476e60>])
FailTuple(TagKey [obj=<google.protobuf.internal.enum_type_wrapper.EnumTypeWrapper object at 0x7f7002b78220>, parent=<function RayXLADDPStrategy.__init__ at 0x7f7016476e60>])

was found to be non-serializable. There may be multiple other undetected variables that were non-serializable. 
Consider either removing the instantiation/imports of these variables or moving the instantiation into the scope of the function/class. 
===================================================================
Check https://docs.ray.io/en/master/ray-core/objects/serialization.html#troubleshooting for more information.
If you have any suggestions on how to improve this error message, please reach out to the Ray developers on github.com/ray-project/ray/issues/
===================================================================

I'm now far enough outside my expertise that I'm not going to be very helpful

woshiyyya commented 1 month ago
!!! FAIL serialization: cannot pickle 'google.protobuf.pyext._message.EnumDescriptor' object

Hi @BrianF-tessera , it seems to be a serialization issue of the TagKey. You can try to remove the record_extra_usage_tag call, which is for reporting Ray telemetry metrics and won't affect the training logics.

Also cc @haohanchen-aws to track this issue.

anyscalesam commented 1 month ago

@haohanchen-aws do you have a ticket on the AWS side that you could link here tracking the fix to the issue?