NVlabs / RADIO

Official repository for "AM-RADIO: Reduce All Domains Into One"
Other
522 stars 20 forks source link

Per-pixel features #55

Open SimonGeb opened 2 months ago

SimonGeb commented 2 months ago

Hi,

Thanks for your work, I found it very interesting. I was wondering whether it is possible to get more per-pixel features using your pre-trained model. Currently, using the provided example scripts on a custom image returns a high-dimensional vector but at low spatial resolution.

I'm looking into zero-shot semantic segmentation, and for that it would be beneficial to get pixel-level features instead. I used the code in visualize_features.py to get a PCA map but it is not as detailed as the example from your paper: image

Eventually, I would be looking to use RADIO for open-vocabulary semantic segmentation, like Grounded-SAM, for other down-stream tasks. Any help would be greatly appreciated.

Kind regards,

Simon

mranzinger commented 2 months ago

Hi Simon,

The output features will have a spatial resolution that is downsampled by 16x on each dimension. All of the examples from the paper come from the visualize_features.py script you're referencing; the only difference is what we set the input image size to be.

So, for semseg, you have a couple options (or a combination thereof): 1) Increase the input image size that you're feeding radio. 2) Use transposed convolution or pixel shuffling as a final learnable layer

I know that (2) is a bit awkward when you're interested in zero-shot, but, you might be able to train that layer on a small-ish segmentation dataset and then use it on open world segmentation. It might look something like this:

Train upsample projector

[(frozen) radio backbone] -> [(learnable) upsample deconv] -> [(learnable) linear classifier] <-> Loss(semseg)

Open-world segmentation

[(frozen) radio backbone] -> [(frozen) upsample deconv] -> (per-pixel features)

SimonGeb commented 2 months ago

Thanks a lot for the pointers, I will definitely explore this more!

One more code specific question: In visualize_features.py there is a get_cluster_map function. Here, a Kmeans clusterer is called but it is not defined anywhere in this repository. I thought I would just use the sklearn implementation instead, but that does not allow for cosine similarity as a metric and gives some other errors, making me think you used a custom implementation. Is this the case? And if so, would it be possible to make it available?

Many thanks!

mranzinger commented 2 months ago

I actually borrowed the code from https://github.com/Jiawei-Yang/Denoising-ViT/blob/adeff838169152a6e55bd8e3d7f1f1befe006ff2/utils/visualization_tools.py#L42 (GitHub Issue). So I expect you'd be able to find the rest of the necessary clustering code there.