Implementation of Relation Network. Original paper: https://arxiv.org/abs/1706.01427
Implementation of Recurrent Relational Network. Original paper: https://arxiv.org/abs/1711.08028
This repository uses PyTorch v1.3 (Python3.7).
This implementation tests the Relation Network model (RN) and the Recurrent Relational Network model (RRN) against the babi dataset, available at https://research.fb.com/downloads/babi/
This repository uses Weights and Biases (W&B) to monitor experiments. You can create a free account on W&B (https://www.wandb.com/) or comment out the (few) lines starting with wandb
. Without W&B, accuracy and loss plots will still be created and saved locally in the results
folder.
pip install -r requirements.txt
babi
babi
.babi
.results
to store experiment results (both plots and saved models)src/models/RN.py
task/babi_task/rn/train.py
launch_rn_babi.py
.
python launch_rn_babi.py experiment_name [options]
.python launch_rn_babi.py --help
.To reproduce results execute python launch_rn_babi.py test --learning_rate 1e-4 --batch_size 20 --epochs 50
and then check under results/test
to see the results. If you want to do the final test on the test set instead of validation set, use --test_on_test
option. The final accuracy on test set is (task: accuracy):
batch size > 1
. If batch size == 1
relu prevents learning, while tanh achieves ~74%
accuracy on the joint dataset.batch size == 1
in the branch no_batch
.src/models/RRN.py
task/babi_task/rrn/train.py
launch_rrn_babi.py
.
python launch_rrn_babi.py experiment_name [options]
.python launch_rrn_babi.py --help
.To reproduce results execute python launch_rrn_babi.py test --cuda --epochs 500 --batch_size 512 --weight_decay 1e-5 --learning_rate 2e-4
and then check under results/rrn/test
to see the results. If you want to do the final test on the test set instead of validation set, use --test_on_test
option. The final accuracy on validation set is (task: accuracy):