这是一个基于Pytorch平台、Transformer框架实现的视频描述生成 (Video Captioning) 深度学习模型。
视频描述生成任务指的是:输入一个视频,输出一句描述整个视频内容的文字(前提是视频较短且可以用一句话来描述)。本repo主要目的是帮助视力障碍者欣赏网络视频、感知周围环境,促进“无障碍视频”的发展。
:yum: 这个repo是第七届“互联网+”北京赛区三等奖项目「以声绘影——基于人工智能的无障碍视频自动生成技术」的一部分。
:yum: 这个repo是北京市级大学生创新训练项目「基于深度学习的视频画面描述及无障碍视频研究」的一部分。
:yum: 这个repo的一部分已登记软件著作权2022SR0269902。
:warning: 本repo遵守Apache-2.0 License,详情请看库内LICENSE文件。不包括使用的数据集版权、submodule子目录下任何文件的版权。
当视频太长或较复杂时效果可能就很差了,针对长视频,目前有密集视频描述生成任务,即Dense Video Captioning,本项目暂时不涉及,但欢迎魔改这个repo。
如下图:CLIP是一个视觉-语言的大规模预训练模型,Clip4clip是将CLIP运用在视频检索任务的一种方法,SCE-loss是一个针对噪声较大数据集的损失函数。我们通过CLIP提取视频的特征,然后作为输入送入Transformer中,输出Caption,训练阶段用SCE-loss优化。
Java JRE (用来调用MS COCO eval server进行Bleu等评估)
Python3.6+
torch 1.8.2+
transformers 4.17.0
tensorboardX
tqdm
mmcv
numpy
pathlib
PIL
模型在训练集集上训练,根据验证集进行earlystop,下表结果为测试集结果。
训练数据集 | Bleu@4 | METEOR | ROUGE_L | CIDEr | 下载 | 配置文件 |
---|---|---|---|---|---|---|
MSVD | 58.0 | 39.9 | 77.0 | 113.8 | 百度网盘zmr4 Drive | config |
MSR-VTT | 48.1 | 31.1 | 65.1 | 60.2 | 百度网盘xy7e Drive | config |
git clone https://github.com/Kamino666/Video-Captioning-Transformer.git --recurse-submodules
python predict.py -c <config> -m <model> -v <video> \
--feat_type CLIP4CLIP-ViT-B-32 \
--ext_type uni_12 \
--greedy \
[--gpu/--cpu]
predict.py
内的注释。效果有的好有的差吧hhhhh |
---|
本repo使用MSR-VTT数据集和MSVD数据集
原始视频
特征文件
百度网盘 aupi包含了两个数据集的标注和特征
特征提取方法
特征提取使用我的另一个repo:Kamino666/video_features。
:warning:使用前请配置好Java
PATH=$PATH:<java_root> \
CUDA_VISIBLE_DEVICES=0,1,2,3 \
python eval.py -c <config> -m <model> [--gpu/--cpu]
PATH=$PATH:<java_root> \
CUDA_VISIBLE_DEVICES=0,1,2,3 \
python -m torch.distributed.run --nproc_per_node 4 train.py \
-c <config> --multi_gpu -ws 4
-ws
参数的值是使用的显卡数量。python train.py -c <config> --gpu
,若懒得改也可以直接把4换成1。配置文件是json格式的一个文件,在训练和预测时都需要用到。简单的说明如下:(部分配置可能不起作用或令人迷惑,是实验时添加的其他部分,可用性暂无法保证)
├── data
│ ├── train 训练数据
│ ├── validation 验证数据(用来计算loss)
│ └── eval 验证数据(用来计算Bleu等)
├── train 和训练方法有关的参数
│ ├── earlystop earlystop的patience
│ ├── epoch 最大epoch数
│ ├── save_dir 模型保存路径
│ ├── log_dir 日志保存路径(tensorboard读取)
│ └── tag 模型保存名称
├── test
└── model 和模型结构有关的参数
├── video_encoder Transformer编码器的参数
├── caption_decoder Transformer解码器的参数
└── modal_shape 模态的维度
Q:下载来自hugging face的模型失败
A:以bert-base-uncased
模型为例,在hugging face的模型网站上的下载页面可以看到一系列文件,如果是模型下载失败BertModel.from_pretrained()
,则下载.bin
文件,并把参数改成.bin
的路径;如果是tokenizer下载失败AutoTokenizer.from_pretrained()
,则下载config.json
、tokenizer.json
、tokenizer_config.json
、vocab.txt
四个文件,并把参数改成这四个文件所处目录路径。如果不想这么麻烦,可以科学上网。
Q:这个模型多大?
A:主要参数是总共4层的Transformer。
@misc{video,
author = {Zihao, Liu},
title = {{video captioning transformer}},
howpublished = {\url{https://github.com/Kamino666/Video-Captioning-Transformer}},
year = {2022}
}