tensorflow / gan

Tooling for GANs in TensorFlow
Apache License 2.0
927 stars 246 forks source link

Computation of Sliced Wasserstein Distance numerically unstable #28

Closed banelli closed 4 years ago

banelli commented 4 years ago

The function _normalize_patches() defined and used here in tensorflow_gan.python.eval.sliced_wasserstein divides by the standard deviation of the values within a patch. To respect the case where the standard deviation is zero I suggest to add a small constant for numerical stability here:

patches = (patches - mean) / (tf.sqrt(variance) + 1.0e-12)

joel-shor commented 4 years ago

This is a good observation, but would change the value of the eval metric and make it incomparable to previously reported numbers. It might be worth proposing an alternate metric, and naming it something different!