bwfbowen / muax

A project that provides help for using DeepMind's mctx on gym-style environments.
MIT License
45 stars 9 forks source link

how to use image shape (96, 96, 1) like using atari PongNoFrameskip-v4 #7

Open amineoui opened 10 months ago

amineoui commented 10 months ago

i try MCTX Muax its a good art work , but i have issue when i set conv observation like a shape (96, 96, 1) , please guide me to solve this problem and the correct way to set train_env and eval_env as input of muax.fit() function

i use this code :


support_size = 20
embedding_size = 10
full_support_size = int(support_size * 2 + 1)
num_actions = 2

repr_fn = init_representation_func(Representation, embedding_size)
pred_fn = init_prediction_func(Prediction, num_actions, full_support_size)
dy_fn = init_dynamic_func(Dynamic, embedding_size, num_actions, full_support_size)

tracer = muax.PNStep(50, 0.999, 0.5)
buffer = muax.TrajectoryReplayBuffer(500)

gradient_transform = muax.model.optimizer(init_value=0.002, peak_value=0.002, end_value=0.0005, warmup_steps=20000, transition_steps=20000)

model = muax.MuZero(repr_fn, pred_fn, dy_fn, policy='muzero', discount=0.999,
                    optimizer=gradient_transform, support_size=support_size)

model_path = muax.fit(model, 'CartPole-v1', 
                max_episodes=1000,
                max_training_steps=50000,
                tracer=tracer,
                buffer=buffer,
                k_steps=10,
                sample_per_trajectory=1,
                buffer_warm_up=128,
                num_trajectory=128,
                tensorboard_dir='/content/data/tensorboard/',
                save_name='model_params',
                random_seed=i,
                log_all_metrics=True)

a simple example how i can run the muax with atari gym PongNoFrameskip-v4

bwfbowen commented 9 months ago

Hi amineoui, sorry for the delayed response. The problem is that I used resnet modules(e.g. ResNetRepresentation ) instead of MLP ones(Representation in your code) to process image input. And here is an example project that uses muax and accepts visual inputs: NTS