FLAIROx / JaxMARL

Multi-Agent Reinforcement Learning with JAX
Apache License 2.0
414 stars 72 forks source link

Strange behaviour from agents in coin game environment #81

Closed Dronie closed 4 months ago

Dronie commented 6 months ago

Hi there,

I am trying to get my head around the library but am having an issue in the coin game environment. This is likely down to user error but I am not 100% sure what the issue may be. I am using a slightly modified version of the random agent in the README:

import jax
from jaxmarl import make
from PIL import Image
import numpy as np

key = jax.random.PRNGKey(0)
key, key_reset, key_act, key_step = jax.random.split(key, 4)

# Initialise environment.
env = make('coin_game')

# Reset the environment.
obs, state = env.reset(key_reset)

images = []
dones = 0
while dones < 10:
    images.append(env.render(state))
    # Sample random actions.
    actions = {agent: jax.numpy.asarray(4) for i, agent in enumerate(env.agents)}

    # Perform the step transition.
    obs, state, reward, done, infos = env.step(jax.numpy.asarray([4, 2], dtype='uint32'), state, actions)
    if done:
        dones += 1

gif = [i.convert("P", palette=Image.ADAPTIVE) for i in images]
gif[0].save('test2.gif', save_all=True, optimize=False, append_images=gif[1:], loop=0)

My problem is that I am fairly sure that both agents should always just take the 'stay' action due to

actions = {agent: jax.numpy.asarray(4) for i, agent in enumerate(env.agents)}

but, when viewing the gif generated by the above code this does not seem to be the case: test It also seems, from the gif, that agents aren't picking up coins when stepping on them?

Any help/instruction on where I may have misunderstood something would be greatly appreciated!

Thanks, Stefan

Dronie commented 6 months ago

Update: I have realised that the agents are indeed picking up the coins as this is evidenced by the scores going up on right of the gif (I assume they just spawn back in the same location due to the fixed enironment seed), but still not sure about the agent behaviour

amacrutherford commented 6 months ago

Hey thanks for raising this! We should hopefully have time to take a look later this week, keep updating us with what you find :slightly_smiling_face:

12tqian commented 6 months ago

It might be because you're using the same rng state without changing it?

Dronie commented 6 months ago

It might be because you're using the same rng state without changing it?

This is likely why the the coins are remaining in the same position when collected. However I'm not sure this explains why agent behaviour is random despite having hardcoded taking the 'stay' action at every timestep?

luchris429 commented 4 months ago

Sorry for taking so long to get back to you on this!

The API for Coin Game actions is as follows:

actions = jax.numpy.asarray([4, 4])

when doing this instead, it fixes the issue:

obs, state = env.reset(key_reset)

images = []
dones = 0
while dones < 10:
    images.append(env.render(state))
    # Sample random actions.
    # actions = {agent: jax.numpy.asarray(4) for i, agent in enumerate(env.agents)}
    actions = jax.numpy.asarray([4, 4])

    # Perform the step transition.
    obs, state, reward, done, infos = env.step(jax.numpy.asarray([4, 2], dtype='uint32'), state, actions)
    if done:
        dones += 1

gif = [i.convert("P", palette=Image.ADAPTIVE) for i in images]
gif[0].save('test2.gif', save_all=True, optimize=False, append_images=gif[1:], loop=0)

test2