amankhullar / mast

Code for the paper Multimodal Abstractive Summarization with Trimodal Hierarchical Attention
19 stars 2 forks source link

where is pretrained file? #4

Open Kingdu97 opened 1 year ago

Kingdu97 commented 1 year ago

pretrained_file: /home/aman_khullar/multimodal/save_mm3/Feb28/text_actions_rnn_hier_mm3_300/attentivemnmtfeatures-dec128-emb256-enc128-adam_4e-04-do_ctx_0.3-do_emb_0.3-do_out_0.3-att_mlp-ctx_hierarchical-bs4-rouge-each1000-2waytied-di_meanctx-r658a0.best.rouge.ckpt

i can't catch this. when i python nmtpy train in config file. i couldn't know where is pretrained_file.

Where is this???????!!

Chenxuanqi666 commented 10 months ago

You need to divide the downloaded files into cv, devtest and train according to the audio id, specifically according to the configure directory file,the following is my code, you can chage the dir or file path :

import os import shutil

import torch

dir_list = ['devtest', 'train', 'cv']

def reload_vid(): for task in dir_list: vidids = [] for filename in os.listdir('E:/data/how2/how2-300h-v1/features/vid{}_300/video'.format(task)): vidids.append(filename.split('.')[0]) with open('E:/data/how2/how2-300h-v1/features/vid{}_300/vid_id.txt'.format(task), 'a') as f: for vid in vid_ids: f.write(vid + '\n')

    print("{}完成视频数据构建".format(task))

def getVid_id(): concat = "E:/data/how2/how2_video_sum/video_action_features/" how2_dir = "E:/data/how2/how2-300h-v1" destdir = "E:/data/how2/how2-300h-v1/features/vid{}_300"

for task in dir_list:
    vid_ids = []
    aud_ids = []
    with open(how2_dir + '/data/{}/cmvn.scp'.format(task)) as f:
        for line in f:
            aud, path = line.split()
            aud_ids.append(aud)
    with open(how2_dir + '/features/vid_{}_300/vid_id.txt'.format(task)) as f1:
        for line in f1:
            vid_ids.append(line.strip('\n'))

    print(len(vid_ids))
    print(len(aud_ids))
    for aud in aud_ids:
        if aud not in vid_ids:
            # 接下来是删除多余的aud文件
            # E:\data\how2\how2-300h-v1\features\aud_dev5_300\audio
            os.remove("E:/data/how2/how2-300h-v1/features/aud_{}_300/audio/{}.npy".format(task, aud))
            print("已删除任务 {} 中多余文件 {}.npy".format(task, aud))

    # for vid in vid_ids:
    #     with open("E:/data/how2/how2-300h-v1/features/vid_{}_300/vid_id.txt".format(task), 'a') as f1:
    #         f1.write(vid + '\n')
    # try:
    #     shutil.copy(concat + vid + '.npy', "E:/data/how2/how2-300h-v1/features/vid_{}_300/video/".format(task))
    # except:
    #     print("没有这个文件{}".format(vid + '.npy'))

    print("{}完成视频数据构建".format(task))

def update_aud_txt(): for task in dir_list: Filelist = [] for filename in os.listdir("E:/data/how2/how2-300h-v1/features/audio{}_300".format(task)): Filelist.append(filename) with open("E:/data/how2/how2-300h-v1/features/audio{}_300.txt".format(task, task), 'a') as f: for filename in Filelist: f.write( os.path.join('E:/data/how2/how2-300h-v1/features/aud{}_300/audio'.format(task), filename) + '\n') print("{}任务的音频已经完成 audio{}_300.txt 已经完成构建".format(task, task))

def update_act_txt(): for task in dir_list: Filelist = []

以audio的id为标准

    for file_name in os.listdir("E:/data/how2/how2-300h-v1/features/aud_{}_300".format(task)):
        Filelist.append(file_name)
    with open("E:/data/how2/how2-300h-v1/features/actions_{}_300.txt".format(task, task), 'a') as f:
        for file_name in Filelist:
            f.write(
                os.path.join('E:/data/how2/how2-300h-v1/features/actions_{}_300/'.format(task), file_name) + '\n')
    print("{}任务的音频已经完成 actions_{}_300.txt 已经完成构建".format(task, task))

def creat_text_300():

获得所有txt_file

desc_files = []
tran_files = []
for task in dir_list:
    # desc.tok.txt
    with open('E:/data/how2/text/sum_{}/desc.tok.txt'.format(task), 'r') as f:
        for line in f.readlines():
            desc_files.append(line)

    # tran.tok.txt
    with open('E:/data/how2/text/sum_{}/tran.tok.txt'.format(task), 'r', encoding='utf-8') as f:
        for line in f.readlines():
            tran_files.append(line)

for task in dir_list:
    idList = []
    # 以audio的id为标准
    for file_name in os.listdir("E:/data/how2/mast/aud_{}_300".format(task)):
        idList.append(file_name.split('.')[0])
    with open("E:/data/how2/mast/text_300/sum_{}_300/desc.tok.txt".format(task), 'a') as f:
        tmp_idList = []
        for id in idList:
            tmp_idList.append(id)
        for txt in desc_files:
            txt_file = txt.split(' ')[0]
            if txt_file in idList:
                f.write(txt)
                tmp_idList.remove(txt_file)
        print("剩下是{}任务中没有找到文本的id".format(task))
        print(tmp_idList)
        print("{}任务的 sum_{}_300/desc.tok.txt 已经构造完毕".format(task, task))

    with open("E:/data/how2/mast/text_300/sum_{}_300/tran.tok.txt".format(task), 'a', encoding='utf-8') as f:
        for id in idList:
            tmp_idList.append(id)
        for txt in tran_files:
            txt_file = txt.split(' ')[0]
            if txt_file in idList:
                f.write(txt)
                tmp_idList.remove(txt_file)
        print("剩下是{}任务中没有找到文本的id".format(task))
        print(tmp_idList)
        print("{}任务的 sum_{}_300/tran.tok.txt 已经构造完毕".format(task, task))

if name == 'main':

getVid_id()

# update_aud_txt()
# update_act_txt()
# creat_text_300()
print(torch.cuda.current_device())
print(os.environ.get('CUDA_VISIBLE_DEVICES'))
# reload_vid()