tbepler / spatial-VAE

Source code for "Explicitly disentangling image content from translation and rotation with spatial-VAE" - NeurIPS 2019
MIT License
63 stars 19 forks source link

Reproduce Results on Galaxy Dataset #2

Open mmubeen-6 opened 3 years ago

mmubeen-6 commented 3 years ago

HI @tbepler, I am trying to reproduce the results of your paper on the galaxy dataset but unable to exactly achieve those. Could you please share the exact training parameters. I am currently using the following command to train it.

python3 train_galaxy.py galaxy_zoo/galaxy_zoo_train.npy galaxy_zoo/galaxy_zoo_test.npy -d 0 --num-epochs 300 --save-prefix galaxy_zoo_models/testing -z 100 --minibatch-size 100 --dx-scale 0.125.

Moreover, in order visualize the reconstructed images, I am using the following code snippet. Please have a look at it.

def get_reconstruction(iterator, x_coord, p_net, q_net, img_size=64, rotate=True, translate=True, dx_scale=0.1, theta_prior=np.pi
                        , augment_rotation=False, z_scale=1, use_cuda=False):

    def decode_tensor(input_tensor, img_size):
        input_tensor = input_tensor.view(input_tensor.shape[0], img_size, img_size, 3)
        input_tensor = input_tensor.cpu().detach().numpy()

        input_tensor = input_tensor.clip(0., 1.)
        input_tensor = input_tensor * 255.
        input_tensor = input_tensor.reshape(img_size, img_size, 3)
        input_tensor = input_tensor.astype("uint8")

        print(input_tensor.shape, input_tensor.dtype)
        return input_tensor

    for y, in iterator:
        b = y.size(0)
        assert b == 1
        x = Variable(x_coord)
        y = Variable(y)

        x = x.expand(b, x.size(0), x.size(1))
        n = int(np.sqrt(y.size(1)))

        if use_cuda:
            y = y.cuda()

        # first do inference on the latent variables
        z_mu,z_logstd = q_net(y_rot.view(b,-1))
        z_std = torch.exp(z_logstd)
        z_dim = z_mu.size(1)

        # draw samples from variational posterior to calculate
        # E[p(x|z)]
        r = Variable(x.data.new(b,z_dim).normal_())
        z = z_std*r + z_mu

        if rotate:
            # z[0] is the rotation
            theta_mu = z_mu[:,0]
            theta_std = z_std[:,0]
            theta_logstd = z_logstd[:,0]
            theta = z[:,0]
            z = z[:,1:]
            z_mu = z_mu[:,1:]
            z_std = z_std[:,1:]
            z_logstd = z_logstd[:,1:]

            # calculate rotation matrix
            rot = Variable(theta.data.new(b,2,2).zero_())
            rot[:,0,0] = torch.cos(theta)
            rot[:,0,1] = torch.sin(theta)
            rot[:,1,0] = -torch.sin(theta)
            rot[:,1,1] = torch.cos(theta)
            x = torch.bmm(x, rot) # rotate coordinates by theta

            # use modified KL for rotation with no penalty on mean
            sigma = theta_prior

        if translate:
            # z[0,1] are the translations
            dx_mu = z_mu[:,:2]
            dx_std = z_std[:,:2]
            dx_logstd = z_logstd[:,:2]
            dx = z[:,:2]*dx_scale # scale dx by standard deviation
            dx = dx.unsqueeze(1)
            z = z[:,2:]

            x = x + dx # translate coordinates

        z = z*z_scale

        # reconstruct
        y_hat = p_net(x.contiguous(), z)
        y_hat = y_hat.view(b, -1, 3)

        input_image = decode_tensor(y_rot, img_size)
        recon_image = decode_tensor(y_hat, img_size)

        import matplotlib.pyplot as plt
        fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize =(15, 6))
        ax1.imshow(input_image)
        ax2.imshow(recon_image)
        fig.savefig('foo.png')
        plt.show(fig)

        break
tomouellette commented 1 year ago

Hi @mmubeen-6, any luck on reproducing this work? I re-implemented the model from the ground up and have yet to generate comparable results relative to the original paper. I'll try a few more hyperparameter settings and train longer than the original paper (which may help?), but I'd be interested to hear if you've solved any of your pre-existing issues.

tbepler commented 1 year ago

It's been a long time since I ran those experiments, but a few things that will improve the generated images are:

It's also worth noting that the encoder in spatial-VAE can sometimes get stuck in bad local optima (especially regarding rotation inference) which then leads to bad generator performance. There are a few tricks implemented here to try to avoid those, but they only work so-so (e.g., including rotated images as input, but decoding the unrotated image by shifted the predicted rotation by the known augmentation rotation). You might want to take a look at some newer work (https://arxiv.org/abs/2210.12918, https://github.com/SMLC-NYSBC/TARGET-VAE) where we improved the encoder to address some of these issues. We lightly tested that on galaxy zoo but didn't push it as far as it should be able to go with a larger spatial generator and/or better initial featurization of the coordinates

tomouellette commented 1 year ago

Hi @tbepler, thanks for your comments! I spent the evening tinkering with a bit of my code and I did notice a bit of the local optima issues on a few runs, so I will play with it a bit more taking your suggestions into account. I also did see your TARGET-VAE paper pop up as well too, congrats on that! I will try playing around with the group convolutions if I get a chance.

If you don't mind, I might create another github repository with a bit of a refactored/re-engineered version of the spatial-VAE (with proper attributions and references of course). I think this could be a nice architecture for some of the applied bio stuff I'm working on - so I will probably run some additional experiments with more expressive encoders, different ways to condition the decoding with the latent variables, cyclical annealing of the kld, and maybe swapping out linear layers for 1x1 convolutions, etc. etc. I may potentially add in reflection if I have time.

I also wonder if semi-supervised learning can help with convergence; since it's seems reasonably straight forward to aggregate some ground truth rotations/translations either through augmentation or extracting them via fits (e.g. major axis or something).

tbepler commented 1 year ago

@tomouellette you're welcome to fork the code and use it however you like. I agree semi-supervised learning could help with convergence if you have labeled angles for a subset of images. It should be pretty straightforward to include in the objective.