seba-1511 / sobolev.pth

Small scale replication of Sobolev Training for NNs
Apache License 2.0
7 stars 0 forks source link

Sobolev Training with Pytorch

Small scale replication of Sobolev Training for NNs.

Overview

You can use the code by importing SobolevLoss from sobolev.py. In order to use it, checkout the example in main.py. The general guideline for distillation is:

from sobolev import SobolevLoss

teacher = Net()
student = Net()
loss = SobolevLoss(loss=nn.MSELoss(), weight=1.0, order=2)

# compute the gradients of teacher and student

sobolev = loss(student.parameters(), teacher.parameters())

# At this point, the parameters' gradients of student look like:
# s.grad = s.original_grad + s.grad.grad
# where s.grad.grad comes from the Sobolov loss

Remarks:

Benchmark results

The results obtained by distilling a LeNet-teacher (converged) into a LeNet-student with the same random architecture. The results are in the form train / test at the 100th epoch of training.

Metric Vanilla Sobolev
Distill. Loss 1.2 / 1.19 0.56 / 0.64
Student Loss 0.94 / 0.9 0.8 / 0.82
Teacher Loss 0.7 / 0.72 0.7 / 0.72
Sobolev Loss n / a 2e-4 / 4e-4