ray-project / ray_lightning

Pytorch Lightning Distributed Accelerators using Ray
Apache License 2.0
211 stars 34 forks source link

[raystrategy] multi-stragy in the worker is not consistent #161

Closed JiahaoYao closed 2 years ago

JiahaoYao commented 2 years ago

In the current ray strategy, since the strategy shows in three places: two are obvious and one is hidden.

one is in the ray launcher:

class RayLauncher(_SpawnLauncher):
    def __init__(self, strategy: "RayPlugin") -> None:
        self._strategy = strategy
        self._start_method = "ray"
        self._workers = []
        self._futures = []
        self._master_addr = None

https://github.com/JiahaoYao/ray_lightning/blob/2727fd441a62e0e6763fd1f25ed97575dc5a6733/ray_lightning/ray_ddp.py#L38-L48

And later we use these in _wrapped_function_

https://github.com/JiahaoYao/ray_lightning/blob/main/ray_lightning/ray_ddp.py#L241-L242

        self._strategy.set_remote(True)
        self._strategy.set_global_to_local(global_to_local)

The second is an attributed in the trainer.strategy.

The last hidden one is in the

https://github.com/JiahaoYao/ray_lightning/blob/2727fd441a62e0e6763fd1f25ed97575dc5a6733/ray_lightning/ray_ddp.py#L222-L226

        self._futures = [
            w.execute.remote(self._wrapping_function, i, self._global_to_local,
                             trainer, function, args, kwargs, self.tune_queue)
            for i, w in enumerate(self._workers)
        ]

ray remote functions create the copy of trainer.

Thus, the actual call of the strategy.teardown is the one from the copies of trainer.

support of the assumption is

image image

printing out the pid of strategy, and it turns out they are different.

Proposal: might removing the redundant use of strategy

JiahaoYao commented 2 years ago

143

JiahaoYao commented 2 years ago

this can be simplified to the minimal codes:

import ray

ray.init()

class Trainer:
    def __init__(self, state):
        self.state = state

    def method(self, arg):
        print(arg)
        print(self.state)
        return self.state

@ray.remote
def wrapper(some_function, args, trainer):
    print("some_fucntion.__self__", some_function.__self__)
    print("type", type(some_function.__self__))
    print("trainer", trainer)
    return some_function(args)

t = Trainer(1234)
ray.get(wrapper.remote(t.method, 6789, t))

thanks @ckw017 !! 👍