absdnd / aux_distill

2 stars 0 forks source link

Reinforcement Learning via Auxiliary Task Distillation

We present Reinforcement Learning via Auxiliary Task Distillation (AuxDistill), a new method that enables reinforcement learning (RL) to perform long-horizon robotic control problems by distilling behaviors from auxiliary RL tasks. AuxDistill achieves this by concurrently carrying out multi-task RL with auxiliary tasks, which are easier to learn and relevant to the main task. A weighted distillation loss transfers behaviors from these auxiliary tasks to solve the main task. We demonstrate that AuxDistill can learn a pixels-to-actions policy for a challenging multi-stage embodied object rearrangement task from the environment reward without demonstrations, a learning curriculum, or pre-trained skills. AuxDistill achieves 2.3× higher success than the previous state-of-the-art baseline in the Habitat Object Rearrangement benchmark and outperforms methods that use pre- trained skills and expert demonstrations.

Installation

Datasets

To download the datasets used for training use the datasets provided in the train folder at the following link and place it in the data/datasets/replica_cad/rearrange/train folder. Place the validation dataset in the data/datasets/replica_cad/rearrange/val folder.

Training

Rearrangement

To train the code for performing the rearrangement task use - bash skill_chain/train/rearrange.sh. This will train the model using the configuration file configs/rearrange.yaml. The checkpoints will by default be saved to the data/ckpts/rearrange directory. To train different seeds simply append the argument habitat.seed=$SEED to the command.

Language Pick

To execute the language pick task use the command - bash skill_chain/train/lang_pick.sh. To train a different seed, append habitat.seed=$SEED to the run command

Evaluation

Pretrained checkpoints

The pre-trained checkpoints for the rearrangement and language pick tasks can be found at the following link, place the checkpoints in the data/ckpts/pretrained folder. To run the pretrained checkpoint, run bash scripts/eval/rearrange_pretrained.sh to execute the run command. Change the checkpoint path to evaluate different seeds in this experiment. For language pick evaluation, run bash scripts/eval/lang_pick_pretrained.sh to evaluate the pretrained language pick model.

Evaluating Trained Checkpoints

Use the command, bash scripts/eval/rearrange.sh to evaluate the trained checkpoints. By default this evaluates the latest checkpoint saved in data/ckpts/rearrange/latest.pth. This path can be adjusted in the evaluation script. The videos are generated in the path given by habitat_baselines.video_dir defined in the script.

Documentation

Citation

@article{harish2024,
  title={Reinforcement Learning via Auxiliary Task Distillation},
  author={Harish Abhinav Narayan, Heck Larry, Hanna Josiah, Kira Zsolt and Szot Andrew},
  journal={arXiv preprint arXiv:2406.17168},
  year={2023}
}