openai / supervised-reptile

Code for the paper "On First-Order Meta-Learning Algorithms"
https://arxiv.org/abs/1803.02999
MIT License
989 stars 210 forks source link

About the role of training set in the process of prediction #13

Closed jaegerstar closed 6 years ago

jaegerstar commented 6 years ago

def _test_predictions(self, train_set, test_set, input_ph, predictions): if self.transductive: inputs, = zip(test_set) return self.session.run(predictions, feed_dict={input_ph: inputs}) res = [] for test_sample in testset: inputs, = zip(train_set) inputs += (test_sample[0],) res.append(self.session.run(predictions, feed_dict={input_ph: inputs})[-1]) return res

Why did you add train_set into _test_predictions function, since they had been learned in

self.session.run(minimize_op, feed_dict={input_ph: inputs, label_ph: labels})

which is 119th line of reptile.py

unixpickle commented 6 years ago

The reason is that, when you don't use transduction, you need to get BatchNorm statistics from somewhere. So, _test_predictions feeds the training inputs into the network as well as each test sample, in order to leverage the BatchNorm statistics from the training set.