Open alpc91 opened 1 year ago
I refer to another issue (https://github.com/frt03/mxt_bench/issues/1) and configure it as follows, and it can run successfully at present. conda create -n YOUR_WANT_NAME python=3.8
I forked brax to my own github repo I then git cloned the forked brax repo to my local computer I then reset the cloned local repo to the relevant past commit: git reset --hard 0ebb332 I then added the local brax repo to PYTHONPATH so it could be found by mxt: export PYTHONPATH=‘/path/to/brax’
pip install --upgrade "jax[cuda111]==0.2.21" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html pip install flax==0.3.6 pip install optax==0.0.9 pip install tensorflow-probability==0.14.0 pip install --upgrade "jax[cuda111]==0.2.21" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html pip install google pip install protobuf==3.20.0 pip install trimesh pip install gym pip install chex==0.1.0 pip install numpy==1.22 (refer to https://github.com/google/jax/discussions/9951)
Hello:
The code can run successfully in this way.
The training command:
CUDA_VISIBLE_DEVICES=0 python train_ppo_mlp.py --logdir ../results --seed 0 --env unimal_reach_handsup_5506-2-16-01-10-58-23 --obs_config mtg_v2_base_m
.
But there are some bugs when I running the "generate_behavior_and_qp.py" with unimal environments:
The command:
CUDA_VISIBLE_DEVICES=0 python generate_behavior_and_qp.py --seed 0 --env unimal_reach_handsup_5506-2-16-01-10-58-23 --task_name unimal_reach_handsup --params_path ../results/ao_ppo_mlp_single_pro_unimal_reach_handsup_5506-2-16-01-10-58-23_20230724_165922/ppo_mlp_98304000.pkl --obs_config mtg_v2_base_m
The bugs:
.... File "generate_behavior_and_qp.py", line 277, in main collect_data(environment_fn=env_fn, **train_job_params) File "generate_behavior_and_qp.py", line 203, in collect_data key_sample, state, replay_buffer) File "generate_behavior_and_qp.py", line 187, in run_collect_data steps_per_envs, length=None) File "generate_behavior_and_qp.py", line 164, in collect_and_update_buffer state.qp.pos.reshape(-1, (state_qp_shape) * 3), .... jax.core.InconclusiveDimensionOperation: Cannot divide evenly the sizes of shapes (50, 11, 3) and (-1, 48)
Have you encountered similar problems?
I use an NVIDIA graphics card, which version of python did the author install?