This repo contains code accompaning the paper, Active One-shot Learning (Woodward and Finn, NIPS Deep RL Workshop 2016). It includes code for running the experiment described in the paper.
Variables representing a tensor in the graph end with _t
. For example, you might feed the tensor last_label_t
with the numpy variable last_label
. Also, a variable ending with _ts
is a list of tensors.
This code requires the following:
A preprocessed version of the original Omniglot dataset is included with this project.
$ python3 train.py
$ tensorboard --logdir ./logs
The accuracy curves will look like the following, this one is for the second instance of a class in an episode:
The code in this project builds a graph of the full training episode. If you wish to use the model after training, you would likely do one step at a time. Here is an example of what that code might look like:
last_label_t = tf.placeholder(tf.int32, shape=(1, params.num_labels))
features_t = tf.placeholder(tf.int32, shape=(1, 28, 28)) # shape of your data
agent = model.Agent(False, params)
action_t, _ = agent.next_action([last_label_t, features_t])
# loop, calling the following.
# Feeds the model with a numpy variables of yours called "features" and "last_label"
# last_label will generally be zeros, unless "action" requested the last label
# The numpy variable rnn_state is updated on each call.
# The numpy variable action will tell give you the model's chosen action
action, rnn_state = sess.run(
[action_t, agent.rnn_state_t],
{
agent.rnn_initial_state_t=rnn_state,
last_label_t=last_label, # feed zeros if no label was requested
features_t=features,
}
)
To ask questions or report issues, please open an issue on the issues tracker. Also, feel free to contact Mark Woodward at mwoodward@cs.stanford.edu.