FLAIROx / JaxMARL

Multi-Agent Reinforcement Learning with JAX
Apache License 2.0
393 stars 68 forks source link

Naive question about multiple seeds #95

Open kfu02 opened 3 months ago

kfu02 commented 3 months ago

Hi,

I am familiar with MARL in Pytorch, but very new to JAX, so please forgive me if this question is naive.

I see that many of your baselines are parallelized over multiple seeds at once (e.g. here in QMIX or here in transfQMIX). However, when running the baselines I notice that the resulting WandB runs seem to aggregate the seeds together. Is there some way to separate the performance of each seed for plotting purposes (e.g. to report the min/avg/max)? Your paper has several average return curves with some sort of error shading, so I imagine I must be missing something obvious.

amacrutherford commented 3 months ago

Hey! So

Thanks for reaching out and exciting that you are trying out JAX. Off the top of my head, I think for wandb the easiest way is to run one seed per script and then sweep over each seed with wandb sweeps (and if you set XLA_PYTHON_CLIENT_PREALLOCATE=false as an environment variable you can then run multiple scripts on one GPU but this is quite a bit less efficient then multiple seeds over one device). Have I missed something @mttga ?

kfu02 commented 3 months ago

Hi, thanks for the reply!

Okay, so you're saying the answer is simply not to parallelize across seeds, then use WandB's tools to aggregate separate 1-seed runs together. If I'm understanding that correctly, then what is being plotted when I run multiple seeds in parallel? The average across those seeds?

mttga commented 3 months ago

The parallel runs will plot in the same space, meaning that you will have datapoints from all your runs but you will not be able to distinguish them. To distinguish them you can use an approach like this:

def function(rng):

  original_seed = rng[0]

  # random stuff

  metrics = # a dictionary of your logging metrics

  def callback(metrics, original_seed):
        metrics.update({
            f'rng{int(original_seed)}/{k}':v
            for k, v in metrics.items()
        })
        wandb.log(metrics)

    jax.debug.callback(callback, metrics, original_seed)

we will include training code like this soon