Shuyu-XJTU / APTM

The official code of "Towards Unified Text-based Person Retrieval: A Large-scale Multi-Attribute and Language Search Benchmark"
https://arxiv.org/abs/2306.02898
MIT License
139 stars 12 forks source link

作者们好,请问我如何从模型获取好的特征向量 #5

Closed zzk2021 closed 10 months ago

zzk2021 commented 1 year ago

我使用自己的这个Evaluator进行测试

class Evaluator():
    def __init__(self, img_loader, txt_loader):
        self.img_loader = img_loader  # gallery
        self.txt_loader = txt_loader  # query

    def _compute_embedding(self, model):
        model = model.eval()
        device = next(model.parameters()).device

        qids, gids, qfeats, gfeats = [], [], [], []
        # text
        for pid, caption in self.txt_loader:

            for k, v in caption.items():
                caption[k] = v.to(device)
            with torch.no_grad():
                caption["input_ids"] = caption["input_ids"].squeeze(1)
                text_embeds = model.get_text_embeds(caption["input_ids"], caption["attention_mask"])
                text_feat =  model.text_proj(text_embeds[:, 0, :])
            qids.append(pid.view(-1))  # flatten
            qfeats.append(text_feat)
        qids = torch.cat(qids, 0)
        qfeats = torch.cat(qfeats, 0)
        # image
        for pid, img in self.img_loader:
            img = img.to(device)
            with torch.no_grad():
                image_embeds, image_atts = model.get_vision_embeds(img)
                img_feat = model.vision_proj(image_embeds[:, 0, :])
            gids.append(pid.view(-1))  # flatten
            gfeats.append(img_feat)
        gids = torch.cat(gids, 0)
        gfeats = torch.cat(gfeats, 0)
        return qfeats, gfeats, qids, gids

结果发现CUHK-PEDES的rank1只有55,我使用了bert的tokenizer。 代码中没有直接测试,而是对相似度矩阵进行了某种迭代,如何理解这一过程?我该如何获取好的文本向量和图像向量呢?

Shuyu-XJTU commented 1 year ago

您好,“代码中没有直接测试”的代码是指哪里的代码?