ray-project / ray

Ray is a unified framework for scaling AI and Python applications. Ray consists of a core distributed runtime and a set of AI Libraries for accelerating ML workloads.
https://ray.io
Apache License 2.0
33.36k stars 5.65k forks source link

[rllib] PyTorch A2C is not GPU accelerated #4333

Closed nautilus22 closed 5 years ago

nautilus22 commented 5 years ago

System information

Describe the problem

I tested atari-a2c with tuned parameter(/tuned_examples/atari-a2c.yaml) It showed great result for atari breakout. However if "use_pytorch": true was added, the result is quite different. (I used "atari-a2c-pytorch.yaml" in the 'Source code / logs' section) It was very slow and it seems that there was no improvement. I guess there's some performance issue on pytorch a2c, but are there any necessary options for pytorch a2c?

image image

Source code / logs

atari-a2c.yaml

atari-a2c:
    env:
        grid_search:
            - BreakoutNoFrameskip-v4
            - BeamRiderNoFrameskip-v4
            - QbertNoFrameskip-v4
            - SpaceInvadersNoFrameskip-v4
    run: A2C
    config:
        sample_batch_size: 20
        clip_rewards: True
        num_workers: 5
        num_envs_per_worker: 5
        num_gpus: 1
        lr_schedule: [
            [0, 0.0007],
            [20000000, 0.000000000001],
        ]

atari-a2c-pytorch.yaml

atari-a2c:
    env:
        grid_search:
            - BreakoutNoFrameskip-v4
            - BeamRiderNoFrameskip-v4
            - QbertNoFrameskip-v4
            - SpaceInvadersNoFrameskip-v4
    run: A2C
    config:
        sample_batch_size: 20
        clip_rewards: True
        num_workers: 5
        num_envs_per_worker: 5
        num_gpus: 1
        lr_schedule: [
            [0, 0.0007],
            [20000000, 0.000000000001],
        ]
        use_pytorch: True
ericl commented 5 years ago

We haven't spent time tuning the torch vision models, so this is probably expected. Also, PyTorch needs explicit tensor.cuda() calls to support GPU acceleration, which is not implemented as well (help here would be welcome)!

pong-a3c-pytorch.yaml might still work ok though, cc @richardliaw

nautilus22 commented 5 years ago

Thank you for your answer. Me and my team are seriously considering working on pytorch A2C acceleration.