alecGraves / BVAE-tf

Disentangled Variational Auto-Encoder in TensorFlow / Keras (Beta-VAE)
The Unlicense
54 stars 13 forks source link

better way to handle batch_size #1

Closed schlerp closed 5 years ago

schlerp commented 5 years ago


really like your implementation but noticed the static batch size was causing me all sorts of grief when i wanted to play around with training. after a bit of mucking around i came up with a solution that i feel is a little more elegant.

basically the issue arises because during construction there is a call to the instantiated layer. at that point the tensor being passed in as "x" to the sampling layers call() function has an undefined batch_size. at build time all we need to do is return a tensor with the appopriate shape, we dont actually need to call the K.random_normal() function which is the only part of this function that needs the batch_size explicitly.

long story short, stick this in your function:

        # trick to allow setting batch at train/eval time
        if x[0].shape[0].value == None:
            return mean + 0*stddev

in context that is (i made some slight other changes to function but you can ignore them, this is just so you can see how y fix would fit into the function):

    def call(self, x):
        if len(x) != 2:
            raise Exception('input layers must be a list: mean and stddev')
        if len(x[0].shape) != 2 or len(x[1].shape) != 2:
            raise Exception('input shape is not a vector [batchSize, latentSize]')

        mean = x[0]
        stddev = x[1]        

        # trick to allow setting batch at train/eval time
        if x[0].shape[0].value == None:
            return mean + 0*stddev

        if self.reg:
            # kl divergence:
            latent_loss = -0.5 * K.mean(1 + stddev
                                        - K.square(mean)
                                        - K.exp(stddev), axis=-1)        

            if self.reg == 'bvae':
                # use beta to force less usage of vector space:
                # also try to use <capacity> dimensions of the space:
                latent_loss = self.beta * K.abs(latent_loss - self.capacity/self.shape.as_list()[1])

            self.add_loss(latent_loss, x)

        epsilon = K.random_normal(shape=self.shape,
                              mean=0., stddev=1.)
        if self.random:
            # 'reparameterization trick':
            return mean + K.exp(stddev / 2) * epsilon
        else: # do not perform random sampling, simply grab the impulse value
            return mean + 0*stddev # Keras needs the *0 so the gradinent is not None
alecGraves commented 5 years ago

Ok, thanks for the recommendation! I was not sure how to get around the issue of not knowing batch size at runtime, this makes a lot of sense. I will try to update the repo when I have some free time.

beldaz commented 5 years ago

Note your refactoring of KLD needs fixing to remain compatible with