Open phquang opened 5 years ago
Thanks for your inquiry, I wouldn’t be too surprised if this implementation has bugs. (In fact, I noticed after your comment that the forward
method of SobolevLoss
should return loss
not sobolev
.)
Regarding your concern, could you clarify the second expression in your post ? (Especially D student(x)
) My understanding is that Sobolev training consists of:
Do you agree with this description ?
Your steps 2 and 3 are correct, the difference is in which variable you take the derivatives to compute the gradient in the first step. The objective in Sobolev training computes the gradient wrt to the input where you are calculating gradient wrt to the parameters. For example, consider a teacher as
t(z) = cz^2 + 1 , where c = 2 (fixed)
and we want to approximate it by learning the parameter a
in a student
s(z) = az^2 + 1
For some inputs x
, we first calculate the MSE between the teacher and student as (this is common for both approaches)
MSE = 0.5 * ( s(x) - t(x) ) ^2
The gradient stored in a
when MSE.backward()
is called is
a.grad = D MSE / D a = (ax^2 - 2x^2) * a = a^2x^2 - 2ax^2
and the gradient in c
is
c.grad = c^2x^2 - 2cx^2
Now what you are doing is
a.grad.grad = D MSE / (D a D a) = 2ax^2 - 2x^2
c.grad.grad = D MSE / (D c D c) = 2cx^2 - 2x^2
So the Sobolev loss you are calculating is
Sobolev = (2ax^2 - 2x^2) - (2cx^2 - 2x^2) = 2x^2 (a -c) = 2x^2 ( a - 2)
To calculate the correct Sobolev loss, the Jacobian(gradient) in this case is computed as
J(s) = D s(x) / D x = D (ax^2 + 1) / D (x) = 2ax
J(t) = D t(x) / D x = 2cx
So the correct Sobolev loss is
Sobolev = 2x(a - c) = 2x(a -2)
The two versions are not exactly the same, I picked a not so good example because the two losses look quite similar, however, you can see a clear different you you change MSE to other types of loss.
Thanks, your explanation makes sense and it matches equation (1) from the paper.
Are you interested in fixing the implementation ? Unfortunately, I don't have the bandwidth to work on this nowadays, so if not I'll put a link to this issue on the README to warn future users.
Thanks again for spotting that !
This implementation does not seem to be correct comparing with the formulation in the original paper. Your implementation seems to be matching the gradient of the loss wrt to the parameters of the teacher and student networks, the gradient you match is
where D denotes the partial derivatives and Loss is the distillation loss. However, the correct Soboloev should be matching the Jacobian of the two networks, which is calculated as
I don't think the two gradient calculations are equivalent in this case.