AlexHex7 / SimGAN_pytorch

[Refer to wayaai/SimGAN(Keras&Tensorflow)] Implementation of Apple's Learning from Simulated and Unsupervised Images through Adversarial Training
41 stars 9 forks source link

Bug in ImageHistoryBuffer implementation #5

Open tonyzhang2035 opened 3 years ago

tonyzhang2035 commented 3 years ago

I did some training using this repo and the results are not very satisfactory. After a closer look at the code I found a bug in the ImageHistoryBuffer implementation :

np.append(self.image_history_buffer, images[:nb_to_add], axis=0)

Per np.append

Returns appendndarray A copy of arr with values appended to axis. Note that append does not occur in-place: a new array is allocated and filled. If axis is None, out is a flattened array.

The function call itself won't modify the source array and the image buffer will never be filled with anything. This can be verified by printing the buffer size during training.

Here's a fix and modification for better efficiency (list.extend is done in-place to avoid copying the entire image buffer).

class ImageHistoryBuffer(object):
    def __init__(self, max_size, batch_size):
        self.image_history_buffer = []
        self.max_size = max_size
        self.batch_size = batch_size

    def add_to_buffer(self, images, num_to_add=None):
        if not num_to_add:
            num_to_add = self.batch_size // 2
        images = images.tolist()
        if len(self.image_history_buffer) < self.max_size:
            num_to_add = min(num_to_add, self.max_size - len(self.image_history_buffer))
            self.image_history_buffer.extend(images[:num_to_add])
        elif len(self.image_history_buffer) == self.max_size:
            self.image_history_buffer[:num_to_add] = images[:num_to_add]
        else:
            assert False, "Image history buffer overflow"

        random.shuffle(self.image_history_buffer)

    def get_from_buffer(self, num_to_get=None):
        if not num_to_get:
            num_to_get = self.batch_size // 2

        try:
            return np.array(self.image_history_buffer[:num_to_get], dtype=np.float32)
        except IndexError:
            return np.zeros(shape=0)

Hope this helps!