photosynthesis-team / piq

Measures and metrics for image2image tasks. PyTorch.
Apache License 2.0
1.32k stars 114 forks source link

Wrong total variation calculation #328

Open Dobatymo opened 1 year ago

Dobatymo commented 1 year ago

The total variation (l2 version) is calculated here as sqrt(sum(d_w**2 + d_h**2)). Shouldn't it be sum(sqrt(d_w**2 + d_h**2)) instead? See https://github.com/photosynthesis-team/piq/blob/26d044e28231cd286b4a7e9e0e6c704d1ed39398/piq/tv.py#L34-L37 Now the problem is how to vectorize this correctly...

zakajd commented 1 year ago

Hi @Dobatymo I believe that both variants are equally common in the literature. Wikipedia article has the summation outside, while other sources (see image) put it inside. We have exact formula included in the docs so user can decide if it satisfies his use case or not. image

Feel free to close the issues if it answers your question!

Dobatymo commented 1 year ago

Hi @zakajd Sorry I missed the formula in the docs. However both Wikipedia and the two references from the docs have the sum outside. I am not familiar with any formulation which has the sum inside. I am only familiar with the isotropic and anisotropic formulations. However both have the sum outside (well it only matters for the isotropic version). Only the sum of the per pixel norm differs.

EDIT: I would suggest

d_w = x[:, :, :-1, 1:] - x[:, :, :-1, :-1]
d_h = x[:, :, 1:, :-1] - x[:, :, :-1, :-1]
score = torch.sum(torch.sqrt(torch.square(d_w) + torch.square(d_h)), dim=(1, 2, 3))

For l2_squared, it doesn't really matter as well.