Accepted to ICLR2023 (notable-top-25%, Spotlight) [arxiv] [Website]
If you use this codebase for your research, please cite the paper:
@inproceedings{furuta2023asystem,
title={A System for Morphology-Task Generalization via Unified Representation and Behavior Distillation},
author={Hiroki Furuta and Yusuke Iwasawa and Yutaka Matsuo and Shixiang Shane Gu},
booktitle={International Conference on Learning Representations},
year={2023},
}
pip install -r requirements.txt
Train single-task single-morphology PPO policy on the environment:
CUDA_VISIBLE_DEVICES=0 python train_ppo_mlp.py --logdir ../results --seed 0 --env ant_reach_4
Pick trained policy weight, and collect expert brax.QP
:
CUDA_VISIBLE_DEVICES=0,1 python generate_behavior_and_qp.py --seed 0 --env ant_reach_4 --task_name ant_reach --params_path ../results/ao_ppo_mlp_single_pro_ant_reach_4_20220707_174507/ppo_mlp_98304000.pkl
Register qp_path
(path to saved brax.QP
) in dataset_config.py.
Convert brax.QP
to morphlogy-task graph representation (e.g. mtg_v2_base_m
):
CUDA_VISIBLE_DEVICES=0 python generate_behavior_from_qp.py --seed 0 --env ant_reach_4 --task_name ant_reach --data_name ant_reach_4_mtg_v2_base_m --obs_config2 mtg_v2_base_m
Register dataset_path
(path to saved observations) in dataset_config.py and task_config.py.
Train Transformer policy via multi-task behavior cloning:
CUDA_VISIBLE_DEVICES=0,1 python train_bc_transformer.py --task_name example --seed 0
# zero-shot evaluation
CUDA_VISIBLE_DEVICES=0,1 python train_bc_transformer_zs.py --task_name example --seed 0
# fine-tuning on multi-task imitation learning
CUDA_VISIBLE_DEVICES=0,1 python train_bc_transformer_fs.py --task_name example --seed 0 --params_path ../results/bc_transformer_zs/policy.pkl
ENV_DESCS = dict()
# add environments
for i in range(2, 7, 1):
ENV_DESCS[f'ant_reach_{i}'] = functools.partial(load_desc, num_legs=i)
ENV_DESCS[f'ant_reach_hard_{i}'] = functools.partial(load_desc, num_legs=i, r_min=10.5, r_max=11.5)
# missing
for i in range(3, 7, 1):
for j in range(i):
ENV_DESCS[f'ant_reach_{i}_b_{j}'] = functools.partial(load_desc, agent='broken_ant', num_legs=i, broken_id=j)
ENV_DESCS[f'ant_reach_hard_{i}_b_{j}'] = functools.partial(load_desc, agent='broken_ant', num_legs=i, broken_id=j, r_min=10.5, r_max=11.5)
min_dist=0
in each reward function dict.