AndreaCossu / Relation-Network-PyTorch

Implementation of Relation Network and Recurrent Relational Network using PyTorch v1.3. Original papers: (RN) https://arxiv.org/abs/1706.01427 (RRN): https://arxiv.org/abs/1711.08028
MIT License
19 stars 7 forks source link
machine-learning neural-reasoning python pytorch pytorch-implementation relation-network

Relation-Network-PyTorch

Implementation of Relation Network. Original paper: https://arxiv.org/abs/1706.01427

Implementation of Recurrent Relational Network. Original paper: https://arxiv.org/abs/1711.08028

This repository uses PyTorch v1.3 (Python3.7).

Implementation details

This implementation tests the Relation Network model (RN) and the Recurrent Relational Network model (RRN) against the babi dataset, available at https://research.fb.com/downloads/babi/

Weights and Biases

This repository uses Weights and Biases (W&B) to monitor experiments. You can create a free account on W&B (https://www.wandb.com/) or comment out the (few) lines starting with wandb. Without W&B, accuracy and loss plots will still be created and saved locally in the results folder.

Prerequisites

Train and test RN

To reproduce results execute python launch_rn_babi.py test --learning_rate 1e-4 --batch_size 20 --epochs 50 and then check under results/test to see the results. If you want to do the final test on the test set instead of validation set, use --test_on_test option. The final accuracy on test set is (task: accuracy):

Observations

Train and test RRN

To reproduce results execute python launch_rrn_babi.py test --cuda --epochs 500 --batch_size 512 --weight_decay 1e-5 --learning_rate 2e-4 and then check under results/rrn/test to see the results. If you want to do the final test on the test set instead of validation set, use --test_on_test option. The final accuracy on validation set is (task: accuracy):