FLAIROx / JaxMARL

Multi-Agent Reinforcement Learning with JAX
Apache License 2.0
393 stars 68 forks source link

Unable to replicate performance with Q-Learning on SMAX #66

Closed corentinartaud closed 5 months ago

corentinartaud commented 5 months ago

Hi! First and foremost, fantastic work! I'm trying to replicate the performance shown in the paper for the Q-Learning baselines locally; however, using the exact versions provided in your requirements under Python 3.9 and Jax 0.4.11, I can't seem to get even close to the results showcased in Figure 12 using the hyperparameters provided in Table 10 (i.e. TD_LAMBDA_LOSS=False and NUM_STEPS=100 under the current configuration files). The results for 5m_vs_6m seem on par with the plots on the paper; however, 2s3z, 3s5z and 3s_vs_5z do not match. I don't expect a one-to-one match, but the results should be somewhat close. I would greatly appreciate any insights on how to reproduce these plots locally using the current state of the repo.

Report: https://api.wandb.ai/links/corentinartaud/2fgsq67j

I've also tried running these without changing the hyperparameters found in the configuration file and under Python 3.11 and Jax 0.4.25; however, the results are very close to what is provided in the report. Have there been any significant changes between the moment the plots were made and the current HEAD of the repo? If so, can you provide the commit ID where you ran the experiments to generate the plots found on the paper?

mttga commented 5 months ago

Hi @corentinartaud, you're reporting results for 1 seed, in the paper we report for 4. Would you mind trying changing NUM_SEEDS in the config file to at least 2?

corentinartaud commented 5 months ago

Hi @mttga, I've run 2s3z, 3s5z, and 3s_vs_5z with 2 and 4 seeds, and the results still seem off, especially for 2s3z and 3s5z. You can see the new results in the wandb report linked above.

mttga commented 5 months ago

Hi @corentinartaud, I was able to replicate your results. I investigated it, and the reason for it seems to come from the following:

  1. When I was running the benchmarks, I was scaling the smax returns to the smac scale (i.e., reward*10, since smax maximum is 2 and smac maximum is 20).
  2. Afterwards, in order to uniform with PPO, I removed the rescaling. I did some tests (5m_vs_6m, smac_v2_5units, and so on), and this didn't seem to affect the performances. My bad I didn't check on more maps.
  3. After your comments, I indeed found that the rescaling was crucial for the 2s3z and 3s5z. i.e., using the smax scale (maximum 2) qmix was performing poorly, but greatly when trained with the smac scale (maximum 20).

I think it is kind of bizarre that this is happening only on some maps, and at the moment, I don't know why it's happening. Maybe @benellis3 might have some intuitions on this.

In any case, I re-introduced the *10 scaling during training. With the last version of jaxmarl, you should be able to reproduce the results of the paper. Please let me know if that's not the case.

You can find the runs here: https://wandb.ai/mttga/jaxmarl_pull_request_71?nw=nwusermttga

Again, thank you very much for opening the issue and share your findings with us.

mttga commented 5 months ago

Hi @corentinartaud, I will close this soon if you don't have any other comments.

corentinartaud commented 5 months ago

Hi @mttga, thank you very much for taking the time to solve this issue. I was able to reproduce the results and have no other comments.