Closed zlj63501 closed 2 years ago
def imshow_attention(img, attn_weights, out_file):
img = numpy.ascontiguousarray(img)[:, :, ::-1]
h, w = img.shape[:2]
attn_weights = torch.cat(list(map(lambda weights: torch.mean(
weights, dim=0, keepdim=True), torch.split(attn_weights, [3, 3, 3, 3]))), dim=0)
assert attn_weights.size(0) == 4 and attn_weights.ndim == 2
attn_weights = attn_weights.reshape(4, 20, 20)
attn_weights = attn_weights[:, :h // 32 + 1, :w // 32 + 1]
attn_weights = attn_weights.cpu().numpy()
for i in range(4):
plt.clf()
plt.axis('off')
plt.imshow(img, alpha=0.7)
attn_mask = cv2.resize(attn_weights[i], (w, h))
attn_mask = (attn_mask * 255).astype(numpy.uint8)
plt.imshow(attn_mask, alpha=0.3,
interpolation="bilinear", cmap="jet")
plt.savefig(out_file.replace(".jpg", f"_{i}th_step.jpg"), dpi=300)
from typing import Sequence
import mmcv
import torch
import numpy
import argparse
import os.path as osp
import torch.nn.functional as f
from mmcv import Config, DictAction
from mmcv.utils import mkdir_or_exist
from seqtr.models import build_model
from seqtr.core import imshow_attention
from seqtr.utils import load_checkpoint, get_root_logger
from seqtr.datasets import extract_data, build_dataset, build_dataloader
try:
import apex
except:
pass
def parse_args():
parser = argparse.ArgumentParser(description="SeqTR")
parser.add_argument('config', help='visualize config file path')
parser.add_argument(
'checkpoint', help='the checkpoint file to load from.')
parser.add_argument(
'--output-dir', help='directory where visualized results will be saved.')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
parser.add_argument(
'--visualize',
type=str,
nargs='+',
default='testA',
help="evaluation set, which depends on the dataset, e.g., \
'val', 'testA', 'testB' for RefCOCO(Plus)UNC, and 'val', 'test' for RefCOCOgUMD.")
args = parser.parse_args()
return args
def main(cfg):
datasets_cfg = [cfg.data.train]
for vis_set in cfg.visualize:
datasets_cfg.append(eval(f"cfg.data.{vis_set}"))
datasets = list(map(build_dataset, datasets_cfg))
dataloaders = list(
map(lambda dataset: build_dataloader(cfg, dataset), datasets))
model = build_model(cfg,
word_emb=datasets[0].word_emb,
num_token=datasets[0].num_token)
model = model.cuda()
if cfg.use_fp16:
model = apex.amp.initialize(model, opt_level="O1")
for m in model.modules():
if hasattr(m, "fp16_enabled"):
m.fp16_enabled = True
load_checkpoint(model, None, None, cfg.checkpoint)
model.eval()
model.head.transformer.need_weights = True
model.head.transformer.decoder.need_weights = True
decoder_layers = model.head.transformer.decoder.num_layers
logger = get_root_logger()
for i, vis_set in enumerate(cfg.visualize):
logger.info(f"visualizing attention on set {vis_set}")
output_dir = osp.join(cfg.output_dir, cfg.dataset, vis_set)
mkdir_or_exist(output_dir)
with torch.no_grad():
prog_bar = mmcv.ProgressBar(len(datasets[i+1]))
for batch, inputs in enumerate(dataloaders[i+1]):
inputs = extract_data(inputs)
"""
inputs (Dict): {
'img_metas' (List[Dict]): {
'filename' (str): './data/images/train2014/COCO_train2014_000000580957.jpg',
'expression' (str): 'bowl behind the others can only see part',
'ori_shape' (tuple): (h_ori, w_ori, 3),
'img_shape' (tuple): (h_img, w_img, 3),
'pad_shape' (tuple): (h_pad, w_pad, 3),
'scale_factor' (Array): (w_scale, h_scale, w_scale, h_scale),
'img_norm_cfg' (dict): {
'mean' (Array): [0., 0., 0.]
'std' (Array): [1., 1., 1.]
}
'to_rgb': True
},
'img' (Tensor): [batch_size, 3, h_batch, w_batch].
'ref_expr_inds' (Tensor): [batch_size, max_token].
'gt_bbox' (List[Tensor]): [
[tl_x, tl_y, br_x, br_y], in (h_img, w_img) coordinate system.
]
}
"""
img, ref_expr_inds, img_metas = inputs['img'], inputs['ref_expr_inds'], inputs['img_metas']
batch_size = img.size(0)
batch_input_shape = tuple(img.size()[-2:])
for img_meta in img_metas:
img_meta['batch_input_shape'] = batch_input_shape
x, y, y_word, y_mask = model.extract_visual_language(
img, ref_expr_inds)
if model.with_neck:
x, y = model.neck(x, y, y_word, y_mask)
x_mask, x_pos_embeds = model.head.transformer.x_mask_pos_enc(
x, img_metas)
if model.with_fusion:
y = model.fusion(x, y_word, x_mask, y_mask)
memory = model.head.transformer.forward_encoder(
x, x_mask, x_pos_embeds)
attn_weights_all_coordinates = []
seq_in_embeds = y
for step in range(4):
out, attn_weights_all_layers = model.head.transformer.forward_decoder(
seq_in_embeds, memory, x_pos_embeds, x_mask)
attn_weights_all_coordinates.append(
attn_weights_all_layers)
logits = out[:, -1, :]
logits = model.head.predictor(logits)
logits = logits[:, :-1] # [batch_size, num_bin]
probability = f.softmax(logits, dim=-1)
probability, next_token = probability.topk(
dim=-1, k=1, largest=True, sorted=True)
seq_in_embeds = torch.cat(
[seq_in_embeds, model.head.transformer.query_embedding(next_token)], dim=1)
attn_weights_all_images = []
for i in range(batch_size):
attn_weights_this_img = []
for j in range(4):
for k in range(decoder_layers):
attn_weights_this_img.append(
attn_weights_all_coordinates[j][k][i, -1, :])
# [4*decoder_layers, 400]
attn_weights_this_img = torch.vstack(attn_weights_this_img)
attn_weights_all_images.append(attn_weights_this_img)
for img_meta, attn_weights in zip(img_metas, attn_weights_all_images):
filename = img_meta['filename']
expression = img_meta['expression'].replace(" ", "")
out_file = osp.join(
args.output_dir, expression + "_" + osp.basename(filename))
img = mmcv.imread(filename).astype(numpy.uint8)
imshow_attention(img, attn_weights, out_file)
prog_bar.update()
if __name__ == "__main__":
args = parse_args()
cfg = Config.fromfile(args.config)
if args.cfg_options is not None:
cfg.merge_from_dict(args.cfg_options)
if isinstance(args.visualize, str):
cfg.visualize = [args.visualize]
elif isinstance(args.visualize, Sequence):
cfg.visualize = args.visualize
cfg.checkpoint = args.checkpoint
cfg.output_dir = args.output_dir
main(cfg)
Hi, you can reference the above code for attention visualization, the code may be buggy cause I renamed several apis and possibly changed the attribute of transformer decoder (whether return weights, etc) during open-sourcing. Nevertheless, this is the code used to visualize the attention map.
def imshow_attention(img, attn_weights, out_file): img = numpy.ascontiguousarray(img)[:, :, ::-1] h, w = img.shape[:2] attn_weights = torch.cat(list(map(lambda weights: torch.mean( weights, dim=0, keepdim=True), torch.split(attn_weights, [3, 3, 3, 3]))), dim=0) assert attn_weights.size(0) == 4 and attn_weights.ndim == 2 attn_weights = attn_weights.reshape(4, 20, 20) attn_weights = attn_weights[:, :h // 32 + 1, :w // 32 + 1] attn_weights = attn_weights.cpu().numpy() for i in range(4): plt.clf() plt.axis('off') plt.imshow(img, alpha=0.7) attn_mask = cv2.resize(attn_weights[i], (w, h)) attn_mask = (attn_mask * 255).astype(numpy.uint8) plt.imshow(attn_mask, alpha=0.3, interpolation="bilinear", cmap="jet") plt.savefig(out_file.replace(".jpg", f"_{i}th_step.jpg"), dpi=300)
from typing import Sequence import mmcv import torch import numpy import argparse import os.path as osp import torch.nn.functional as f from mmcv import Config, DictAction from mmcv.utils import mkdir_or_exist from seqtr.models import build_model from seqtr.core import imshow_attention from seqtr.utils import load_checkpoint, get_root_logger from seqtr.datasets import extract_data, build_dataset, build_dataloader try: import apex except: pass def parse_args(): parser = argparse.ArgumentParser(description="SeqTR") parser.add_argument('config', help='visualize config file path') parser.add_argument( 'checkpoint', help='the checkpoint file to load from.') parser.add_argument( '--output-dir', help='directory where visualized results will be saved.') parser.add_argument( '--cfg-options', nargs='+', action=DictAction, help='override some settings in the used config, the key-value pair ' 'in xxx=yyy format will be merged into config file. If the value to ' 'be overwritten is a list, it should be like key="[a,b]" or key=a,b ' 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" ' 'Note that the quotation marks are necessary and that no white space ' 'is allowed.') parser.add_argument( '--visualize', type=str, nargs='+', default='testA', help="evaluation set, which depends on the dataset, e.g., \ 'val', 'testA', 'testB' for RefCOCO(Plus)UNC, and 'val', 'test' for RefCOCOgUMD.") args = parser.parse_args() return args def main(cfg): datasets_cfg = [cfg.data.train] for vis_set in cfg.visualize: datasets_cfg.append(eval(f"cfg.data.{vis_set}")) datasets = list(map(build_dataset, datasets_cfg)) dataloaders = list( map(lambda dataset: build_dataloader(cfg, dataset), datasets)) model = build_model(cfg, word_emb=datasets[0].word_emb, num_token=datasets[0].num_token) model = model.cuda() if cfg.use_fp16: model = apex.amp.initialize(model, opt_level="O1") for m in model.modules(): if hasattr(m, "fp16_enabled"): m.fp16_enabled = True load_checkpoint(model, None, None, cfg.checkpoint) model.eval() model.head.transformer.need_weights = True model.head.transformer.decoder.need_weights = True decoder_layers = model.head.transformer.decoder.num_layers logger = get_root_logger() for i, vis_set in enumerate(cfg.visualize): logger.info(f"visualizing attention on set {vis_set}") output_dir = osp.join(cfg.output_dir, cfg.dataset, vis_set) mkdir_or_exist(output_dir) with torch.no_grad(): prog_bar = mmcv.ProgressBar(len(datasets[i+1])) for batch, inputs in enumerate(dataloaders[i+1]): inputs = extract_data(inputs) """ inputs (Dict): { 'img_metas' (List[Dict]): { 'filename' (str): './data/images/train2014/COCO_train2014_000000580957.jpg', 'expression' (str): 'bowl behind the others can only see part', 'ori_shape' (tuple): (h_ori, w_ori, 3), 'img_shape' (tuple): (h_img, w_img, 3), 'pad_shape' (tuple): (h_pad, w_pad, 3), 'scale_factor' (Array): (w_scale, h_scale, w_scale, h_scale), 'img_norm_cfg' (dict): { 'mean' (Array): [0., 0., 0.] 'std' (Array): [1., 1., 1.] } 'to_rgb': True }, 'img' (Tensor): [batch_size, 3, h_batch, w_batch]. 'ref_expr_inds' (Tensor): [batch_size, max_token]. 'gt_bbox' (List[Tensor]): [ [tl_x, tl_y, br_x, br_y], in (h_img, w_img) coordinate system. ] } """ img, ref_expr_inds, img_metas = inputs['img'], inputs['ref_expr_inds'], inputs['img_metas'] batch_size = img.size(0) batch_input_shape = tuple(img.size()[-2:]) for img_meta in img_metas: img_meta['batch_input_shape'] = batch_input_shape x, y, y_word, y_mask = model.extract_visual_language( img, ref_expr_inds) if model.with_neck: x, y = model.neck(x, y, y_word, y_mask) x_mask, x_pos_embeds = model.head.transformer.x_mask_pos_enc( x, img_metas) if model.with_fusion: y = model.fusion(x, y_word, x_mask, y_mask) memory = model.head.transformer.forward_encoder( x, x_mask, x_pos_embeds) attn_weights_all_coordinates = [] seq_in_embeds = y for step in range(4): out, attn_weights_all_layers = model.head.transformer.forward_decoder( seq_in_embeds, memory, x_pos_embeds, x_mask) attn_weights_all_coordinates.append( attn_weights_all_layers) logits = out[:, -1, :] logits = model.head.predictor(logits) logits = logits[:, :-1] # [batch_size, num_bin] probability = f.softmax(logits, dim=-1) probability, next_token = probability.topk( dim=-1, k=1, largest=True, sorted=True) seq_in_embeds = torch.cat( [seq_in_embeds, model.head.transformer.query_embedding(next_token)], dim=1) attn_weights_all_images = [] for i in range(batch_size): attn_weights_this_img = [] for j in range(4): for k in range(decoder_layers): attn_weights_this_img.append( attn_weights_all_coordinates[j][k][i, -1, :]) # [4*decoder_layers, 400] attn_weights_this_img = torch.vstack(attn_weights_this_img) attn_weights_all_images.append(attn_weights_this_img) for img_meta, attn_weights in zip(img_metas, attn_weights_all_images): filename = img_meta['filename'] expression = img_meta['expression'].replace(" ", "") out_file = osp.join( args.output_dir, expression + "_" + osp.basename(filename)) img = mmcv.imread(filename).astype(numpy.uint8) imshow_attention(img, attn_weights, out_file) prog_bar.update() if __name__ == "__main__": args = parse_args() cfg = Config.fromfile(args.config) if args.cfg_options is not None: cfg.merge_from_dict(args.cfg_options) if isinstance(args.visualize, str): cfg.visualize = [args.visualize] elif isinstance(args.visualize, Sequence): cfg.visualize = args.visualize cfg.checkpoint = args.checkpoint cfg.output_dir = args.output_dir main(cfg)
Hi, you can reference the above code for attention visualization, the code may be buggy cause I renamed several apis and possibly changed the attribute of transformer decoder (whether return weights, etc) during open-sourcing. Nevertheless, this is the code used to visualize the attention map.
Well, thank you very much for your generous sharing. I think many of the latter will learn a lot about visualization skills from this, eg. me, haha. Thanks again for your quick reply!
Hi,
Congratulation!
I want to visualize the attention weights of segmentation points similar to Fig. 5.
According to the paper: "We visualize the cross attention map averaged over decoder layers and attention heads in Fig. 5.", but I am not sure how to incorporate these weights into the original image.
Would you like to share the script or provide a workable idea?
Thanks~