Event-AHU / MambaEVT

Event stream based visual object tracking using Mamba/State Space Model
MIT License
23 stars 2 forks source link

可视化代码 #1

Open goodnight111111111 opened 2 months ago

goodnight111111111 commented 2 months ago

感谢您们的贡献,我对您们的论文很感兴趣,能否分享可视化代码,非常感谢!

TomX32 commented 2 months ago

具体是哪部分?

goodnight111111111 commented 2 months ago

非常感谢您百忙之中回复我,需要图4和图5的可视化,非常感谢!

TomX32 commented 2 months ago

图4可参考下面代码,放置在lib/test/tracker/yourTracker.pytrack()函数中

def getCAM2(features, img, idx):      
    save_path =  '/path/to/save/' 
    if not os.path.exists(save_path):
        os.makedirs(save_path)
    features = features.to("cpu")
    features = features.squeeze(1).detach().numpy()
    img = cv2.resize(img, (256, 256))
    img = img
    img = np.array(img, dtype=np.uint8)
    # mask = features.sum(dim=0, keepdims=False)
    mask = features
    # mask = mask.detach().cpu().numpy()
    mask = mask.transpose((1, 2, 0))
    mask = (mask - mask.min()) / (mask.max() - mask.min())
    mask = cv2.resize(mask, (256,256))
    mask = 255 * mask
    mask = mask.astype(np.uint8)
    heatmap = cv2.applyColorMap(255-mask, cv2.COLORMAP_JET)

    img = cv2.addWeighted(src1=img, alpha=0.6, src2=heatmap, beta=0.4, gamma=0)
    name = '/attn_%d.png' % idx
    cv2.imwrite(save_path + name, img)

图5可参考下面代码,并且提前将所需文件配置好

import cv2
import math
import os
import numpy as np
from tqdm import tqdm

seq_name ="recording_2022-10-10_17-28-24"  # 在当前目录下
# seq_name ="recording_2022-10-10_17-28-47"
# seq_name ="recording_2023-04-28_17-17-47"
root = r"/path/to/project"
img_folds = '/' + seq_name
gt_folds = root + img_folds
save_folds = root + img_folds + '/img_bbox'
# 如果save_folds不存在,则创建
if not os.path.exists(save_folds):
    os.makedirs(save_folds)
img_root = root + img_folds + '/img'
img_list = sorted(os.listdir(img_root))

trackers = {
    'OSTrack': (gt_folds + '/bbox/OSTrack.txt', (0, 255, 255)),  # 注意通道是BGR
    'MambaEVT': (gt_folds + '/bbox/MambaEVT.txt', (0, 0, 255)),  # 注意通道是BGR
    # ....按需添加
}

# 读取所有追踪器文件的坐标数据
tracker_data = {}
for tracker, (file_name, color) in trackers.items():
    with open(file_name, 'r') as file:
        lines = file.readlines()
        coords = []
        for line in lines:
            if tracker == 'groundtruth':
                x, y, w, h = map(float, line.split(','))
            else:
                x, y, w, h = map(float, line.split('\t'))
            coords.append((x, y, w, h))
        tracker_data[tracker] = (coords, color)

# 处理图像并显示进度条
for k, img_name in enumerate(tqdm(img_list, desc="Processing images")):
    image_path = os.path.join(img_root, img_name)
    image = cv2.imread(image_path)

    text_num = '#%04d' % (k + 1)
    cv2.putText(image, text_num, (50, 50), cv2.FONT_HERSHEY_SIMPLEX, 2, (0, 255, 0), 8)

    for tracker, (coords, color) in tracker_data.items():
        if k < len(coords):  # 确保不会超出索引范围
            x, y, w, h = coords[k]
            first_point = (math.ceil(x), math.ceil(y))
            last_point = (math.ceil(x + w), math.ceil(y + h))
            cv2.rectangle(image, first_point, last_point, color, 4)
    save_path = os.path.join(save_folds, img_name)
    cv2.imwrite(save_path, image)
goodnight111111111 commented 1 month ago

非常感谢您提供的代码,非常感谢!