google-deepmind / mctx

Monte Carlo tree search in JAX
Apache License 2.0
2.31k stars 188 forks source link

how to use image shape (96, 96, 1) with muax like using atari PongNoFrameskip-v4 #63

Closed amineoui closed 11 months ago

amineoui commented 11 months ago

i try MCTX Muax its a good art work , but i have issue when i set conv observation like a shape (96, 96, 1) , please guide me to solve this problem and the correct way to set train_env and eval_env as input of muax.fit() function

i use this code :

support_size = 20
embedding_size = 10
full_support_size = int(support_size * 2 + 1)
num_actions = 2

repr_fn = init_representation_func(Representation, embedding_size)
pred_fn = init_prediction_func(Prediction, num_actions, full_support_size)
dy_fn = init_dynamic_func(Dynamic, embedding_size, num_actions, full_support_size)

tracer = muax.PNStep(50, 0.999, 0.5)
buffer = muax.TrajectoryReplayBuffer(500)

gradient_transform = muax.model.optimizer(init_value=0.002, peak_value=0.002, end_value=0.0005, warmup_steps=20000, transition_steps=20000)

model = muax.MuZero(repr_fn, pred_fn, dy_fn, policy='muzero', discount=0.999,
                    optimizer=gradient_transform, support_size=support_size)

model_path = muax.fit(model, 'CartPole-v1', 
                max_episodes=1000,
                max_training_steps=50000,
                tracer=tracer,
                buffer=buffer,
                k_steps=10,
                sample_per_trajectory=1,
                buffer_warm_up=128,
                num_trajectory=128,
                tensorboard_dir='/content/data/tensorboard/',
                save_name='model_params',
                random_seed=i,
                log_all_metrics=True)
fidlej commented 11 months ago

Maybe try asking this question on the muax repository: https://github.com/bwfbowen/muax Try to provide a better description of the problem and maybe include the error message.