flennerhag / warpgrad

Meta-Learning with Warped Gradient Descent
https://openreview.net/forum?id=rkeiQlBFPB
Apache License 2.0
92 stars 18 forks source link

How to create minimal working example? #5

Closed rcmalli closed 4 years ago

rcmalli commented 4 years ago

I am struggling to implement a minimal example using this repository as a library. Based on a given code sample in the readme file, could you specify the required helper classes should I use in the following steps?

  1. Assuming that I have three layer neural network initialized as follow:
class SimpleModel(nn.Module):
    # The second linear layer will be used as warp layer.
    def __init__(self,hidden_dim):
        super(SimpleModel, self).__init__()
        self.linear1 = nn.Linear(2,hidden_dim)
        self.relu1 = nn.ReLU()
        self.linear2 = nn.Linear(hidden_dim,hidden_dim)
        self.relu2 = nn.ReLU()
        self.linear3 = nn.Linear(hidden_dim,2)
    def forward(self,x):
        out = self.linear1(x)
        out = self.relu1(out)
        out = self.linear2(out)
        out = self.relu2(out)
        out = self.linear3(out)
        return out

# It seems like any model that will be given as an argument to the initialization of the `Warp` class should have a function called `init_adaptation`.
  1. Creating the dataloaders:

    • How should I define the task generator classes to use with model.register_task(mytask)?
  2. After successfully handling the first two steps, according to the instruction, I should be able to initialize the Warpgrad classes as follows:

updater = warpgrad.DualUpdater(criterion=nn.CrossEntropyLoss)
buffer = warpgrad.ReplayBuffer()

# I will already be going to use inner_optimizer and meta_optimizer for the task adaptation problem.
# What is the connection of these optimization parameters with the algorithm itself?

optimizer_parameters = warpgrad.OptimizerParameters(
                trainable=False,
                default_lr=1e-3,
                default_momentum=0.9)
simple_model = SimpleModel(hidden_dim=100)

model = warpgrad.Warp(simple_model , 
                        [simple_model.linear1,simple_model.relu1,simple_model.relu2,simple_model.linear3],   # List of nn.Modules for adaptation
                        [simple_model.linear2],  # List of nn.Modules for warping
                        updater, 
                        buffer,
                        optimizer_parameters) 

# Should we use the optimizers defined in the Warpgrad library as `meta_opt_class`?
meta_opt = meta_opt_class(model.warp_parameters(), **meta_opt_kwargs)
  1. Implementing the training loop
def meta_step_fn():
    meta_opt.step()
    meta_opt.zero_grad()

for meta_step in range(meta_steps):
    meta_batch = mytaskgenerator.sample()
    for task in meta_batch:
        model.init_adaptation()     # Initialize adaptation on the model side
        # How should task should be defined? It seems like it needs to be a specific class.
        model.register_task(task)   # Register task in replay buffer
        model.collect()             # Turn parameter collection on

        # Is this optimizer should also be from the warpgrad library?
        opt = opt_cls(model.adapt_parameters(), **opt_kwargs)

        for x, y in task:
            loss = criterion(model(x), y)
            loss.backward()
            opt.step()
            opt.zero_grad()

        # Evaluation
        model.eval()
        model.no_collect()
        your_evaluation(task, model)

    # Meta-objective
    model.backward(meta_step_fn)
flennerhag commented 4 years ago

Hi, sorry for the delay!

I put your questions in a list:

Here is a complete MVP using the Omniglot dataloader:

import torch
import torch.nn as nn
import torch.optim

import warpgrad
from data import DataContainer

class SimpleModel(nn.Module):
    # The second linear layer will be used as warp layer.
    def __init__(self, hidden_dim):
        super(SimpleModel, self).__init__()
        self.linear1 = nn.Linear(784, hidden_dim)
        self.relu1 = nn.ReLU()
        self.linear2 = nn.Linear(hidden_dim, hidden_dim)
        self.relu2 = nn.ReLU()
        self.linear3 = nn.Linear(hidden_dim, 20)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        out = self.linear1(x)
        out = self.relu1(out)
        out = self.linear2(out)
        out = self.relu2(out)
        out = self.linear3(out)
        return out

    def init_adaptation(self):
        """Things like resetting heads and running stats."""
        return

# Data and loss
criterion = nn.CrossEntropyLoss()
data = DataContainer(root='./data',
                     num_pretrain_alphabets=20, num_classes=20,
                     seed=1, num_workers=1, pin_memory=True)

# Model
simple_model = SimpleModel(hidden_dim=100)
updater = warpgrad.DualUpdater(criterion)
replay_buffer = warpgrad.ReplayBuffer()
optimizer_parameters = warpgrad.OptimizerParameters(trainable=False,
                                                    default_lr=0.01,
                                                    default_momentum=0.)
model = warpgrad.Warp(
  model=simple_model,
  adapt_modules=[simple_model.linear1, simple_model.relu1,
                 simple_model.relu2, simple_model.linear3],
  warp_modules=[simple_model.linear2],
  updater=updater,
  buffer=replay_buffer,
  optimizer_parameters=optimizer_parameters)

meta_optimizer = torch.optim.Adam(model.warp_parameters(), lr=0.001)

def step_fn():
    # For meta-learning.
    meta_optimizer.step()
    meta_optimizer.zero_grad()

# Meta-training
print("starting meta-training")
for train_step in range(1, 11):
    print(f"meta-training step {train_step} / 10", end=" ")
    meta_batch = data.train(4, 20, 10, False)
    for task in meta_batch:
        # Set up task learner
        model.register_task(task)  # task is an instance of torch.utils.data.DataLoader
        model.collect()
        model.init_adaptation()
        inner_optimizer = torch.optim.SGD(model.optimizer_parameter_groups())

        # Train task learner
        for inputs, target in task:
            prediction = model(inputs)
            loss = criterion(prediction, target)
            loss.backward()
            inner_optimizer.step()
            inner_optimizer.zero_grad()
    print("done.")

    # All tasks in meta-batch are done. Time for meta-learning.
    model.backward(step_fn)

# Meta-evaluation
print("starting meta-evaluation on held-out tasks")
meta_batch = data.test(20, 10, False)
for task_iter, task in enumerate(meta_batch):
    print(f"task {task_iter} / 10:", end=" ")
    # Set up task learner
    model.no_collect()         # don't store task trajectories in buffer
    model.init_adaptation()
    inner_optimizer = torch.optim.SGD(model.optimizer_parameter_groups())

    # Train task learner
    for inputs, target in task:
        prediction = model(inputs)
        loss = criterion(prediction, target)
        loss.backward()
        inner_optimizer.step()
        inner_optimizer.zero_grad()

    # Evaluate
    task.dataset.eval()  # Grabs test data: your implementation might look different
    n_correct = []
    n_samples = 0
    for inputs, target in task:
        prediction = model(inputs)
        n_correct.append((prediction.max(1)[1] == target).sum().item())
        n_samples += prediction.size(0)
    accuracy = sum(n_correct) / n_samples
    print(f"accuracy={accuracy}")
blake-camp commented 4 years ago

I notice here in your MVP you never call inner_optimizer.step(). can you explain why?

flennerhag commented 4 years ago

Oops, my bad! Updated the MVP to include the inner_opt call