microsoft / FocalNet

[NeurIPS 2022] Official code for "Focal Modulation Networks"
MIT License
682 stars 61 forks source link

Visualization of gated outputs #32

Closed Tajamul21 closed 1 year ago

Tajamul21 commented 1 year ago

RuntimeError Traceback (most recent call last) /tmp/ipykernel_63586/1379561863.py in 23 fig.add_subplot(1, 5, i+2) 24 gates_i = (upsampler(gates[:, i:i+1])).cpu().detach() ---> 25 plt.imshow(gates_i.permute(1,2,0).numpy()) 26 plt.axis('off') 27 x.axes.get_xaxis().set_visible(False)

RuntimeError: number of dims don't match in permute

Tajamul21 commented 1 year ago

visualize gating maps

upsampler = nn.Upsample(scale_factor=4, mode='bilinear')

img_folder = "/home/tajamul/scratch/FocalNet/FocalNet/demo_fig/" img_paths = os.listdir(img_folder) for img_path in img_paths: img = Image.open(img_folder + img_path) img_t = eval_transforms(img) img_d = display_transforms(img) out = model(img_t.unsqueeze(0).cuda())

fig=plt.figure(figsize=(16, 8))

fig.add_subplot(1, 5, 1)       
img2d = img_d.permute(1, 2, 0).cpu().detach().contiguous().numpy()
x = plt.imshow(img_d.permute(1, 2, 0).cpu().detach().contiguous().numpy())     
plt.axis('off')
x.axes.get_xaxis().set_visible(False)
x.axes.get_yaxis().set_visible(False)    

gates = (model.layers[-1].blocks[-1].modulation.gates)
for i in range(4):
    fig.add_subplot(1, 5, i+2)        
    gates_i = (upsampler(gates[:, i:i+1])).cpu().detach()
    plt.imshow((gates_i.squeeze(0)).squeeze(0).numpy())
    plt.axis('off')
    x.axes.get_xaxis().set_visible(False)
    x.axes.get_yaxis().set_visible(False)

plt.subplots_adjust(wspace=0, hspace=0)
plt.show()
Tajamul21 commented 1 year ago

This updated code is working. Thanks