frt03 / mxt_bench

A System for Morphology-Task Generalization via Unified Representation and Behavior Distillation (ICLR2023)
https://arxiv.org/abs/2211.14296
12 stars 4 forks source link

A System for Morphology-Task Generalization via Unified Representation and Behavior Distillation

Accepted to ICLR2023 (notable-top-25%, Spotlight) [arxiv] [Website]

Citation

If you use this codebase for your research, please cite the paper:

@inproceedings{furuta2023asystem,
  title={A System for Morphology-Task Generalization via Unified Representation and Behavior Distillation},
  author={Hiroki Furuta and Yusuke Iwasawa and Yutaka Matsuo and Shixiang Shane Gu},
  booktitle={International Conference on Learning Representations},
  year={2023},
}

Installation

pip install -r requirements.txt

Behavior Distillation Pipeline

  1. Train single-task single-morphology PPO policy on the environment:

    CUDA_VISIBLE_DEVICES=0 python train_ppo_mlp.py --logdir ../results --seed 0 --env ant_reach_4
  2. Pick trained policy weight, and collect expert brax.QP:

    CUDA_VISIBLE_DEVICES=0,1 python generate_behavior_and_qp.py --seed 0 --env ant_reach_4 --task_name ant_reach --params_path ../results/ao_ppo_mlp_single_pro_ant_reach_4_20220707_174507/ppo_mlp_98304000.pkl
  3. Register qp_path (path to saved brax.QP) in dataset_config.py.

  4. Convert brax.QP to morphlogy-task graph representation (e.g. mtg_v2_base_m):

    CUDA_VISIBLE_DEVICES=0 python generate_behavior_from_qp.py --seed 0 --env ant_reach_4 --task_name ant_reach --data_name ant_reach_4_mtg_v2_base_m --obs_config2 mtg_v2_base_m
  5. Register dataset_path (path to saved observations) in dataset_config.py and task_config.py.

  6. Train Transformer policy via multi-task behavior cloning:

    CUDA_VISIBLE_DEVICES=0,1 python train_bc_transformer.py --task_name example --seed 0
    # zero-shot evaluation
    CUDA_VISIBLE_DEVICES=0,1 python train_bc_transformer_zs.py --task_name example --seed 0
    # fine-tuning on multi-task imitation learning
    CUDA_VISIBLE_DEVICES=0,1 python train_bc_transformer_fs.py --task_name example --seed 0 --params_path ../results/bc_transformer_zs/policy.pkl

How to Register New Morphology

How to Register New Task

ENV_DESCS = dict()

# add environments
for i in range(2, 7, 1):
  ENV_DESCS[f'ant_reach_{i}'] = functools.partial(load_desc, num_legs=i)
  ENV_DESCS[f'ant_reach_hard_{i}'] = functools.partial(load_desc, num_legs=i, r_min=10.5, r_max=11.5)

# missing
for i in range(3, 7, 1):
  for j in range(i):
    ENV_DESCS[f'ant_reach_{i}_b_{j}'] = functools.partial(load_desc, agent='broken_ant', num_legs=i, broken_id=j)
    ENV_DESCS[f'ant_reach_hard_{i}_b_{j}'] = functools.partial(load_desc, agent='broken_ant', num_legs=i, broken_id=j, r_min=10.5, r_max=11.5)

Structure

Reference