greentfrapp / lucent

Lucid library adapted for PyTorch
Apache License 2.0
605 stars 88 forks source link

Support for torchvision.models.video models #2

Open greentfrapp opened 4 years ago

greentfrapp commented 4 years ago

Video models require 5D inputs (batch, channels, frames, height, width). But most of the parameterization, transforms and rendering functions in Lucent assume 4D inputs.

A simple workaround is to initialize a batch of images where batchsize = batch * frames. Then, inside the render_vis function, just before we pass the input to the model, we transpose and unsqueeze the input to a 5D shape.

Specifically, in render.py, replace https://github.com/greentfrapp/lucent/blob/044317a7b395220e6a27fd890c35abc081c5d1c8/lucent/optvis/render.py#L73 with

image_t = transform_f(image_f())
image_t = torch.transpose(image_t, 0, 1).unsqueeze(0)
model(image_t)

But I'm wondering if there is a better solution.

Also, ideally we want the frames to be continuous, which suggests an objective to maximize alignment between the frames. Maybe objectives.alignment("input") will be sufficient for this.