FID estimation uses the Newton-Schulz Iterative method to compute the square root of a matrix and stops too early.
To determine when to stop computing sqrt(M), the algorithm computes the ratio of the norms of the square of the current approximation and M and stops when it drops below 1e-5.
However, this is not what we is used in FID, what matters is the trace of the resulting square root.
For instance, using a real dataset, I have modified the code to run until the difference in norm above is 1e-7 instead, and here are the values of the trace of s_matrix:
iteration
error
trace
0
0.3495
495.6085
1
0.2281
650.9324
2
0.1414
825.6144
3
0.0814
1010.0005
4
0.0444
1190.7549
5
0.0230
1355.3980
6
0.0115
1496.0165
7
0.0057
1610.0487
8
0.0027
1698.8550
9
0.0013
1764.7837
10
0.0006
1811.5501
11
0.0003
1843.3400
12
0.0001
1864.3024
13
5.0639e-05
1877.7849
14
2.2064e-05
1886.2064
15
9.5923e-06
1891.3386
16
4.1472e-06
1894.3981
17
1.7905e-06
1896.1724
18
7.7794e-07
1897.1686
19
3.4087e-07
1897.7101
20
1.4866e-07
1897.9915
21
6.5038e-08
1898.1279
If we stop at 1e-5 we stop at iteration 15 and use 1891.34 for the value of the trace instead of the actual trace which should be 1898.22 in this case (computed with scipy's linalg.sqrtm). This results in a FID of 90.74 instead of 76.98 (a significant difference!).
This issue was introduced in this commit (@denproc).
A potential fix could be to use the absolute difference of the trace between two iterations as an additional stopping criterion, although there could still theoretically be cases where this would not be enough. I can submit a PR with that change if you want.
FID estimation uses the Newton-Schulz Iterative method to compute the square root of a matrix and stops too early.
To determine when to stop computing
sqrt(M)
, the algorithm computes the ratio of the norms of the square of the current approximation andM
and stops when it drops below1e-5
.However, this is not what we is used in FID, what matters is the trace of the resulting square root.
For instance, using a real dataset, I have modified the code to run until the difference in norm above is
1e-7
instead, and here are the values of the trace ofs_matrix
:If we stop at
1e-5
we stop at iteration 15 and use1891.34
for the value of the trace instead of the actual trace which should be1898.22
in this case (computed with scipy'slinalg.sqrtm
). This results in a FID of90.74
instead of76.98
(a significant difference!).This issue was introduced in this commit (@denproc).
A potential fix could be to use the absolute difference of the trace between two iterations as an additional stopping criterion, although there could still theoretically be cases where this would not be enough. I can submit a PR with that change if you want.