sotetsuk / pgx

♟️ Vectorized RL game environments in JAX
http://sotets.uk/pgx/
Apache License 2.0
416 stars 27 forks source link

Import error in Python 3.11 #996

Closed Nightbringers closed 1 year ago

Nightbringers commented 1 year ago

from pgx.go import Go

get this error: ValueError: mutable default <class 'jaxlib.xla_extension.ArrayImpl'> for field current_player is not allowed: use default_factory

sotetsuk commented 1 year ago

Hi, @Nightbringers

Thank you for your comment! Could you give us the following information about your environment to reproduce the error?

At least, I found it works in google colab.

image
Nightbringers commented 1 year ago

jax 0.4.13 jaxlib 0.4.13+cuda12.cudnn89 python. 3.11 pgx 1.1.0 os Ubuntu 18.04.5 gpu 3090 flax 0.7.0

sotetsuk commented 1 year ago

Thank you for additional information!

I guess this errors comes from dataclass change in Python 3.11. Pgx is currently only tested in Python 3.8, 3.9, and 3.10. We will support Python 3.11 in the near future but you can use 3.10 instead as hotfix.

Thank you for your report!

Nightbringers commented 1 year ago

The final feature plane, C, represents the colour to play, and has a constant value of either 1 if black is to play or 0 if white is to play. I found through testing that the implementation is contrary to the description in the paper. Now 0 is black to play, 1 is white to play. Does it matter?

sotetsuk commented 1 year ago

The final feature plane, C, represents the colour to play, and has a constant value of either 1 if black is to play or 0 if white is to play. I found through testing that the implementation is contrary to the description in the paper. Now 0 is black to play, 1 is white to play. Does it matter?

I suppose it doesn't matter.

Nightbringers commented 1 year ago

i think the color of current player to play and the color of player_id is confused. I don't understand why it has been set up this way; these two could be one. i see the black player sometimes it's 0, and sometimes it's 1. Does it affect feature plane, C? how can i make sure black player is always 0?

sotetsuk commented 1 year ago

these two could be one

This is wrong. Read

and the discussion in this issue may help.

Also read here for Go feature

https://sotets.uk/pgx/go/#observation

Please note that black/white is independent of current_player current_player is necessary because of Pgx's vectorization feature. If you do not take into account the vectorization, you may misunderstand it's role.

sotetsuk commented 1 year ago

~Given your comment, I will check the implementation and documentation. Please wait a few days.~

I think our implementation is ok but I'm happy to help you 👍 You can give us a minimum code example in which you think the behavior is strange. @Nightbringers

Nightbringers commented 1 year ago

i suppose you do this for two different agent play in case one agent being black or white all the time. but in self-play condition, these two could be one, is that right?

sotetsuk commented 1 year ago

i suppose you do this for two different agent play in case one agent being black or white all the time

Wrong. Please refer to the documents above.

but in self-play condition, these two could be one

I'm not sure what you intend but anyway, current_player has nothing to do with black/white.

Nightbringers commented 1 year ago

thanks, Im now still a little bit confused. I will read more documents later. I need to confirm that no matter how current_player changes, the observation is all the same? it does not affect C feature? 0 is black to play, if black_player is 1, it still 0 is black to play? not change to 1 is black to play?

sotetsuk commented 1 year ago

Sorry but I'm not sure what you want to confirm but this snippet may help. I modified the example in doc.

import jax
import jax.numpy as jnp
import pgx
from pgx.experimental.utils import act_randomly

seed = 42
batch_size = 10
key = jax.random.PRNGKey(seed)

# Prepare agent A and B
#   Agent A: random player
#   Agent B: baseline player provided by Pgx
A = 0
B = 1

# Load the environment
env = pgx.make("go_9x9")
init_fn = jax.jit(jax.vmap(env.init))
step_fn = jax.jit(jax.vmap(env.step))

# Prepare baseline model
# Note that it additionaly requires Haiku library ($ pip install dm-haiku)
model_id = "go_9x9_v0"
model = pgx.make_baseline_model(model_id)

# Initialize the states
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, batch_size)
state = init_fn(keys)
print(f"Game index: {jnp.arange(batch_size)}")  #  [0 1 2 3 4 5 6 7 8 9]
# print(f"Black player: {state.current_player}")  #  [1 1 0 1 0 0 1 1 1 1]
# In other words
# print(f"A is black: {state.current_player == A}")  # [False False  True False  True  True False False False False]
# print(f"B is black: {state.current_player == B}")  # [ True  True False  True False False  True  True  True  True]

# Run simulation
R = state.rewards
for i in range(10):
    print(i, state.observation[:, 0, 0, -1])  # !!!!LAST DIMENSION OF FEATURE!!!!

    # Action of random player A
    key, subkey = jax.random.split(key)
    action_A = jax.jit(act_randomly)(subkey, state)
    # Greedy action of baseline model B
    logits, value = model(state.observation)
    action_B = logits.argmax(axis=-1)

    action = jnp.where(state.current_player == A, action_A, action_B)
    state = step_fn(state, action)
    R += state.rewards

# print(f"Return of agent A = {R[:, A]}")  # [-1. -1. -1. -1. -1. -1. -1. -1. -1. -1.]
# print(f"Return of agent B = {R[:, B]}")  # [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]

This results in

Game index: [0 1 2 3 4 5 6 7 8 9]
0 [False False False False False False False False False False]
1 [ True  True  True  True  True  True  True  True  True  True]
2 [False False False False False False False False False False]
3 [ True  True  True  True  True  True  True  True  True  True]
4 [False False False False False False False False False False]
5 [ True  True  True  True  True  True  True  True  True  True]
6 [False False False False False False False False False False]
7 [ True  True  True  True  True  True  True  True  True  True]
8 [False False False False False False False False False False]
9 [ True  True  True  True  True  True  True  True  True  True]
Nightbringers commented 1 year ago

So the LAST DIMENSION OF FEATURE has nothing to do with current_player, it always start with False, then True, then False... The current_player won't affect observation. Right?

sotetsuk commented 1 year ago

In this usage, yes.

Nightbringers commented 1 year ago

Have you ever considered supporting Chinese Chess?

sotetsuk commented 1 year ago

No.

sotetsuk commented 1 year ago

But if there are many requests, we may support

Nightbringers commented 1 year ago

Well, I think the demand will exceed that of many games here.

Nightbringers commented 1 year ago

alphazero go example is good. Can you make a muzero go example use mctx? that will be great.

sotetsuk commented 1 year ago

Thank you for your comment! Unfortunately, MuZero example is not currently planned🙏