Closed Nightbringers closed 1 year ago
Hi, @Nightbringers
Thank you for your comment! Could you give us the following information about your environment to reproduce the error?
jax
versionjaxlib
versionpgx
versionAt least, I found it works in google colab.
jax 0.4.13 jaxlib 0.4.13+cuda12.cudnn89 python. 3.11 pgx 1.1.0 os Ubuntu 18.04.5 gpu 3090 flax 0.7.0
Thank you for additional information!
I guess this errors comes from dataclass change in Python 3.11. Pgx is currently only tested in Python 3.8, 3.9, and 3.10. We will support Python 3.11 in the near future but you can use 3.10 instead as hotfix.
Thank you for your report!
The final feature plane, C, represents the colour to play, and has a constant value of either 1 if black is to play or 0 if white is to play. I found through testing that the implementation is contrary to the description in the paper. Now 0 is black to play, 1 is white to play. Does it matter?
The final feature plane, C, represents the colour to play, and has a constant value of either 1 if black is to play or 0 if white is to play. I found through testing that the implementation is contrary to the description in the paper. Now 0 is black to play, 1 is white to play. Does it matter?
I suppose it doesn't matter.
i think the color of current player to play and the color of player_id is confused. I don't understand why it has been set up this way; these two could be one. i see the black player sometimes it's 0, and sometimes it's 1. Does it affect feature plane, C? how can i make sure black player is always 0?
these two could be one
This is wrong. Read
and the discussion in this issue may help.
Also read here for Go feature
https://sotets.uk/pgx/go/#observation
Please note that black/white is independent of current_player
current_player
is necessary because of Pgx's vectorization feature.
If you do not take into account the vectorization, you may misunderstand it's role.
~Given your comment, I will check the implementation and documentation. Please wait a few days.~
I think our implementation is ok but I'm happy to help you 👍 You can give us a minimum code example in which you think the behavior is strange. @Nightbringers
i suppose you do this for two different agent play in case one agent being black or white all the time. but in self-play condition, these two could be one, is that right?
i suppose you do this for two different agent play in case one agent being black or white all the time
Wrong. Please refer to the documents above.
but in self-play condition, these two could be one
I'm not sure what you intend but anyway, current_player
has nothing to do with black/white
.
thanks, Im now still a little bit confused. I will read more documents later. I need to confirm that no matter how current_player changes, the observation is all the same? it does not affect C feature? 0 is black to play, if black_player is 1, it still 0 is black to play? not change to 1 is black to play?
Sorry but I'm not sure what you want to confirm but this snippet may help. I modified the example in doc.
import jax
import jax.numpy as jnp
import pgx
from pgx.experimental.utils import act_randomly
seed = 42
batch_size = 10
key = jax.random.PRNGKey(seed)
# Prepare agent A and B
# Agent A: random player
# Agent B: baseline player provided by Pgx
A = 0
B = 1
# Load the environment
env = pgx.make("go_9x9")
init_fn = jax.jit(jax.vmap(env.init))
step_fn = jax.jit(jax.vmap(env.step))
# Prepare baseline model
# Note that it additionaly requires Haiku library ($ pip install dm-haiku)
model_id = "go_9x9_v0"
model = pgx.make_baseline_model(model_id)
# Initialize the states
key, subkey = jax.random.split(key)
keys = jax.random.split(subkey, batch_size)
state = init_fn(keys)
print(f"Game index: {jnp.arange(batch_size)}") # [0 1 2 3 4 5 6 7 8 9]
# print(f"Black player: {state.current_player}") # [1 1 0 1 0 0 1 1 1 1]
# In other words
# print(f"A is black: {state.current_player == A}") # [False False True False True True False False False False]
# print(f"B is black: {state.current_player == B}") # [ True True False True False False True True True True]
# Run simulation
R = state.rewards
for i in range(10):
print(i, state.observation[:, 0, 0, -1]) # !!!!LAST DIMENSION OF FEATURE!!!!
# Action of random player A
key, subkey = jax.random.split(key)
action_A = jax.jit(act_randomly)(subkey, state)
# Greedy action of baseline model B
logits, value = model(state.observation)
action_B = logits.argmax(axis=-1)
action = jnp.where(state.current_player == A, action_A, action_B)
state = step_fn(state, action)
R += state.rewards
# print(f"Return of agent A = {R[:, A]}") # [-1. -1. -1. -1. -1. -1. -1. -1. -1. -1.]
# print(f"Return of agent B = {R[:, B]}") # [1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
This results in
Game index: [0 1 2 3 4 5 6 7 8 9]
0 [False False False False False False False False False False]
1 [ True True True True True True True True True True]
2 [False False False False False False False False False False]
3 [ True True True True True True True True True True]
4 [False False False False False False False False False False]
5 [ True True True True True True True True True True]
6 [False False False False False False False False False False]
7 [ True True True True True True True True True True]
8 [False False False False False False False False False False]
9 [ True True True True True True True True True True]
So the LAST DIMENSION OF FEATURE has nothing to do with current_player, it always start with False, then True, then False... The current_player won't affect observation. Right?
In this usage, yes.
Have you ever considered supporting Chinese Chess?
No.
But if there are many requests, we may support
Well, I think the demand will exceed that of many games here.
alphazero go example is good. Can you make a muzero go example use mctx? that will be great.
Thank you for your comment! Unfortunately, MuZero example is not currently planned🙏
from pgx.go import Go
get this error: ValueError: mutable default <class 'jaxlib.xla_extension.ArrayImpl'> for field current_player is not allowed: use default_factory