mseitzer / pytorch-fid

Compute FID scores with PyTorch.
Apache License 2.0
3.34k stars 506 forks source link

A better way to compute the FID #95

Open francois-rozet opened 1 year ago

francois-rozet commented 1 year ago

Hello, I think the following implementation of the Fréchet distance is faster than the current one and would allow to drop the scipy dependency.

def frechet_distance(mu_x: Tensor, sigma_x: Tensor, mu_y: Tensor, sigma_y: Tensor) -> Tensor:
    a = (mu_x - mu_y).square().sum(dim=-1)
    b = sigma_x.trace() + sigma_y.trace()
    c = torch.linalg.eigvals(sigma_x @ sigma_y).sqrt().real.sum(dim=-1)

    return a + b - 2 * c

The implementation is based on two facts:

  1. The trace of $A$ equals the sum of its eigenvalues.
  2. The eigenvalues of $\sqrt{A}$ are the square-roots of the eigenvalues of $A$.