ray-project / ray

Ray is an AI compute engine. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.89k stars 5.76k forks source link

[Core] serialisation of dataclass in separate module fails to recognise parameter change in child dataclass, but functions correctly if in the same module #34366

Open DrJohnDale opened 1 year ago

DrJohnDale commented 1 year ago

What happened + What you expected to happen

I run multi processing where I pass a dataclass with a child dataclass which are defined in a separate module to the main function. If I change any of the values in the child dataclass these are not passed to the processes. If the dataclasses are defined in the same module as the main function then it works correctly.

A work around is to have the parent class as a standard python class and then it also works correctly

Versions / Dependencies

Ray 2.3.1

Reproduction script

data_class.py

from dataclasses import dataclass

@dataclass
class Child:
    c1: int
    c2: float

@dataclass
class Parent:
    p1: int = 1.0
    p2 = Child(c1=1, c2=1.0)

separate_files.py

import ray
from data_classes import Parent
import time

@ray.remote
def print_values(data: Parent):
    print(data.p1)
    print(data.p2)

parent = Parent()

parent.p2.c1 = 10
parent.p2.c2 = 99.99

print(parent.p1)
print(parent.p2)

num_cpu = 4

futures = [print_values.remote(parent) for _ in range(num_cpu)]
matching_results = ray.get(futures)

time.sleep(1)
ray.shutdown()

running above gives incorrect output

1.0
Child(c1=10, c2=99.99)
2023-04-13 16:45:38,894 INFO worker.py:1553 -- Started a local Ray instance.
(print_values pid=111732) 1.0
(print_values pid=111732) Child(c1=1, c2=1.0)
(print_values pid=111734) 1.0
(print_values pid=111734) Child(c1=1, c2=1.0)
(print_values pid=111726) 1.0
(print_values pid=111726) Child(c1=1, c2=1.0)
(print_values pid=111730) 1.0
(print_values pid=111730) Child(c1=1, c2=1.0)

single_file.py

import ray
from dataclasses import dataclass
import time

@dataclass
class Child:
    c1: int
    c2: float

@dataclass
class Parent:
    p1: int = 1.0
    p2 = Child(c1=1, c2=1.0)

@ray.remote
def print_values(data: Parent):
    print(data.p1)
    print(data.p2)

parent = Parent()

parent.p2.c1 = 10
parent.p2.c2 = 99.99

print(parent.p1)
print(parent.p2)

num_cpu = 4

futures = [print_values.remote(parent) for _ in range(num_cpu)]
matching_results = ray.get(futures)

time.sleep(1)
ray.shutdown()

works as expected and gives output

1.0
Child(c1=10, c2=99.99)
2023-04-13 16:46:36,204 INFO worker.py:1553 -- Started a local Ray instance.
(print_values pid=113184) 1.0
(print_values pid=113184) Child(c1=10, c2=99.99)
(print_values pid=113172) 1.0
(print_values pid=113172) Child(c1=10, c2=99.99)
(print_values pid=113170) 1.0
(print_values pid=113170) Child(c1=10, c2=99.99)
(print_values pid=113174) 1.0
(print_values pid=113174) Child(c1=10, c2=99.99)

Issue Severity

Medium: It is a significant difficulty but I can work around it.

cadedaniel commented 1 year ago

cc @jjyao FYI

rkooo567 commented 1 year ago

this is highly likely cloudpickle doesn't play well with the dataclass. You may need to define your own serializer https://docs.ray.io/en/master/ray-core/objects/serialization.html#customized-serialization