proroklab / graph-conv-memory

Graph convolutional memory
15 stars 1 forks source link

Ray CartPole Example is using A2C instead of PPO #2

Closed heng2j closed 3 years ago

heng2j commented 3 years ago

Hi @smorad ,

Noticed the latest change on the CartPole example changed the RL algorithm from PPO to A2C. Is this the intend going forward? And do you mind to provide some details on the reasoning behind the change?

Thank you, Heng

smorad commented 3 years ago

@heng2j you can train using whichever algorithm you prefer. In the paper, we show GCM works with PPO, IMPALA, or A2C. It likely works with other algorithms as well. Ray PPO tends to be quite slow in my opinion (it does nearly 900 forward passes over each training batch), hence the switch to A2C. You might want to switch algorithms depending on the task you're trying to solve.

heng2j commented 3 years ago

Thank you @smorad. Currently, I bumped into the sample batch index out of range issue for running the example with PPO. Perhaps I may needed to supply the PPO specific parameters to my rllib config. Do you mind to check if that's also the case for you when we simply swapping PPO from A2C?

smorad commented 3 years ago

I think this must be due to your PPO parameters or perhaps the task itself? We did not run into issues using PPO. We do set max_seq_len = graph_size + 1 in the config, but besides that, I don't think we make any other significant changes.

Make sure you have an up-to-date ray installation. They keep patching 1.6.0 and we did run into similar issues that forced us to upgrade from ray 1.4.0.

heng2j commented 3 years ago

Thank you @smorad. Turns out for PPO, setting this LSTM related parameter enabled training with PPO. So here is my sample pytest test run for PPO:


def test_Ray_gcm():
    hidden = 32
    graph_size = 32
    ray.init(
        local_mode=True,
        object_store_memory=3e10,
    )
    dgc = torch_geometric.nn.Sequential(
        "x, adj, weights, B, N",
        [
            # Mean and sum aggregation perform roughly the same
            # Preprocessor with 1 layer did not help
            (torch_geometric.nn.DenseGraphConv(hidden, hidden), "x, adj -> x"),
            (torch.nn.Tanh()),
            (torch_geometric.nn.DenseGraphConv(hidden, hidden), "x, adj -> x"),
            (torch.nn.Tanh()),
        ],
    )
    cfg = {
        "framework": "torch",
        "num_gpus": 1,
        "env": "CartPole-v0",
        "num_workers": 0,
        "model": {
            "custom_model": RayDenseGCM,
            "custom_model_config": {
                "graph_size": graph_size,
                "gnn_input_size": hidden,
                "gnn_output_size": hidden,
                "gnn": dgc,
                "edge_selectors": TemporalBackedge([1]),
                "edge_weights": False,
            },
            "max_seq_len": graph_size + 1,
        },
    }
    tune.run("PPO", config=cfg, stop={"info/num_steps_trained": 100})

    ray.shutdown()

Thank you @smorad for your help! Greatly appreciated!