FLAIROx / JaxMARL

Multi-Agent Reinforcement Learning with JAX
Apache License 2.0
395 stars 69 forks source link

Unable to replicate performance on MABrax #47

Closed sash-a closed 9 months ago

sash-a commented 9 months ago

Hi there :wave:

I've been trying to match the performance that you show in the paper, where you get around 5000 return for ant_4x2, however when I run your code locally I get around -600 after 10m timesteps with the config that you share on this repo (in your paper you get ~3500 at 10m timesteps).

If you could let me know any config differences or provide a script that is able to match the performance that would be great :+1:

Another question, when plotting your results in the paper do you sum all agent returns to get 5000 or are you taking the mean return for each agent?

Your paper results image

Our local run of your repo image

gardarjuto commented 9 months ago

Thanks for letting us know! There is actually an error in the plot in this version of the paper, caused by misinterpreted logging. The x-axis should be up to 3e8 and not 3e7 for Ant. Furthermore, it was made using an older version of brax (0.0.16) which may have some differences. We will update this particular plot so it is reproducible by the current state of the repo.

If you set TOTAL_TIMESTEPS to 3e8, you should get a return of at least above 2000 on average with the current implementation.

In these tasks, the same return is shared between all agents, so we just plot that return, not the sum.

sash-a commented 9 months ago

Thanks for the reply @gardarjuto! Sorry I'm going to re-open this issue as I don't get the results you are claiming and I think it might be useful for future people who look into this.

I set total timesteps to 3e8 and I'm only seeing ~750 return and it doesn't seem to learn passed 1e8 steps: image

As a point of comparison Mava's feed forward IPPO gets ~1300 return, but again nowhere near 5000. I checked the brax version that installs with JaxMarl and it's now well passed 0.0.16, it's at 0.9.3, so I assume this is the issue. The question is why are you then requiring brax>=0.9? I don't think pinning to such an old version of brax is at all ideal (in fact I couldn't get it running), so are there any plans to update your benchmark - not necessarily in the paper, but possibly as a note in this repo so that people have a point of comparison for the current version of brax?