stepjam / TecNets

Official code for "Task-Embedded Control Networks for Few-Shot Imitation Learning".
Other
44 stars 10 forks source link

Task Embedded Control Networks

TecNets Example

The code presented here was used in: Task-Embedded Control Networks for Few-Shot Imitation Learning.

Running Paper Experiments

If you want to be able to re-run the experiments presented in the paper, you will need some of the dependencies from a paper that we compare against: One-Shot Visual Imitation Learning via Meta-Learning .

Follow these steps:

  1. First clone the fork of the gym repo found here, and switch to branch mil.
  2. You can now either install this, or just add the gym fork to your PYTHONPATH.
  3. Download the _mil_simreach and _mil_simpush datsets from here. Unzip them to the datasets folder. Note: The data format here has been changed slightly in comparison to the original data from the MIL paper.
  4. (Optional) Run the integration test to make sure everything is set-up correctly.

To run the reaching task, run:

./tecnets_corl_results.sh sim_reach

To run the pushing task, run:

./tecnets_corl_results.sh sim_push

Code Design

This section is for people who wish to extend the framework.

The code in designed in a pipelined fashion, where there are a list of consumers that takes in a dictionary of inputs (from a previous consumer) and then outputs a combined dictionary of the inputs and outputs of that consumer. For example:

a = GeneratorConsumer(...)
b = TaskEmbedding(...)
c = MarginLoss(...)
d = Control(...)
e = ImitationLoss(...)
consumers = [a, b, c, d, e]
p = Pipeline(consumers)

This allows the TecNet to be built in a modular way. For example, if one wanted to do use a prototypical loss rather than a margin loss, then one would only need to swap out one of these consumers.

Citation

@article{james2018task,
  title={Task-Embedded Control Networks for Few-Shot Imitation Learning},
  author={James, Stephen and Bloesch, Michael and Davison, Andrew J},
  journal={Conference on Robot Learning (CoRL)},
  year={2018}
}