instadeepai / og-marl

Datasets with baselines for offline multi-agent reinforcement learning.
https://instadeepai.github.io/og-marl/
Apache License 2.0
136 stars 12 forks source link

Performance drop bewteen offline and online #27

Closed zhonghai1995 closed 4 months ago

zhonghai1995 commented 5 months ago

Thanks so much for your work, I find it very helpful.

I am confused with a problem, that I trained with qmix+cql,startcraft v1, 3m scneario, firstly offline and then online. I comment out the trainning part in the train online function, so It is actually just evaluating the performance. But I see a big performance discrepency, the offline training achieves around 20 epsiode return, but when I just evluate it in the train_online, it has a much lower return ranging roughly around [2,4]. I am very confused, and hope you could provide some help or insight into it. Thanks so much!

Below are my episode return curves for offline and online. I aslo paste my online trainning code and part of my main() function.

W B Chart 5_10_2024, 8_27_31 PM W B Chart 5_10_2024, 8_27_18 PM

def train_online(
    self,
    replay_buffer: FlashbaxReplayBuffer,
    max_env_steps: int = int(1e4),
    train_period: int = 20
) -> None:
    """Method to train the system online."""
    episodes = 0
    while True:  # breaks out when env_steps > max_env_steps
        self.reset()  # reset the system
        observations_ = self._environment.reset()

        if isinstance(observations_, tuple):
            observations, infos = observations_
        else:
            observations = observations_
            infos = {}

        episode_return = 0.0
        while True:

            if "legals" in infos:
                legal_actions = infos["legals"]
            else:
                legal_actions = None

            start_time = time.time()
            actions = self.select_actions(observations, legal_actions)
            end_time = time.time()
            time_for_action_selection = end_time - start_time

            start_time = time.time()
            (
                next_observations,
                rewards,
                terminals,
                truncations,
                next_infos,
            ) = self._environment.step(actions)
            end_time = time.time()
            time_to_step = end_time - start_time

            # Add step to replay buffer
            replay_buffer.add(observations, actions, rewards, terminals, truncations, infos)

            # Critical!!
            observations = next_observations
            infos = next_infos

            # Bookkeeping
            episode_return += np.mean(list(rewards.values()), dtype="float")
            self._env_step_ctr += 1
            """Comment out the training part
            if (
                self._env_step_ctr > 100 and self._env_step_ctr % train_period == 0
            ):  # TODO burn in period
                # Sample replay buffer
                start_time = time.time()
                experience = replay_buffer.sample()
                end_time = time.time()
                time_to_sample = end_time - start_time

                # Train step
                start_time = time.time()
                train_logs = self.train_step(experience,offline = False)
                end_time = time.time()
                time_train_step = end_time - start_time

                train_steps_per_second = 1 / (time_train_step + time_to_sample)
                env_steps_per_second = 1 / (time_to_step + time_for_action_selection)

                train_logs = {
                    **train_logs,
                    **self.get_stats(),
                    "Environment Steps": self._env_step_ctr,
                    "Time to Sample": time_to_sample,
                    "Time for Action Selection": time_for_action_selection,
                    "Time to Step Env": time_to_step,
                    "Time for Train Step": time_train_step,
                    "Train Steps Per Second": train_steps_per_second,
                    "Env Steps Per Second": env_steps_per_second,
                }

                self._logger.write(train_logs)
            """
            if all(terminals.values()) or all(truncations.values()):
                break

        episodes += 1
        if episodes % 1 == 0:  # TODO: make variable
            self._logger.write(
                {   
                    "Episodes": episodes,
                    "Episode Return": episode_return,
                    "Environment Steps": self._env_step_ctr,
                },
                force=True,
            )

        if self._env_step_ctr > max_env_steps:
            break

def main(_): config = { "env": FLAGS.env, "scenario": FLAGS.scenario, "dataset": FLAGS.dataset, "system": FLAGS.system, "backend": "tf2", }

env = get_environment(FLAGS.env, FLAGS.scenario)

buffer = FlashbaxReplayBuffer(sequence_length=20, sample_period=1)

download_and_unzip_vault(FLAGS.env, FLAGS.scenario)

is_vault_loaded = buffer.populate_from_vault(FLAGS.env, FLAGS.scenario, FLAGS.dataset)
if not is_vault_loaded:
    print("Vault not found. Exiting.")
    return

run_name = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
logger = WandbLogger(project="og-marl-baselines", config=config,name = run_name)

json_writer = JsonWriter(
    "logs",
    f"{FLAGS.system}",
    f"{FLAGS.scenario}_{FLAGS.dataset}",
    FLAGS.env,
    FLAGS.seed,
    file_name=f"{FLAGS.scenario}_{FLAGS.dataset}_{FLAGS.seed}.json",
    save_to_wandb=True,
)

system_kwargs = {"add_agent_id_to_obs": True}
if FLAGS.scenario == "pursuit":
    system_kwargs["observation_embedding_network"] = CNNEmbeddingNetwork()

system = get_system(FLAGS.system, env, logger, **system_kwargs)

system.train_offline(buffer, max_trainer_steps=FLAGS.trainer_steps, json_writer=json_writer)

online_buffer = FlashbaxReplayBuffer(sequence_length=20, sample_period=1)
system.train_online(online_buffer)
jcformanek commented 5 months ago

Hi @zhonghai1995 , this is very cool. Looks like you are trying to do some offline-to-online training. Let me take some time to look into why it doesn't seem to be working. Ill get back to you as soon as possible. I am working on it now.

jcformanek commented 5 months ago

I think I know what is going on. The QMIX system (qmix.py) has an argument called eps_decay_timesteps=50_000. This means that the qmix+cql.py system will use epsilon-greedy action selection for the first 50000 timesteps. That means your system is choosing random actions when it goes online. Try setting that value to zero.

I see in qmix_cql.py I did not expose the eps_decay_timesteps argument. So you may want to just modify the code a bit so that you can change it.

By the way, I have found IDRQN+CQL works better than QMIX+CQL.

https://instadeepai.github.io/og-marl/baselines/smac_v1/

zhonghai1995 commented 5 months ago

I see and it solves the problem. thanks again!

jcformanek commented 5 months ago

I am glad its working. We actually also did a research project on offline-to-online MARL which you might find interesting. You can find it here:

https://instadeepai.github.io/selective-reincarnation-marl/

zhonghai1995 commented 5 months ago

By the way, did you tried to use omar for discrete actions? I tried the gumbel max trck in the smac envrionments, and the performance is bad.

jcformanek commented 5 months ago

I have also not successfully implemented OMAR for discrete actions. I have seen other people also have challenges with this. See here: https://github.com/thu-rllab/CFCQL/issues/1

zhonghai1995 commented 5 months ago

hi @jcformanek, I tried omar on 2ant in mamujoco, with good dataset. I used adam instead of rmsprop, increase the hidden sizes to 256, and it seems the performance is better, roughly matches the peformance of BC and ITD3 bc in Table D.5 in your paper. Please have a look.

https://wandb.ai/haizhong/og-marl-baselines/reports/evaluator-episode_return_offline-24-05-25-23-51-46---Vmlldzo4MDkyMjU4

W B Chart 5_25_2024, 11_50_39 PM

zhonghai1995 commented 5 months ago

I run for more seeds, now the result is across 20 seeds. and it looks like it could acheive an average of roughly 1700 mean reward for good dataset in 2ant scenario, worse than bc based methods, but sill better than reported in the table.

https://wandb.ai/haizhong/og-marl-baselines/reports/evaluator-episode_return_offline-24-05-26-21-21-23---Vmlldzo4MDk5MjA3

jcformanek commented 5 months ago

Oh that is great, thank you for sharing. We will work on updating all of the benchmark results.

zhonghai1995 commented 4 months ago

https://github.com/instadeepai/og-marl/blob/68db0c007c73a06197f7b66d453ee4dd7429434e/og_marl/tf2/systems/idrqn.py#L115-L116

One more question, why you increase the env step here? thanks!

jcformanek commented 4 months ago

Thats used to control the epsilon greedy exploration. It only has an effect if you train online.

zhonghai1995 commented 4 months ago

https://github.com/instadeepai/og-marl/blob/68db0c007c73a06197f7b66d453ee4dd7429434e/og_marl/tf2/systems/base.py#L103-L127

But you also increase the environment step here. And the default argument of explore for select action function is True, then you would increase the environment step counter twice for a single environment step,is this expected?

jcformanek commented 4 months ago

Oh I see. I think you are right! That would result in exploration decreasing 2x faster than I expected. You are welcome to open a PR to fix it if you like. Alternatively, I can attend to it

zhonghai1995 commented 4 months ago

https://github.com/instadeepai/og-marl/blob/68db0c007c73a06197f7b66d453ee4dd7429434e/og_marl/tf2/systems/qmix_cql.py#L191-L192 I also find here cql loss is not multiplied by its weight. is this as expected? If no, please fix them

jcformanek commented 4 months ago

I have just merged (#28) in a fix for this and for the env_step_ctr bug. Thank you so much for finding and reporting these bugs. I really appreciate your contributions. Let me know if you find any more.

zhonghai1995 commented 4 months ago

hi @jcformanek , I see you added more benchmark results for datasets from previous works, thanks for this and it is really helpful. I wonder if I want to convert the dataset of omar's mpe (other than the simple spread), how can I do it? Also, do I need to calculate the normalized score by myself? If so, where can I find the expert, random score of the datset it self? Thanks so much

jcformanek commented 4 months ago

I am glad you find it helpful. Ill upload the datasets for the other scenarios, we already converted them. The challenge we faced on those scenarios is that the MPE environment code they used depended on loading in a pre-trained model (PyTorch) for the adversaries. If you can properly instantiate the environment for evaluation, then everything should work fine.

With regards to normalisation, the CFCQL paper says they normalise in one way, but if you inspect the code you can see they simply normalise by dividing by the mean episode return of the dataset. You need to do the normalisation yourself, yes.

zhonghai1995 commented 4 months ago

Thanks! I am trying to run this simple spread environment from offline to online, during online I need state, but in the infos obtained from step are just info_n [{}, {}, {}]. What are the states for the mpe simple spread? I could extract this by myself, but I do not know how it is composed. Please help me. Thanks so much!

zhonghai1995 commented 4 months ago

Thanks! I am trying to run this simple spread environment from offline to online, during online I need state, but in the infos obtained from step are just info_n [{}, {}, {}]. What are the states for the mpe simple spread? I could extract this by myself, but I do not know how it is composed. Please help me. Thanks so much!

I think I figure it out, the state is just the concatenation of the three agent's observations.

jcformanek commented 4 months ago

Yes, I think you are correct! Also @callumtilbury is uploading the other MPE vaults now. We will add the download link to the file og_marl/offline_dataset.py.

callumtilbury commented 4 months ago

Hi @zhonghai1995 👋🏻 Here are the MPE vaults from OMAR:

    "mpe_omar": {
        "simple_spread": {"url": "https://s3.kao.instadeep.io/offline-marl-dataset/omar/simple_spread.zip"},
        "simple_tag": {"url": "https://s3.kao.instadeep.io/offline-marl-dataset/omar/simple_tag.zip"},
        "simple_world": {"url": "https://s3.kao.instadeep.io/offline-marl-dataset/omar/simple_world.zip"},
    }

Note that for the simple_world and simple_tag scenarios, the observation dimensions are not homogenous, so we pad them with -inf when appropriate.

The vault conversion code can be found here: https://bit.ly/vault-conversion-notebook. For OMAR's MPE datasets, see Example 4.

Please let us know if you have any further questions or problems! 🚀

zhonghai1995 commented 4 months ago

@callumtilbury This is super helpful for me! Thanks so mcuh!

jcformanek commented 4 months ago

I am going to convert this "issue" into a "discussion" and then we can continue discussing using OG-MARL for offline-to-online MARL. :rocket: