pystiche / papers

Reference implementation and replication of prominent NST papers
BSD 3-Clause "New" or "Revised" License
4 stars 1 forks source link

Li wand score correction factor #234

Closed pmeier closed 3 years ago

pmeier commented 3 years ago

I finally remembered why I added the score_correction_factor for the implementation parameters: the key observation is that the functionalities are implemented in the backward rather than the forward pass. That means, the gradient is calculated rather than the loss.

Roughly speaking, if we are not seeing any correction factor in the gradient calculation, it has to appear in the score calculation. Since the loss is calculated with the squared euclidean distance, we need a score_correction_factor=0.5 to compensate for that.

For example, look at the gradient calculation of the total variation loss in the reference implementation of the original authors and compare it to pystiches implementation:

import torch
import pystiche.ops.functional as F

input = torch.rand(1, 3, 128, 128, requires_grad=True)

horz_diff = input[..., :-1, :-1] - input[..., :-1, 1:]
vert_diff = input[..., :-1, :-1] - input[..., 1:, :-1]

grad = torch.zeros_like(input)
grad[..., :-1, :-1] = horz_diff + vert_diff
grad[..., :-1, 1:] -= horz_diff
grad[..., 1:, :-1] -= vert_diff

score_correction_factor = 0.5
loss = score_correction_factor * F.total_variation_loss(input, exponent=2.0, reduction="sum")
loss.backward()

assert torch.allclose(input.grad, grad)

Without score_correction_factor=0.5 this would not be the same. The same holds true for the MRF loss.

Cc @jbueltemeier

codecov[bot] commented 3 years ago

Codecov Report

Merging #234 (874e781) into master (2170ae3) will increase coverage by 0.0%. The diff coverage is 100.0%.

Impacted file tree graph

@@          Coverage Diff           @@
##           master    #234   +/-   ##
======================================
  Coverage    98.8%   98.8%           
======================================
  Files          38      38           
  Lines        1508    1512    +4     
======================================
+ Hits         1490    1494    +4     
  Misses         18      18           
Impacted Files Coverage Δ
pystiche_papers/li_wand_2016/_loss.py 98.7% <100.0%> (+<0.1%) :arrow_up:

Continue to review full report at Codecov.

Legend - Click here to learn more Δ = absolute <relative> (impact), ø = not affected, ? = missing data Powered by Codecov. Last update 2170ae3...874e781. Read the comment docs.