nmichlo / disent

🧶 Modular VAE disentanglement framework for python built with PyTorch Lightning ▸ Including metrics and datasets ▸ With strongly supervised, weakly supervised and unsupervised methods ▸ Easily configured and run with Hydra config ▸ Inspired by disentanglement_lib
https://disent.michlo.dev
MIT License
122 stars 18 forks source link

visualisation with trained model #43

Closed dgm2 closed 3 months ago

dgm2 commented 1 year ago

Hi, great package!

I am looking at the example in plotting_examples folder. These seem to work independently from a trained torch model ? what would be a minimal way / example to use those with a trained model ? e.g. how to visualise the latent traversal of a trained model

Best regards

nmichlo commented 1 year ago

Hi @dgm2, thank you!

My apologies for the delayed response. Sorry you are right. Most of them are independent.

However, you can have a look at the helper code inside the pytorch lightning utilities, one of the callback classes is specifically for generating latent traversals:

https://github.com/nmichlo/disent/blob/ff462ba567a734041874cc584b97695a81729498/disent/util/lightning/callbacks/_callback_vis_latents.py#L197-L261

This method uses helper functions from disent.util.visualize.vis_img to convert tensors to images, disent.util.visualize.vis_latents to generate latent sequences, and disent.util.visualize.vis_util to combine the images together into a grid or make sequential frames.

The code is more complicated than it need to be for most cases because of some additional handling and quirks. Maybe we can add a specific docs example for latent traversals.

dgm2 commented 1 year ago

thanks! the callback method returns stills, frames, image how should I input these into plot_dataset_traversals ? or into visualize_dataset_traversal or what is are corresponding values there? e.g. does stills corresponds to grid as input into plt_subplots_imshow

e.g. this example makes sense? many thanks!

trainer = pl.Trainer(
    max_steps=2048,
    gpus=1 if torch.cuda.is_available() else None,
    logger=False,
    checkpoint_callback=False,
    max_epochs=1
)
trainer.fit(module, dataloader)
# trainer.save_checkpoint("trained.ckpt")

viz = VaeLatentCycleLoggingCallback()
stills, frames_, image_ = viz.generate_visualisations(trainer_or_dataset=trainer, pl_module=trainer.lightning_module,
                                                      num_frames=4, num_stats_samples=15)

plt_scale = 4.5
offset = 0.75
factors, frames, _, _, c = stills.shape

plt_subplots_imshow(grid=stills, title=None, row_labels=None, subplot_padding=None,
                    figsize=(offset + (1 / 2.54) * frames * plt_scale, (1 / 2.54) * (factors + 0.45) * plt_scale),
                    show=False)
nmichlo commented 1 year ago

Your example makes sense, but admittedly it has been a while since I last touched the code (I realize the current system is not optimal for these custom scripts, so this will need to be fixed in future).

You can try and plot images directly with plt.imshow(image). Or create your own visualization/animation with the frames or stills