Yao-DD / S3N

86 stars 26 forks source link

how should I do to load my checkpoints and visualize response maps like these in paper? #4

Open xmczh003 opened 4 years ago

xmczh003 commented 4 years ago

Thanks for your work and I have successfully achieved 88.5% accuracy with 2 Titan XP. Now I want to load my checkpoints and visualize some response maps to further understand this awesome work. Could you give me suggestions? Thank you very much.

Yao-DD commented 4 years ago

You can load checkpoints in jupyter as follows: from nest import register,modules model = modules.s3n(mode ='resnet50', num_classes = 200, task_input_size = 448, base_ratio = 0.09, radius = 0.09, radius_inv = 0.25) checkpoint = torch.load('/userhome/Generate/cub_sss_net_448/run027/model_best_test_best.pt')['model'] model.load_state_dict({k.replace('module.', ''):v for k,v in checkpoint.items()})

There is an example of visualize the response maps of the origin branch: from PIL import Image import matplotlib.pyplot as plt from torch.autograd import Variable

img = Image.open(image_path).convert('RGB') image = transform(img) image = Variable(image).unsqueeze(0).cuda()

feature_raw = model.features(image) agg_origin = model.raw_classifier(model.avg(feature_raw).view(-1, 2048))

class_response_maps = model.map_origin(feature_raw), size=model.grid_size, mode='bilinear', align_corners=True)

plt.imshow(class_response_maps[0, 0, :, :].detech().cpu().numpy())

JessicaChanzc commented 4 years ago

You can load checkpoints in jupyter as follows: from nest import register,modules model = modules.s3n(mode ='resnet50', num_classes = 200, task_input_size = 448, base_ratio = 0.09, radius = 0.09, radius_inv = 0.25) checkpoint = torch.load('/userhome/Generate/cub_sss_net_448/run027/model_best_test_best.pt')['model'] model.load_state_dict({k.replace('module.', ''):v for k,v in checkpoint.items()})

There is an example of visualize the response maps of the origin branch: from PIL import Image import matplotlib.pyplot as plt from torch.autograd import Variable

img = Image.open(image_path).convert('RGB') image = transform(img) image = Variable(image).unsqueeze(0).cuda()

feature_raw = model.features(image) agg_origin = model.raw_classifier(model.avg(feature_raw).view(-1, 2048))

class_response_maps = model.map_origin(feature_raw), size=model.grid_size, mode='bilinear', align_corners=True)

plt.imshow(class_response_maps[0, 0, :, :].detech().cpu().numpy())

excellent work! I try to visualize through the code you provided, but I can't find the definition of transform ------->image = transform(img) Could you provide the code of this function? Thank you