cfzd / FcaNet

FcaNet: Frequency Channel Attention Networks
MIT License
503 stars 100 forks source link

visualization of papar Figure 5 & papar Figure 6 #6

Closed TianhaoFu closed 3 years ago

TianhaoFu commented 3 years ago

Hi,I've read your paper. It's a good job . Could you provide your visualization code of papar(FcaNet) Figure 5 & papar Figure 6 ?

Thanks a lot!

cfzd commented 3 years ago

@TianhaoFu Hi, sorry for the late response. The code for Figure 5:

import matplotlib.pyplot as plt
import numpy as np
import random
import seaborn

freq = np.array([76.69399990722657, 76.54799998779296, 76.49199985839844, 76.37199993652344, 76.38599986083985, 76.5140000366211, 76.38399990722657,
            76.4820000390625, 76.25799990722656, 76.4680000390625, 76.29600001220703, 76.19000004150391, 76.27599993408204, 76.39799995361328,
            76.29799993896485, 76.31999999267578, 76.35599988037109, 76.30400008789063, 76.25999998535156, 76.27999995361328, 76.20999998535156,
            76.39400001464844, 76.30799993164062, 76.30999999023437, 76.33599993652344, 76.19200006835938, 76.35999993652344, 76.2060000390625,
            76.4419999609375, 76.30599993652343, 76.27800001220703, 76.21800006347657, 76.26799993408203, 76.26800008789063, 76.34400001220703,
            76.43799998535157, 76.28000008544922, 76.31400001220703, 76.33199995361328, 76.31000003662109, 76.3259998803711, 76.27000001464843,
            76.53399998779297, 76.31600001464844, 76.2780000366211, 76.3379999633789, 76.2839999584961, 76.29600003417968, 75.71999993896485])

annot = np.array([['Rank1', 'Rank2', 'Rank5', 'Rank14', 'Rank12', 'Rank4', 'Rank13'],
['Rank6', 'Rank43', 'Rank7', 'Rank32', 'Rank48', 'Rank38', 'Rank10'],
['Rank30', 'Rank22', 'Rank16', 'Rank29', 'Rank42', 'Rank35', 'Rank45'],
['Rank11', 'Rank27', 'Rank26', 'Rank19', 'Rank47', 'Rank15', 'Rank46'],
['Rank8', 'Rank28', 'Rank37', 'Rank44', 'Rank41', 'Rank40', 'Rank17'],
['Rank9', 'Rank34', 'Rank24', 'Rank20', 'Rank25', 'Rank21', 'Rank39'],
['Rank3', 'Rank23', 'Rank36', 'Rank18', 'Rank33', 'Rank31', 'Rank49']])
freq = freq.reshape(7, 7)

temp = seaborn.color_palette("Blues", 120)
new = []
for i in range(120):
    if i < 50:
        new.append(temp[119-i])
    else:
        new.append(temp[int(i*0.)])

new.reverse()

f, ax = plt.subplots(1, 1, figsize=(8,7))

seaborn.heatmap(freq, annot=annot, fmt='', ax=ax, cmap=new, linewidths=1., annot_kws={"fontsize":13}, cbar=False)

plt.xticks([])
plt.yticks([])

ax.xaxis.set_ticks_position('top')
ax.xaxis.set_label_position('top')
plt.xlabel('Low Frequency $\longrightarrow$  High Frequency', fontdict={'fontsize':18})
plt.ylabel('High Frequency $\longleftarrow$  Low Frequency', fontdict={'fontsize':18})
plt.show()

The code for Figure 6:

import numpy as np
import cv2
import math

def get_1d_dct_basis(length, pos, freq):
    return math.cos(math.pi * freq * (pos + 0.5) / length)

def get_dct_basis(size, fi,fj):
    real_size = int(size*1.0)
    pad = int((real_size - size) / 2)
    basis = np.zeros((real_size,real_size))
    basis[:,:] = -2
    for i in range(size):
        for j in range(size):
            basis[i+pad,j+pad] = get_1d_dct_basis(size, i, fi) * get_1d_dct_basis(size, j, fj) 
    return basis

if __name__ == "__main__":
    dct_grid = 7
    dct_size = 7
    vis_resize_size = 100
    for i in range(dct_grid):
        for j in range(dct_grid):
            dct =  (get_dct_basis(dct_size,i * dct_size * 1.0 / dct_grid ,j * dct_size * 1.0 / dct_grid) + 1) * 127
            dct = dct.astype(np.uint8)
            dct = cv2.resize(dct,dsize = (vis_resize_size,vis_resize_size),interpolation  = cv2.INTER_NEAREST )
            dct = cv2.applyColorMap(dct, cv2.COLORMAP_JET)
            cv2.imwrite('%02d_%02d.png'%(i,j),dct)
TianhaoFu commented 3 years ago

Thanks