Closed rcmalli closed 4 years ago
Hi, sorry for the delay!
I put your questions in a list:
DataLoaders
.OptimizerParameters
class is only relevant if you are going to meta-learning parameters of the optimizer. See below how to instantiate task learners using meta-learned optimizer parameters. Here is a complete MVP using the Omniglot dataloader:
import torch
import torch.nn as nn
import torch.optim
import warpgrad
from data import DataContainer
class SimpleModel(nn.Module):
# The second linear layer will be used as warp layer.
def __init__(self, hidden_dim):
super(SimpleModel, self).__init__()
self.linear1 = nn.Linear(784, hidden_dim)
self.relu1 = nn.ReLU()
self.linear2 = nn.Linear(hidden_dim, hidden_dim)
self.relu2 = nn.ReLU()
self.linear3 = nn.Linear(hidden_dim, 20)
def forward(self, x):
x = x.view(x.size(0), -1)
out = self.linear1(x)
out = self.relu1(out)
out = self.linear2(out)
out = self.relu2(out)
out = self.linear3(out)
return out
def init_adaptation(self):
"""Things like resetting heads and running stats."""
return
# Data and loss
criterion = nn.CrossEntropyLoss()
data = DataContainer(root='./data',
num_pretrain_alphabets=20, num_classes=20,
seed=1, num_workers=1, pin_memory=True)
# Model
simple_model = SimpleModel(hidden_dim=100)
updater = warpgrad.DualUpdater(criterion)
replay_buffer = warpgrad.ReplayBuffer()
optimizer_parameters = warpgrad.OptimizerParameters(trainable=False,
default_lr=0.01,
default_momentum=0.)
model = warpgrad.Warp(
model=simple_model,
adapt_modules=[simple_model.linear1, simple_model.relu1,
simple_model.relu2, simple_model.linear3],
warp_modules=[simple_model.linear2],
updater=updater,
buffer=replay_buffer,
optimizer_parameters=optimizer_parameters)
meta_optimizer = torch.optim.Adam(model.warp_parameters(), lr=0.001)
def step_fn():
# For meta-learning.
meta_optimizer.step()
meta_optimizer.zero_grad()
# Meta-training
print("starting meta-training")
for train_step in range(1, 11):
print(f"meta-training step {train_step} / 10", end=" ")
meta_batch = data.train(4, 20, 10, False)
for task in meta_batch:
# Set up task learner
model.register_task(task) # task is an instance of torch.utils.data.DataLoader
model.collect()
model.init_adaptation()
inner_optimizer = torch.optim.SGD(model.optimizer_parameter_groups())
# Train task learner
for inputs, target in task:
prediction = model(inputs)
loss = criterion(prediction, target)
loss.backward()
inner_optimizer.step()
inner_optimizer.zero_grad()
print("done.")
# All tasks in meta-batch are done. Time for meta-learning.
model.backward(step_fn)
# Meta-evaluation
print("starting meta-evaluation on held-out tasks")
meta_batch = data.test(20, 10, False)
for task_iter, task in enumerate(meta_batch):
print(f"task {task_iter} / 10:", end=" ")
# Set up task learner
model.no_collect() # don't store task trajectories in buffer
model.init_adaptation()
inner_optimizer = torch.optim.SGD(model.optimizer_parameter_groups())
# Train task learner
for inputs, target in task:
prediction = model(inputs)
loss = criterion(prediction, target)
loss.backward()
inner_optimizer.step()
inner_optimizer.zero_grad()
# Evaluate
task.dataset.eval() # Grabs test data: your implementation might look different
n_correct = []
n_samples = 0
for inputs, target in task:
prediction = model(inputs)
n_correct.append((prediction.max(1)[1] == target).sum().item())
n_samples += prediction.size(0)
accuracy = sum(n_correct) / n_samples
print(f"accuracy={accuracy}")
I notice here in your MVP you never call inner_optimizer.step(). can you explain why?
Oops, my bad! Updated the MVP to include the inner_opt call
I am struggling to implement a minimal example using this repository as a library. Based on a given code sample in the readme file, could you specify the required helper classes should I use in the following steps?
Creating the dataloaders:
model.register_task(mytask)
?After successfully handling the first two steps, according to the instruction, I should be able to initialize the Warpgrad classes as follows: