openai / guided-diffusion

MIT License
6.06k stars 807 forks source link

Implicit assumption of random ordering of generated images in calculation of Inception Score leads to underestimated ISC. #153

Open adilhasan927 opened 1 month ago

adilhasan927 commented 1 month ago

Hello, in the provided evaluator.py:

    def compute_inception_score(self, activations: np.ndarray, split_size: int = 5000) -> float:
        softmax_out = []
        for i in range(0, len(activations), self.softmax_batch_size):
            acts = activations[i : i + self.softmax_batch_size]
            softmax_out.append(self.sess.run(self.softmax, feed_dict={self.softmax_input: acts}))
        preds = np.concatenate(softmax_out, axis=0)
        # https://github.com/openai/improved-gan/blob/4f5d1ec5c16a7eceb206f42bfc652693601e1d5c/inception_score/model.py#L46
        scores = []
        for i in range(0, len(preds), split_size):
            part = preds[i : i + split_size]
            kl = part * (np.log(part) - np.log(np.expand_dims(np.mean(part, 0), 0)))
            kl = np.mean(np.sum(kl, 1))
            scores.append(np.exp(kl))
        return float(np.mean(scores))

Of interest is the computation of the KL divergence in batches of 5000. This implicitly assumes that, having generated 50,000 images of, say, 1000 ImageNet classes as our conditioning information, the images are ordered randomly in the provided array and thus the mean of the KL divergence of the batch approaches the KL divergence of the whole.

If instead the case occurs where the images are ordered by class (i.e. images of class 0 as the first 50 images, class 1 as the next 50, etc etc) in the provided array, the KL divergence of the batch will spike due to the batch only containing representations of 100 out of 1000 classes, and thus the calculated ISC will be artificially low.

This issue can be fixed by adding the following line:

np.random.shuffle(activations)

As a demonstration: If I sort my own generated ImageNet sample set of 50K images in order, I get ISC of \~50, the other is not in order and gets ISC of \~366.

Fortunately, this bug does not affect the academic research which uses this script for evaluations, because authors save the images to disk as individual files, then use Python to read the files back in -- which ends up being, by happenstance, in a random enough order that the batch statistics are close to the non-batched KL divergence.