photosynthesis-team / piq

Measures and metrics for image2image tasks. PyTorch.
Apache License 2.0
1.4k stars 120 forks source link

FID estimation is too imprecise (too few iterations) #381

Open catwell opened 7 months ago

catwell commented 7 months ago

FID estimation uses the Newton-Schulz Iterative method to compute the square root of a matrix and stops too early.

To determine when to stop computing sqrt(M), the algorithm computes the ratio of the norms of the square of the current approximation and M and stops when it drops below 1e-5.

However, this is not what we is used in FID, what matters is the trace of the resulting square root.

For instance, using a real dataset, I have modified the code to run until the difference in norm above is 1e-7 instead, and here are the values of the trace of s_matrix:

iteration error trace
0 0.3495 495.6085
1 0.2281 650.9324
2 0.1414 825.6144
3 0.0814 1010.0005
4 0.0444 1190.7549
5 0.0230 1355.3980
6 0.0115 1496.0165
7 0.0057 1610.0487
8 0.0027 1698.8550
9 0.0013 1764.7837
10 0.0006 1811.5501
11 0.0003 1843.3400
12 0.0001 1864.3024
13 5.0639e-05 1877.7849
14 2.2064e-05 1886.2064
15 9.5923e-06 1891.3386
16 4.1472e-06 1894.3981
17 1.7905e-06 1896.1724
18 7.7794e-07 1897.1686
19 3.4087e-07 1897.7101
20 1.4866e-07 1897.9915
21 6.5038e-08 1898.1279

If we stop at 1e-5 we stop at iteration 15 and use 1891.34 for the value of the trace instead of the actual trace which should be 1898.22 in this case (computed with scipy's linalg.sqrtm). This results in a FID of 90.74 instead of 76.98 (a significant difference!).

This issue was introduced in this commit (@denproc).

A potential fix could be to use the absolute difference of the trace between two iterations as an additional stopping criterion, although there could still theoretically be cases where this would not be enough. I can submit a PR with that change if you want.