plenoptic-org / plenoptic

Visualize/test models for visual representation by synthesizing images.
https://plenoptic.readthedocs.io/en/latest/
MIT License
57 stars 9 forks source link

Make Portilla-Simoncelli code more efficient #222

Open billbrod opened 1 year ago

billbrod commented 1 year ago

Because of the multiscale representation used in the Portilla-Simoncelli texture model, it's hard to write the code in a GPU-friendly way. I attempted to do so with in this commit, by using the non-downsampled pyramid, which allowed me to store the coefficients as a single tensor of shape (batch, channel, n_scales, n_orientations, height, width), which made a variety of things easier. It is probably worth comparing against the version here, which uses the downsampled version, packs the coefficients into lists of length n_scales, each entry of which is a tensor of shape (batch, channel, n_orientations, height/2^s, width/2^s), where s is the scale / index in the list, and uses list comprehension for much of the computations.

However, the version that uses the non-downsampled pyramid is significantly slower on the CPU (slightly faster on the GPU), much more memory-intensive, and it's really hard to compare against the earlier code -- in general, it's just not the same, because e.g., computing the autocorrelation of the coefficients at the coarsest scale gives a different answer depending on when that scale is size (64, 64) or (256, 256). You can downsample the coefficient image before computing the autocorrelation or downsampling it afterwards, but there's no way to do that efficiently on the GPU (vmap doesn't let you convert something being vmapped over to an int or like dynamically-sized inputs, one of which would be necessary to do the downsampling and center cropping in an efficient manner) and it's still generally not the same value (something like allclose(rtol=1e-1, atol=1e-1)). And you must do the downsampling, or the definition of "autocorrelations up to 4 shifts in all directions" is completely different at the coarser scales.

With the downsampled + list comprehension version of the code, I tried considering the lists as pytree and using tree_map (either from torch.utils._pytree, jax, or optree) and got no speedup (either on GPU or CPU, probably because a list of 4d/5d tensors is not that difficult of a pytree to parse), and vmap doesn't like ragged data (a list of variably-sized arrays) or lists of structs, so I couldn't figure out a way to make anything vmap-able.

So, unsure how to improve this.

billbrod commented 1 year ago

A good chunk of time is also spent in the forward and recon_pyr methods of the steerable pyramid, so improving the efficiency of those would also help.