This is a PyTorch implementation of Recurrent Models of Visual Attention by Volodymyr Mnih, Nicolas Heess, Alex Graves and Koray Kavukcuoglu.
The Recurrent Attention Model (RAM) is a neural network that processes inputs sequentially, attending to different locations within the image one at a time, and incrementally combining information from these fixations to build up a dynamic internal representation of the image.
In this paper, the attention problem is modeled as the sequential decision process of a goal-directed agent interacting with a visual environment. The agent is built around a recurrent neural network: at each time step, it processes the sensor data, integrates information over time, and chooses how to act and how to deploy its sensor at the next time step.
phi
around location l
from an image x
. It encodes the region around l
at a high-resolution but uses a progressively lower resolution for pixels further from l
, resulting in a compressed representation of the original image x
.phi
) and the "where" (l
) into a glimpse feature vector wg_t
.h_t
that gets updated at every time step t
.h_t
of the core network to produce the location coordinates l_t
for the next time step.h_t
of the core network to produce the final output classification y
.I decided to tackle the 28x28
MNIST task with the RAM model containing 6 glimpses, of size 8x8
, with a scale factor of 1
.
Model | Validation Error | Test Error |
---|---|---|
6 8x8 | 1.1 | 1.21 |
I haven't done random search on the policy standard deviation to tune it, so I expect the test error can be reduced to sub 1%
error. I'll be updating the table above with results for the 60x60
Translated MNIST, 60x60
Cluttered Translated MNIST and the new Fashion MNIST dataset when I get the time.
Finally, here's an animation showing the glimpses extracted by the network on a random batch at epoch 23.
With the Adam optimizer, paper accuracy can be reached in ~160 epochs.
The easiest way to start training your RAM variant is to edit the parameters in config.py
and run the following command:
python main.py
To resume training, run:
python main.py --resume=True
Finally, to test a checkpoint of your model that has achieved the best validation accuracy, run the following command:
python main.py --is_train=False