IMCCretrieval / MomentDiff

MomentDiff: Generative Video Moment Retrieval from Random to Real--NeurIPS 2023
Other
75 stars 0 forks source link

glove特征 #6

Closed Song200103 closed 7 months ago

Song200103 commented 7 months ago

您好,在运行您的源代码时缺少该glove特征,请问该如何获得glove特征

Song200103 commented 7 months ago

我使用NLTKTokenizer提取glove特征的代码如下:

import os
import json
import numpy as np

from tqdm import tqdm
import nltk
nltk.download('punkt')
# Load GloVe vectors into memory
def load_glove_vectors(glove_file):
    print("Loading GloVe vectors...")
    embeddings_index = {}
    with open(glove_file, "r", encoding="utf-8") as f:
        for line in f:
            values = line.split()
            word = values[0]
            coefs = np.asarray(values[1:], dtype="float32")
            embeddings_index[word] = coefs
    print("Total GloVe vectors:", len(embeddings_index))
    return embeddings_index

def get_sentence_embedding(sentence, embeddings_index):
    words = nltk.word_tokenize(sentence.lower())  # Tokenize the sentence into words
    embedding_dim = len(next(iter(embeddings_index.values())))  # Get embedding dimension
    sentence_embedding = np.zeros(embedding_dim)  # Initialize sentence embedding
    count = 0
    for word in words:
        embedding_vector = embeddings_index.get(word)
        if embedding_vector is not None:
            sentence_embedding += embedding_vector
            count += 1
    if count != 0:
        sentence_embedding /= count  # Calculate average embedding
    return sentence_embedding

def load_jsonl(filename):
    with open(filename, "r") as f:
        return [json.loads(l.strip("\n")) for l in f.readlines()]

if __name__ == '__main__':
    glove_file = 'glove.6B.300d.txt'
    anno_path = '/data3/sxf/MomentDiff/data/charades/charades_sta_test_tvr_format.jsonl'
    annotations = load_jsonl(anno_path)

    embeddings_index = load_glove_vectors(glove_file)

    for anno in tqdm(annotations, desc="Generating text features"):
        query = anno['query']
        qid = anno['qid']

        text_embedding = get_sentence_embedding(query, embeddings_index)

        save_npz_name = f"qid{qid}.npz"
        save_path = os.path.join('/data3/sxf/MomentDiff/data/charades_features/glove_txt_features/', save_npz_name)

        np.savez(save_path, text_embedding=text_embedding)

当我使用我提取的glove特征时,会出现以下错误:

KeyError: 'last_hidden_state is not a file in the archive' 是否是因为我提取的glove不正确呢?

IMCCretrieval commented 7 months ago

你好,请把特征提取的embeding_text改成last_hidden_state。谢谢关注!