fmeirinhos / pytorch-hessianfree

PyTorch implementation of Hessian Free optimisation
MIT License
43 stars 8 forks source link

Train Deep Linear/CNN #3

Closed opooladz closed 10 months ago

opooladz commented 4 years ago

I was hoping to get an example of a simple train/test example using the hessian free optimizer on mnist. I want to eventually actually later try a hessian for the levenberg-marquardt optimization rule. But for now I am trying to work with the EFM. I am also for now using the Hessian vector product instead of the GGN matrix vector product. Below is my equivalent to the hf_test file. I also needed to change the code for the EFM slightly.

When I run the following I either get a loss that explodes and then becomes nan or told that the matrix is not invertible. But in both cases the loss is exploding.

I define the closure inside the for loop so as to not track the rest of the changes in the other file.

I eventually want to get this working on a CNN as well. So if I can get help on either/both that would be amazing.

Thank you in advance.

transform = transforms.Compose([transforms.ToTensor(),
                              transforms.Normalize((0.5,), (0.5,)),
                              ])
path = 'PATH/TO/DATASET'
trainset = datasets.MNIST(path+'/train', download=True, train=True, transform=transform)
valset = datasets.MNIST(path+'/test', download=True, train=False, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
valloader = torch.utils.data.DataLoader(valset, batch_size=64, shuffle=True)

dataiter = iter(trainloader)
images, labels = dataiter.next()

# reduced image size to fit on gpu
input_size = 196
hidden_sizes = [128, 64]
output_size = 10

model = nn.Sequential(nn.Linear(input_size, hidden_sizes[0]),
                      nn.ReLU(),
                      nn.Linear(hidden_sizes[0], hidden_sizes[1]),
                      nn.ReLU(),
                      nn.Linear(hidden_sizes[1], output_size),
                      nn.LogSoftmax(dim=1))
criterion =  nn.NLLLoss()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# def M_inv(x,y):  # inverse preconditioner
#     return empirical_fisher_diagonal(model, x, y, criterion)
def M_inv(data,label):  # inverse preconditioner
    return empirical_fisher_matrix(model, data, label, criterion)

optimizer = HessianFree(model.parameters(), use_gnm=False, verbose=True)

epochs = 15
for e in range(epochs):
    running_loss = 0
    for images, labels in trainloader:
        # Flatten MNIST images into a 784 long vector
        images = images[:,:,::4,::4].contiguous()
        images = images.view(images.shape[0], -1)

        # Training pass
        optimizer.zero_grad()
        images = images.cuda()
        labels = labels.cuda()
        output = model(images)
        loss = criterion(output, labels)

        loss.backward()
        def closure():
            z = model(images)
            loss = criterion(z, labels)
            loss.backward(create_graph=True)
            return loss, z
        optimizer.step(closure, M_inv=M_inv(images,labels))

        running_loss += loss.item()
    else:
        print("Epoch {} - Training loss: {}".format(e, running_loss/len(trainloader)))

# The empirical Fisher matrix (Section 20.11.3)
def empirical_fisher_matrix(net, xs, ys, criterion):
    grads = list()
    for (x, y) in zip(xs, ys):
        fi = criterion(net(x[None,:]), y.repeat(1))
        grad = torch.autograd.grad(fi, net.parameters(),
                                   retain_graph=False)
        grads.append(torch.cat([g.detach().flatten() for g in grad]))

    grads = torch.stack(grads)
    n_batch = grads.shape[0]
    return torch.einsum('ij,ik->jk', grads, grads) / n_batch
opooladz commented 4 years ago

Hope all is well,

Just to make clear, I have the code working when I set M_inv = None in the optimizer step. My interest falls in using the preconditioner.

torch.inverse(m_inv + damping * torch.eye(*m_inv.shape).to(device))

I have not been able to get it working (trying it on MNIST) with the fisher matrix or the diagonal or the H \approx J^{\text{T}}J approximation of the Hessian.

And direction would be greatly appreciated. Thank you in advance.

fmeirinhos commented 4 years ago

@opooladz , sorry for the very late reply.

Unfortunately I have no experience with Fisher information matrices and I can't really remember why I ended up implementing them.

Have you read Martens' section Designing a Good Preconditioner? I remember skimming through some papers exploring preconditions for these kind of optimisation problems. It seems to be quite tricky and there are some hyper parameters to them that should require you to play a bit with the source code. It could be that there is a bug in their implementation :|

Have you had any progress on this problem?

opooladz commented 4 years ago

To my understanding the Fisher information matrix just acts like the hessian matrix.

Yes I have read through some of the parts of Martens' book. If you dont give in a preconditioner to your code it wont use line 126 of hessainfree

m = torch.inverse(m_inv + damping * torch.eye(*m_inv.shape))

I have not been able to make progress. I basically just want a second order optimizer such as newton method that works on CNN and is integrated with pytorch. Have you been able to run even a vanilla second order method. Running with Conjugate-Gradient and backtracking is of course a plus.