Closed Ga-Lee closed 2 years ago
Hi @Ga-Lee, thanks for your interest. We use the following script for visualizing our Figure 5.
import matplotlib.pyplot as plt
import torch
import os
def visualize_freq(x, title):
'''
x : The output feature maps from either Hi-Fi or Lo-Fi attention.
Tensor shape: (batch_size, height, width, hidden_dim)
'''
data = x.permute(0, 3, 1, 2)
fft_input = torch.fft.fft2(data.float())
freq_img = torch.log(torch.abs(torch.fft.fftshift(fft_input)))
num_plots = 8
freq_img_mean = freq_img.mean(dim=0).cpu()
fig, axis = plt.subplots(1, num_plots, figsize=(num_plots * 4, 4))
for i in range(num_plots):
axis[i].imshow(freq_img_mean[i, ...].numpy())
axis[i].axes.xaxis.set_visible(False)
axis[i].axes.yaxis.set_visible(False)
plt.axis('off')
plt.tight_layout()
plt.savefig(f'/data1/visualizations/litv2_small/{title}.pdf')
plt.clf()
Hope this can help your research.
Hi @Ga-Lee, thanks for your interest. We use the following script for visualizing our Figure 5.
import matplotlib.pyplot as plt import torch import os def visualize_freq(x, title): ''' x : The output feature maps from either Hi-Fi or Lo-Fi attention. Tensor shape: (batch_size, height, width, hidden_dim) ''' data = x.permute(0, 3, 1, 2) fft_input = torch.fft.fft2(data.float()) freq_img = torch.log(torch.abs(torch.fft.fftshift(fft_input))) num_plots = 8 freq_img_mean = freq_img.mean(dim=0).cpu() fig, axis = plt.subplots(1, num_plots, figsize=(num_plots * 4, 4)) for i in range(num_plots): axis[i].imshow(freq_img_mean[i, ...].numpy()) axis[i].axes.xaxis.set_visible(False) axis[i].axes.yaxis.set_visible(False) plt.axis('off') plt.tight_layout() plt.savefig(f'/data1/visualizations/litv2_small/{title}.pdf') plt.clf()
Hope this can help your research.
Thank you a lot! Appreciate for your help!
Hello! Thank you for your wonderful job! The figure in the article is beautiful! I want to visualize the frequency in my job like yours Fig.5. Could you share the visualize code ? Thank you ! : )