Jingkang50 / OpenPSG

Benchmarking Panoptic Scene Graph Generation (PSG), ECCV'22
https://psgdataset.org
MIT License
407 stars 68 forks source link

update the work to new version #118

Open jiugexuan opened 2 months ago

jiugexuan commented 2 months ago

I want update all dependencies to new version, so I need to remove the [Detectron2].For it,I write a new script to visual the result:

import mmcv
from mmdet.apis import init_detector, inference_detector, show_result_pyplot

# 使用你自己的配置文件和训练好的模型检查点
config_file = 'configs/psgtr/psgtr_r50_psg_inference.py'
checkpoint_file = 'work_dirs/psgtr_r50_e60/epoch_60.pth'

# 初始化检测器
model = init_detector(config_file, checkpoint_file, device='cuda:0')

# 测试单张图片
img = "./data/coco/val2017/000000568439.jpg"  # 或者 img = mmcv.imread(img), 只加载一次
# img = 'bw25.png'
# 运行推理
result = inference_detector(model, img)

import networkx as nx
from pyvis.network import Network
import mmcv
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Polygon
import cv2
from mmdet.datasets.coco_panoptic import INSTANCE_OFFSET

# 定义类别
CLASSES = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train', 'truck', 'boat', 
           'traffic light', 'fire hydrant', 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 
           'dog', 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 
           'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball', 
           'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket', 
           'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 
           'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair', 
           'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop', 'mouse', 'remote', 
           'keyboard', 'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'book', 
           'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner', 'blanket', 
           'bridge', 'cardboard', 'counter', 'curtain', 'door-stuff', 'floor-wood', 'flower', 'fruit', 
           'gravel', 'house', 'light', 'mirror-stuff', 'net', 'pillow', 'platform', 'playingfield', 
           'railroad', 'river', 'road', 'roof', 'sand', 'sea', 'shelf', 'snow', 'stairs', 'tent', 
           'towel', 'wall-brick', 'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'window-blind', 
           'window-other', 'tree-merged', 'fence-merged', 'ceiling-merged', 'sky-other-merged', 
           'cabinet-merged', 'table-merged', 'floor-other-merged', 'pavement-merged', 'mountain-merged', 
           'grass-merged', 'dirt-merged', 'paper-merged', 'food-other-merged', 'building-other-merged', 
           'rock-merged', 'wall-other-merged', 'rug-merged', 'background']

# 定义关系
PREDICATES = [
    'over', 'in front of', 'beside', 'on', 'in', 'attached to', 'hanging from', 'on back of', 
    'falling off', 'going down', 'painted on', 'walking on', 'running on', 'crossing', 'standing on', 
    'lying on', 'sitting on', 'flying over', 'jumping over', 'jumping from', 'wearing', 'holding', 
    'carrying', 'looking at', 'guiding', 'kissing', 'eating', 'drinking', 'feeding', 'biting', 
    'catching', 'picking', 'playing with', 'chasing', 'climbing', 'cleaning', 'playing', 'touching', 
    'pushing', 'pulling', 'opening', 'cooking', 'talking to', 'throwing', 'slicing', 'driving', 
    'riding', 'parked on', 'driving on', 'about to hit', 'kicking', 'swinging', 'entering', 'exiting', 
    'enclosing', 'leaning on'
]

# 读取图像
img_path = img  # 替换为你的图像路径
img = mmcv.imread(img_path)
img_h, img_w = img.shape[:-1]

# 获取 pan_results
pan_results = result.pan_results

# 处理 ids
ids = np.unique(pan_results)[::-1]
num_classes = 133
legal_indices = (ids != num_classes)  # 过滤掉 VOID 标签
ids = ids[legal_indices]

# 获取预测标签
# INSTANCE_OFFSET = 1000  # 确保这是正确的偏移量
labels = np.array([id % INSTANCE_OFFSET for id in ids], dtype=np.int64)

# 创建标签计数器,用于生成唯一编号
label_counter = {}
unique_labels = []
for label in labels:
    if label not in label_counter:
        label_counter[label] = 0
    label_counter[label] += 1
    if label_counter[label] > 1:
        unique_labels.append(f'{CLASSES[label]}_{label_counter[label]}')
    else:
        same_label_count = labels.tolist().count(label)
        if same_label_count == 1:
            unique_labels.append(f'{CLASSES[label]}')
        else:
            unique_labels.append(f'{CLASSES[label]}_1')

# 获取分割掩码
segms = pan_results[None] == ids[:, None, None]

# 绘制分割结果和标签
# 绘制图像
plt.figure(figsize=(15, 15))
plt.imshow(img)

# 绘制分割结果和标签
for i, segm in enumerate(segms):
    # 生成浅色颜色掩码
    color_mask = np.random.rand(3) * 0.7 + 0.3  # 确保颜色较浅
    mask = segm.astype(np.uint8)

    # 找到掩码的轮廓
    contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
    for contour in contours:
        polygon = Polygon(contour.reshape(-1, 2), fill=True, edgecolor=color_mask, facecolor=color_mask, linewidth=0.5, alpha=0.5)
        plt.gca().add_patch(polygon)

    # 添加标签
    y, x = np.where(mask)
    if len(y) > 0 and len(x) > 0:
        center_y, center_x = int(y.mean()), int(x.mean())
        offset_y, offset_x = center_y - 5, center_x - 5  # 调整标签位置
        label = unique_labels[i]  # 使用唯一标签
        plt.text(offset_x, offset_y, label, color=color_mask, fontsize=8, ha='center', va='center', bbox=dict(facecolor='black', alpha=0.7, edgecolor='none'))

# 显示结果
plt.axis('off')
plt.show()

# 创建关系图
G = nx.DiGraph()

# 创建颜色映射
node_colors = {}
for cls in CLASSES:
    node_colors[cls] = np.random.rand(3) * 0.7 + 0.3  # 生成浅色颜色

# 添加节点
for i, label in enumerate(unique_labels):
    node_type = label.split('_')[0]
    G.add_node(str(i), label=label, color=node_colors[node_type], size=20)

# 添加关系边
rels = result.rels
colors = plt.cm.tab20(np.linspace(0, 1, len(PREDICATES)))
for rel in rels:
    subj_idx, obj_idx, rel_label = rel
    subj_str = str(subj_idx)
    obj_str = str(obj_idx)
    color = colors[int(rel_label)]  # 转换为整数索引
    G.add_edge(subj_str, obj_str, label=PREDICATES[int(rel_label)], color=color)

# 确保关系图中包含所有对象的标签
for rel in rels:
    subj_idx, obj_idx, rel_label = rel
    subj_str = str(subj_idx)
    obj_str = str(obj_idx)
    if subj_str not in G.nodes:
        G.add_node(subj_str, label=unique_labels[subj_idx], color=node_colors[unique_labels[subj_idx].split('_')[0]], size=20)
    if obj_str not in G.nodes:
        G.add_node(obj_str, label=unique_labels[obj_idx], color=node_colors[unique_labels[obj_idx].split('_')[0]], size=20)

# 使用 spring_layout 生成节点布局,使图更加离散
pos = nx.spring_layout(G, k=0.5)

# 创建 PyVis 网络图
net = Network(notebook=True, width="1500px", height="1500px", directed=True, cdn_resources='remote')

# 从 NetworkX 图导入 PyVis 图
net.from_nx(G)

# 设置节点标签和颜色
for node in net.nodes:
    node['title'] = node['label']
    node['label'] = node['label']
    color = node_colors[node['label'].split('_')[0]]
    node['color'] = 'rgba({}, {}, {}, 1)'.format(int(color[0]*255), int(color[1]*255), int(color[2]*255))

# 设置边标签和颜色
for edge in net.edges:
    edge['title'] = edge['label']
    edge['color'] = 'rgba({}, {}, {}, 1)'.format(int(edge['color'][0]*255), int(edge['color'][1]*255), int(edge['color'][2]*255))

# 显示 PyVis 图
net.show('relationship_graph.html')

but it has some issue,could any one help me? output