Source code of our paper Reading-strategy Inspired Visual Representation Learning for Text-to-Video Retrieval.
-
CUDA 10.1
-
Python 3.8.5
-
PyTorch 1.7.0
We used Anaconda to setup a deep learning workspace that supports PyTorch. Run the following script to install the required packages.
conda create --name rivrl_env python=3.8.5
conda activate rivrl_env
git clone https://github.com/LiJiaBei-7/rivrl.git
cd rivrl
pip install -r requirements.txt
conda deactivate
We use three public datasets: MSR-VTT, VATEX, and TGIF. Please refer to here for detailed description and how to download the datasets. Since our model uses additional Bert features, you shall download the pre-extracted Bert features from Baidu pan (url, password:6knd). You can also run the following script to download the features of BERT, the extracted data is placed in $HOME/VisualSearch/
.
ROOTPATH=$HOME/VisualSearch
mkdir -p $ROOTPATH && cd $ROOTPATH
mkdir bert_extract && cd bert_extract
# download the features of BERT
wget https://8.210.46.84:8787/rivrl/bert/<bert-Name>.tar.gz
tar zxf <bert-Name>.tar.gz -C $ROOTPATH
# <bert-Name> is msrvtt_bert, vatex_bert, and tgif_bert respectively.
Run the following script to train and evaluate RIVRL
network. Specifically, it will train RIVRL
network and select a checkpoint that performs best on the validation set as the final model. Notice that we only save the best-performing checkpoint on the validation set to save disk space.
ROOTPATH=$HOME/VisualSearch
conda activate rivrl_env
# To train the model on the MSR-VTT, which the feature is resnext-101_resnet152-13k
# Template:
./do_all_msrvtt.sh $ROOTPATH <split-Name> <useBert> <gpu-id>
# Example:
# Train RIVRL with the BERT on MV-Yu
./do_all_msrvtt.sh $ROOTPATH msrvtt10yu 1 0
<split-Name>
indicates different partitions of the dataset. msrvtt10yu
, msrvtt10k
,msrvtt10kmiech
respectively denotes the partition of MV-Yu, MV-Miech and MV-Xu.
<useBert>
indicates whether training with BERT as additional text feature. 1 means using the BERT feature, while 0 indicates we do not use it.
<gpu-id>
is the index of the GPU where we train on.
Run the following script to download and evaluate our trained checkpoints. The trained checkpoints can also be downloaded from Baidu pan (url, password:wb3c).
ROOTPATH=$HOME/VisualSearch
# download trained checkpoints
wget -P $ROOTPATH https://8.210.46.84:8787/rivrl/best_model/msrvtt/<best_model>.pth.tar
# <best_model> is mv_yu_best, mv_yu_Bert_best, mv_miech_best, mv_miech_Bert_best, mv_xu_best, or mv_xu_Bert_best.
tar zxf $ROOTPATH/<best_model>.pth.tar -C $ROOTPATH
# evaluate on MSR-VTT
# Template:
./do_test.sh $ROOTPATH <split-Name> $MODELDIR <gpu-id>
# $MODELDIR is the path of checkpoints, $ROOTPATH/.../runs_0
# Example:
# evaluate on MV-Yu
./do_test.sh $ROOTPATH msrvtt10kyu $MODELDIR 0
The expected performance and corresponding pre-trained checkpoints of RIVRL on MSR-VTT is as follows. Notice that due to random factors in SGD based training, the numbers differ slightly from those reported in the paper.
DataSet | Splits | BERT | Text-to-Video Retrieval | SumR | Pre-trained Checkpoints | ||||
---|---|---|---|---|---|---|---|---|---|
R@1 | R@5 | R@10 | MedR | mAP | |||||
MSR-VTT | MV-Yu | w/o | 24.2 | 51.5 | 63.8 | 5 | 36.86 | 139.5 | mv_yu_best.pth.tar |
with | 27.9 | 59.3 | 71.3 | 4 | 42.0 | 158.4 | mv_yu_Bert_best.pth.tar | ||
MV-Miech | w/o | 25.3 | 53.6 | 67.0 | 4 | 38.5 | 145.9 | mv_miech_best.pth.tar | |
with | 26.2 | 56.6 | 68.2 | 4 | 39.92 | 151.0 | mv_miech_Bert_best.pth.tar | ||
MV-Xu | w/o | 12.9 | 33.0 | 44.6 | 14 | 23.07 | 90.5 | mv_xu_best.pth.tar | |
with | 13.7 | 34.6 | 46.4 | 13 | 24.19 | 94.6 | mv_xu_Bert_best.pth.tar |
Run the following script to train and evaluate RIVRL
network on VATEX.
ROOTPATH=$HOME/VisualSearch
conda activate rivrl_env
# To train the model on the VATEX
./do_all_vatex.sh $ROOTPATH <useBert> <gpu-id>
Run the following script to download and evaluate our trained model on the VATEX from Baidu pan (url, password:wb3c).
ROOTPATH=$HOME/VisualSearch
# download trained checkpoints and evaluate
wget -P $ROOTPATH https://8.210.46.84:8787/rivrl/best_model/vatex/<best_model>.pth.tar
# <best_model> is vatex_best or vatex_Bert_best
tar zxf $ROOTPATH/<best_model>.pth.tar -C $ROOTPATH
# evaluate on VATEX
./do_test.sh $ROOTPATH vatex $MODELDIR <gpu-id>
# $MODELDIR is the path of checkpoints, $ROOTPATH/.../runs_0
The expected performance and corresponding pre-trained checkpoints of RIVRL on VATEX is as follows.
DataSet | Splits | BERT | Text-to-Video Retrieval | SumR | Pre-trained Checkpoints | ||||
---|---|---|---|---|---|---|---|---|---|
R@1 | R@5 | R@10 | MedR | mAP | |||||
VATEX | w/o | 39.4 | 76.1 | 84.8 | 2 | 55.3 | 200.4 | vatex_best.pth.tar | |
with | 39.1 | 76.7 | 85.4 | 2 | 55.4 | 201.0 | vatex_Bert_best.pth.tar |
Run the following script to train and evaluate RIVRL
network on TGIF.
ROOTPATH=$HOME/VisualSearch
conda activate rivrl_env
# To train the model on the TGIF-Li
./do_all_tgif_li.sh $ROOTPATH <useBert> <gpu-id>
# To train the model on the TGIF-Chen
./do_all_tgif_chen.sh $ROOTPATH <useBert> <gpu-id>
Run the following script to download and evaluate our trained model on the TGIF from Baidu pan (url, password:wb3c).
ROOTPATH=$HOME/VisualSearch
# download trained checkpoints
wget -P $ROOTPATH https://8.210.46.84:8787/rivrl/best_model/tgif/<best_model>.pth.tar
# <best_model> is tgif_li_best, tgif_li_Bert_best, tgif_chen_best and tgif_chen_Bert_best, respectively.
tar zxf $ROOTPATH/<best_model>.pth.tar -C $ROOTPATH
# evaluate on the TGIF-Li
./do_test.sh $ROOTPATH tgif-li $MODELDIR <gpu-id>
# evaluate on the TGIF-Chen
./do_test.sh $ROOTPATH tgif-chen $MODELDIR <gpu-id>
# $MODELDIR is the path of checkpoints, $ROOTPATH/.../runs_0
The expected performance and corresponding pre-trained checkpoints of RIVRL on TGIF is as follows.
DataSet | Splits | BERT | Text-to-Video Retrieval | SumR | Pre-trained Checkpoints | ||||
---|---|---|---|---|---|---|---|---|---|
R@1 | R@5 | R@10 | MedR | mAP | |||||
TGIF | TGIF-Li | w/o | 11.3 | 25.3 | 33.6 | 34 | 18.7 | 70.3 | tgif_li_best.pth.tar |
with | 12.1 | 26.6 | 35.1 | 29 | 19.75 | 73.8 | tgif_li_Bert_best.pth.tar | ||
TGIF-Chen | w/o | 6.4 | 16.1 | 22.4 | 91 | 11.81 | 44.9 | tgif_chen_best.pth.tar | |
with | 6.8 | 17.2 | 23.5 | 79 | 12.45 | 47.4 | tgif_chen_Bert_best.pth.tar |
If you find the package useful, please consider citing our paper:
@article{dong2022reading,
title={Reading-strategy Inspired Visual Representation Learning for Text-to-Video Retrieval},
author={Dong, Jianfeng and Wang, Yabing and Chen, Xianke and Qu, Xiaoye and Li, Xirong and He, Yuan and Wang, Xun},
journal={IEEE Transactions on Circuits and Systems for Video Technology},
year={2022}
}