Closed dgm2 closed 3 months ago
Hi @dgm2, thank you!
My apologies for the delayed response. Sorry you are right. Most of them are independent.
However, you can have a look at the helper code inside the pytorch lightning utilities, one of the callback classes is specifically for generating latent traversals:
This method uses helper functions from disent.util.visualize.vis_img
to convert tensors to images, disent.util.visualize.vis_latents
to generate latent sequences, and disent.util.visualize.vis_util
to combine the images together into a grid or make sequential frames.
The code is more complicated than it need to be for most cases because of some additional handling and quirks. Maybe we can add a specific docs example for latent traversals.
thanks!
the callback method returns stills, frames, image
how should I input these into plot_dataset_traversals
? or into visualize_dataset_traversal
or what is are corresponding values there? e.g. does stills
corresponds to grid
as input into plt_subplots_imshow
e.g. this example makes sense? many thanks!
trainer = pl.Trainer(
max_steps=2048,
gpus=1 if torch.cuda.is_available() else None,
logger=False,
checkpoint_callback=False,
max_epochs=1
)
trainer.fit(module, dataloader)
# trainer.save_checkpoint("trained.ckpt")
viz = VaeLatentCycleLoggingCallback()
stills, frames_, image_ = viz.generate_visualisations(trainer_or_dataset=trainer, pl_module=trainer.lightning_module,
num_frames=4, num_stats_samples=15)
plt_scale = 4.5
offset = 0.75
factors, frames, _, _, c = stills.shape
plt_subplots_imshow(grid=stills, title=None, row_labels=None, subplot_padding=None,
figsize=(offset + (1 / 2.54) * frames * plt_scale, (1 / 2.54) * (factors + 0.45) * plt_scale),
show=False)
Your example makes sense, but admittedly it has been a while since I last touched the code (I realize the current system is not optimal for these custom scripts, so this will need to be fixed in future).
stills
is should be an array of shape (num_latents, num_frames, 64, 64, 3)
containing individual latent traversals.frames
is a concatenated version of stills
intended to create videos, so the individual stills over the factors dimension are combined together into an image grid. The final array is approx of shape (num_frames, ~(64 * grid_h), ~(64 * grid_w), 3)
.image
is a single image that you can plot that has all the latent traversals merged together into a grid, the x axis of this grid will correspond to num_latents
and y axis to num_frames
(or vice versa) so the shape will be approx: (~(64 * num_latents), ~(64 * num_frames), 3)
You can try and plot images directly with plt.imshow(image)
. Or create your own visualization/animation with the frames
or stills
Hi, great package!
I am looking at the example in plotting_examples folder. These seem to work independently from a trained torch model ? what would be a minimal way / example to use those with a trained model ? e.g. how to visualise the latent traversal of a trained model
Best regards