ziplab / LITv2

[NeurIPS 2022 Spotlight] This is the official PyTorch implementation of "Fast Vision Transformers with HiLo Attention"
Apache License 2.0
229 stars 11 forks source link

code to visualize the frequency #3

Closed Ga-Lee closed 2 years ago

Ga-Lee commented 2 years ago

Hello! Thank you for your wonderful job! The figure in the article is beautiful! I want to visualize the frequency in my job like yours Fig.5. Could you share the visualize code ? Thank you ! : )

HubHop commented 2 years ago

Hi @Ga-Lee, thanks for your interest. We use the following script for visualizing our Figure 5.

import matplotlib.pyplot as plt
import torch
import os

def visualize_freq(x, title):
    '''
    x : The output feature maps from either Hi-Fi or Lo-Fi attention.
        Tensor shape: (batch_size, height, width, hidden_dim)
    '''
    data = x.permute(0, 3, 1, 2)
    fft_input = torch.fft.fft2(data.float())
    freq_img = torch.log(torch.abs(torch.fft.fftshift(fft_input)))

    num_plots = 8
    freq_img_mean = freq_img.mean(dim=0).cpu()

    fig, axis = plt.subplots(1, num_plots, figsize=(num_plots * 4, 4))
    for i in range(num_plots):
        axis[i].imshow(freq_img_mean[i, ...].numpy())
        axis[i].axes.xaxis.set_visible(False)
        axis[i].axes.yaxis.set_visible(False)

    plt.axis('off')
    plt.tight_layout()
    plt.savefig(f'/data1/visualizations/litv2_small/{title}.pdf')
    plt.clf()

Hope this can help your research.

Ga-Lee commented 2 years ago

Hi @Ga-Lee, thanks for your interest. We use the following script for visualizing our Figure 5.

import matplotlib.pyplot as plt
import torch
import os

def visualize_freq(x, title):
    '''
    x : The output feature maps from either Hi-Fi or Lo-Fi attention.
        Tensor shape: (batch_size, height, width, hidden_dim)
    '''
    data = x.permute(0, 3, 1, 2)
    fft_input = torch.fft.fft2(data.float())
    freq_img = torch.log(torch.abs(torch.fft.fftshift(fft_input)))

    num_plots = 8
    freq_img_mean = freq_img.mean(dim=0).cpu()

    fig, axis = plt.subplots(1, num_plots, figsize=(num_plots * 4, 4))
    for i in range(num_plots):
        axis[i].imshow(freq_img_mean[i, ...].numpy())
        axis[i].axes.xaxis.set_visible(False)
        axis[i].axes.yaxis.set_visible(False)

    plt.axis('off')
    plt.tight_layout()
    plt.savefig(f'/data1/visualizations/litv2_small/{title}.pdf')
    plt.clf()

Hope this can help your research.

Thank you a lot! Appreciate for your help!