Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dense Video Captioning on raw input videos #11

Open
harpavatkeerti opened this issue Jun 16, 2020 · 23 comments
Open

Dense Video Captioning on raw input videos #11

harpavatkeerti opened this issue Jun 16, 2020 · 23 comments

Comments

@harpavatkeerti
Copy link

It seems a nice work. I wanted to test it on custom input videos. It would be very helpful if you can provide a script for generating video captions for a raw input video.

@v-iashin
Copy link
Owner

Hi. Thanks for the interest in our work!

I am afraid there are several obstacles that are hard to solve before fulfilling your request, namely:

  1. Google ASR system used on YouTube has a proprietary license. We don't have access to it and merely used its output downloaded for each video from the service.
  2. We used the predictions of the BAFCG (Bi-SST) for proposals. The repo is great yet missing the script for generating proposals for a single video.

Therefore, we are not supporting this feature 🙁.

At the same time, let me provide some advice on how to try to mitigate the issues.

  1. If your custom video is not on YouTube, start by looking for a solution that will do the ASR for a video. In a case, if your videos are from YouTube, try to obtain the CC for the video (how? I can only give you a hint: "Google is your friend" 😉). In many cases, you will need to pre-process the subtitles because the lines will overlap temporally with their neighbors. If you got this far, create an issue, and I will provide the script we used for this work. We also describe the problem in README.md (last paragraph) in the supplementary material.
  2. Try to ask the authors of the BAFCG. I contacted the first author on a different matter about the paper, and they helped me a lot. Also, make sure to specify what you want. Their model may predict the proposals in two different settings: with and without joint ranking of the proposals. The joint ranking uses the scores of the predicted captions using these proposals. In our work, we use the predictions before joint ranking.

On a good note, we made the rest of the ingredients to be more accessible. Specifically:

  • You can extract visual and audio features using our library called video_features. If I do say so myself, it is easy to use and convenient if you want to extract features from a ton of videos if you have several GPUs at your disposal.
  • Check out ./epoch_loop/run_epoch.py (predict_1by1_for_TBoard) function. It generates one-by-one caption inference on a pre-defined set of videos (in Config) from the validation dataset using the timestamps of the ground truth proposals at each epoch.
  • Also, feel free to check out our other work on the topic. In particular, see the script for the single video prediction. It does exactly what you asked but with a better, more efficient audio-visual model.

Of course, let me know if you have any further questions about the procedure via e-mail or in Issues in this repo.

@v-iashin v-iashin pinned this issue Jun 16, 2020
@harpavatkeerti
Copy link
Author

Hi, thanks for your prompt and detailed reply.
For 1, I currently have the captions as well as timestamps, generated using Google ASR. So your scripts would be a great help.

Regarding the single video prediction in BMT, I already used the video_features repo, and the scripts you provided. Running that code is very convenient thanks to the detailed steps you have provided. But the work that I am currently pursuing requires me to generate descriptions grounded to the audio transcript. So that is why I was trying to get this code running, as it takes speech also into account.

v-iashin added a commit that referenced this issue Jun 16, 2020
@harpavatkeerti
Copy link
Author

Thank you very much, I will try this script, and hopefully can run the code for single video.

@harpavatkeerti
Copy link
Author

Sorry to bother you again. I tried that script, it is working fine. I have a few other doubts.

  1. For the feature extraction, the repo which you mentioned, it gives the output in .npy (flow,rgb and
    vggish) files. But this code requires hdf5 format. So what script should I use to get the features in
    hdf5 format?
  2. Also, I saw the run_epoch.py ( predict_1by1_for_TBoard and validate 1by1). Both of them seem to
    be very close to what I need, but I can't figure out the way to pass on the csv file of captions, and
    the features files directly instead of the loader. Any guidance on this?

@v-iashin
Copy link
Owner

No problem. Thanks for your questions.

  1. Well, it doesn't matter if you use the features stored in a .hdf5 format or as a separate file. One benefit of storing it in .hdf5 was the ease of experimentation on a cluster with a Lustre file-system. Initially, we were using hdf5 format, and features were appended to the file on extraction. I decided to omit this part as it is a bit cumbersome and could be done after given a folder of features. Hence, in BMT (load_features_from_npy) we load them from separate files. So, I think you may just adapt the code from BMT. On the other hand, if you really want to store them in a .hdf5 file, here are some snippets. Start by initializing an hdf5 object:
import h5py
# create empty hdf5 files on your disk
h5py.File(hdf5_vggish_features_path, 'w').close()
h5py.File(hdf5_i3d_features_path, 'w').close()

I will assume that you extracted features for your custom videos in some folder. This means that you may have several files with your features for a video. For instance, for a video_1.mp4 the folder contains video_1_vggish.npy, video_1_rgb.npy, video_1_flow.npy which should be stored in two different hdf5 files. Therefore, we care only for the beginning of the file paths

import os
# returns an unsorted list of filenames from `features_path`
list_of_files = os.listdir(features_path)
# append filenames to parent directory (and making sure non-`.npy` files are ignored)
list_of_paths = [os.path.join(features_path, fname) for fname in list_of_files if fname.endswith('.npy')]
# we will care only for the beginning of a file path (only `/some_path/video_1`)
paths = [path.replace('_vggish.npy', '').replace('_rgb.npy', '').replace('_flow.npy', '') for path in list_of_paths]
# we expect to have 3 duplicate paths for each video. Remove duplicates
paths = list(set(paths))

Then, open the files, read features from numpy files, and append features in a for-loop

import numpy as np
# start context managers ('a' == append) for both files 
with h5py.File(hdf5_vggish_features_path, 'a') as hd5vgg, h5py.File(hdf5_i3d_features_path, 'a') as hd5i3d:
    # the for-loop
    for path in paths:
        # construct new paths
        vggish_path = f'{path}_vggish.npy'
        rgb_path = f'{path}_rgb.npy'
        flow_path = f'{path}_flow.npy'

        # loading numpy files
        vggish = np.load(vggish_path)
        rgb = np.load(rgb_path)
        flow = np.load(flow_path)

        # extract video names from the paths (relying on the rgb path only)
        # os.path.split() outputs a list which contains parent dir path [0] and filename [1]
        # and removing the part with '_rgb.npy' (`video_1_rgb.npy` -> `video_1`)
        video_name = os.path.split(rgb_path)[-1].replace('_rgb.npy', '')

        # append features to the hdf5 files
        # VGGish
        hd5vgg.create_dataset(f'{video_name}/vggish_features', vggish.size(), data=vggish)
        # RGB
        hd5i3d.create_dataset(f'{video_name}/i3d_features/rgb', rgb.size(), data=rgb)
        hd5i3d.create_dataset(f'{video_name}/i3d_features/flow', flow.size(), data=flow)

Please note that I provide this code just for guidance as I neither compiled this code nor tested it locally. Adapt it to your needs.

That's it. If you like you can update the code of video_features and initialize the hdf5 files at the beginning of parallel_feature_extraction() and add another on_extraction action which would start the context manager (with h5py.File(..., 'a') part) and append the extracted features to the created file, this would resemble our implementation of feature extraction for MDVC.

  1. Sure. You will need to first initialize the validation dataset with the files which are provided in this repo. This is done just to initialize the train vocabs (for captions and subs) and extract some technical things from there which were used at training (special tokens, device, etc.):
import torch
from torch.utils.data import DataLoader
from dataset.dataset import ActivityNetCaptionsIteratorDataset

val_dataset = ActivityNetCaptionsIteratorDataset(
    '<s>', '</s>', '<blank>', 1, 
    28, './data/sub_activitynet_v1-3.i3d_25fps_stack24step24_2stream.hdf5', 'i3d', 
    False, False,
    './data/sub_activitynet_v1-3.vggish.hdf5', 'vggish', 
    False, False, 
    './data/train_meta.csv', './data/val_1_meta.csv', './data/val_2_meta.csv', 
    torch.device('cuda:0'), 'val_1', 'subs_audio_video', 
    False, props_are_gt=True, get_full_feat=False
)

val_loader = DataLoader(val_dataset, collate_fn=val_dataset.dont_collate)

Next, you will want to update the predict_1by1_for_TBoard function a bit. Something like this:

import pandas as pd
import h5py

def predict_1by1_for_TBoard(your_meta_path, vggish_hdf5_features_path, i3d_hdf5_features_path, 
    vid_ids_list, val_loader, decoder, model, max_len=100):
    '''
        your_meta_path: path to your .csv
        *_hdf5_features_path: path to your hdf5 files
        vid_ids_list: the ids which will be used to filter meta file. For example: ['video_1', 'video_2']
        val_loader: object defined above
        decoder: just pass the greedy_decoder function
        model: pass the pre-trained model
        max_len: largest caption possible. The generation will stop if exceeded
    '''
    # for dataframe example see `./data/val_1_meta.csv`. Make sure your video_id will correspond to the
    # filenames w/o extension (`video_1_rgb.npy` -> `video_1`) which you used to create hdf5 files 
    # because load_multimodal_features_from_h5() will use them
    # as well as they should be present in `vid_ids_list` variable
    meta = pd.read_csv(your_meta_path, sep='\t')

    # re-define hdf5 files with your custom ones
    feat_h5_audio = h5py.File(vggish_hdf5_features_path, 'r')
    feat_h5_video = h5py.File(i3d_hdf5_features_path, 'r')

    feature_names = val_loader.dataset.feature_names
    device = val_loader.dataset.device
    start_idx = val_loader.dataset.start_idx
    end_idx = val_loader.dataset.end_idx
    pad_idx = val_loader.dataset.pad_idx
    modality = val_loader.dataset.modality
    
    text = ''

    for vid_id in vid_ids_list:
        meta_subset = meta[meta['video_id'] == vid_id]
        text += f'\t {vid_id} \n'

        for (video_id, cap, start, end, duration, category, subs, phase, idx) in meta_subset.values:
            
            feature_names_list = val_loader.dataset.features_dataset.feature_names_list
            train_subs_vocab = val_loader.dataset.train_subs_vocab

            # rgb is padded with pad_idx; flow is padded with 0s: expected to be summed later
            video_stack_rgb, video_stack_flow, audio_stack = load_multimodal_features_from_h5(
                feat_h5_video, feat_h5_audio, feature_names_list, video_id, start, end, duration
            )

            subs_stack = encode_subs(train_subs_vocab, idx, meta, start_idx, end_idx)
            
            video_stack_rgb = video_stack_rgb.unsqueeze(0).to(device)
            video_stack_flow = video_stack_flow.unsqueeze(0).to(device)
            audio_stack = audio_stack.unsqueeze(0).to(device)
            subs_stack = subs_stack.unsqueeze(0).to(device)

            stack = video_stack_rgb + video_stack_flow, audio_stack, subs_stack
            
            trg_ints = decoder(model, stack, max_len, start_idx, end_idx, pad_idx, modality)
            trg_ints = trg_ints.cpu().numpy()[0]
            trg_words = [val_loader.dataset.train_vocab.itos[i] for i in trg_ints]
            en_sent = ' '.join(trg_words)

            text += f'\t P sent: {en_sent} \n'
            text += f'\t P proposals: {start//60:.0f}:{start%60:02.0f} {end//60:.0f}:{end%60:02.0f} '
            
        text += '\t \n'
    
    return text

Please use ./data/val_1_meta.csv as your guidance and the comment before pd.read_csv() to construct a proper .csv. You can use some gibberish text for the caption column and 1.0 for category_32.

Again I haven't tested the code, and I really hope it will work. Let me know if you have any further questions.

@harpavatkeerti
Copy link
Author

Thank you very much

@harpavatkeerti
Copy link
Author

Thanks a lot for your help, this code works for predicting captions for a single video, and the results are also good.
(Just a minor change for anybody else who uses this, change hd5vgg.create_dataset(f'{video_name}/vggish_features', vggish.size(), data=vggish) to
hd5vgg.create_dataset(f'{video_name}/vggish_features', data=vggish) and similarly for the others.) Rest of the code is working perfectly fine.)

@v-iashin
Copy link
Owner

Great 🎉!

Yep, I also noticed it today 🙂. .size() should fail with an error as numpy.ndarray doesn't have size attribute (but pytorch has). I don't remember if there were any reason to specify the size of the dataset and why I was there in the first place. Anyway, .shape might work, but if you say that it can be omitted, then it is even better.

If you have a working example that you are comfortable with sharing, just type it here, or we may also discuss how to form a pull request. I think even it would be interesting even for youtube videos only.

@harpavatkeerti
Copy link
Author

Sure!
This is one of the examples I tested it on.
Adidas ad https://www.youtube.com/watch?v=DpR50O1nGNs&feature=youtu.be
Below are the subtitles I generated using ASR and the corresponding captions generated by your code.
1
0.2-->12.4
have a great first day lockers on the left good morning everyone activities are about to commence 42.

2
12.4-->16.1
I just want to say what's remember why we're here we're here to change things.

3
16.1-->23.4
we can all do amazing things in our own but together we can do so much more to make things better than they were before.

4
23.4-->52.1
having a teacher in your craft and a student in someone else's oh nevermind me say hi to Jen soda our newest team member share your strength and receive strength and most importantly have fun if you're smiling you're doing it right change is a team sport and we'd.

5
52.1-->53.2
be honored to be on a team with you

 adidas 
 P sent: <s> the man is now standing in the court and the other men are running and the other man is shown in the middle of the court </s> 
     P proposals: 0:00 0:12 	 P sent: <s> a man is talking to the camera </s> 
 P proposals: 0:12 0:16 	 P sent: <s> a man is seen standing in front of a large crowd </s> 
 P proposals: 0:16 0:23 	 P sent: <s> the man is then seen speaking to the camera and leads into several clips of people playing the game </s> 
 P proposals: 0:23 0:52 	 P sent: <s> a man is standing in front of a large green field </s> 
 P proposals: 0:52 0:53 	 

`

@v-iashin
Copy link
Owner

🙂 I meant the script which takes a subs file, vggish and i3d features and outputs a set of predictions for, at least, the GT proposals.

@harpavatkeerti
Copy link
Author

harpavatkeerti commented Jun 19, 2020

Oh, my bad!
I created .csv file from the subtitles using your previous code

import os
import re
import pandas as pd

import subprocess
import argparse

def get_length(filename):
    result = subprocess.run(["ffprobe", "-v", "error", "-show_entries",
                             "format=duration", "-of",
                             "default=noprint_wrappers=1:nokey=1", filename],
        stdout=subprocess.PIPE,
        stderr=subprocess.STDOUT)
    return float(result.stdout)

def parse_timestamp(timestamp):
    '''
        Extracts start and end points of a subtitles
        '00:00:01,320 --> 00:00:19,609'
    '''
    start, end = timestamp.split('-->')
#     s_hours, s_mins, s_secs = start.split(':')
#     e_hours, e_mins, e_secs = end.split(':')
#     s_secs = int(s_hours)*3600 + int(s_mins)*60 + round(float(s_secs), 2)
#     e_secs = int(e_hours)*3600 + int(e_mins)*60 + round(float(e_secs), 2)

    # floats
#     return s_secs, e_secs
    return float(start), float(end)
def clean_text(text):
    '''
    cleans things like:
        <font color="#E5E5E5">you</font><font color="#CCCCCC"> just better</font> you just better
        [Music]
        [Applause]
    '''
    text = re.sub('<[^>]*>', ' ', text).strip()
    text = re.sub('\s{2,}', ' ', text)
    text = text.replace('[Music]', '')
    text = text.replace('[Applause]', '')

    # str
    return text

def parse_sub(in_stream):
    '''
        Extracts start, end, and subtitle from a two lines given the input stream open()
    '''
    # read new lines that contain the content of a sub
    num = in_stream.readline().replace('\n', '')

    # end file
    if len(num) == 0:
        return None, None, None

    timestamp = in_stream.readline().replace('\n', '')
    assert len(timestamp) != 0
    text = in_stream.readline().replace('\n', '')

    # moving pointer over the empty line
    in_stream.readline()

    # extract start and end times
    start_secs, end_secs = parse_timestamp(timestamp)
    # clean the content of a sub
    text = clean_text(text)

    # floats and str
    return start_secs, end_secs, text

def parse_sub_file(path):
    '''
        Parses a subtitle file.
    '''
    starts, ends, texts = [], [], []
    in_stream = open(path, 'r')

    # while the end of the file has been reached
    while in_stream:
        start_secs, end_secs, text = parse_sub(in_stream)

        if (start_secs is None) or (end_secs is None) or (text is None):
            break
        else:
            starts.append(start_secs)
            ends.append(end_secs)
            texts.append(text)

    # sanity check
    line_number = len(open(path, 'r').readlines())
    if (line_number - len(texts) * 4) > 1:
        print(path, line_number, len(texts) * 4)

    # lists
    return starts, ends, texts

def add_adjusted_end_2_df(subs_dataframe):
    '''
        Given a pandas dataframe, adjusts the start and end points to address the following problem:

        YouTube displays the previous speech segment, as well as the new one, appears when
        somebody is speaking. When the current line is finished, it replaces the previous one
        while the new one start to appear on the screen and so on. Therefore, the starting
        point is quite accurate when the ending point is not. Considering the fact that the
        previous speech segment is ended by the start of the next one, we may adjust the
        ending point to be the start of the next segment within one video.
    '''
    subs_dataframe['video_id_next'] = subs_dataframe['video_id'].shift(periods=-1)
    subs_dataframe['start_next'] = subs_dataframe['start'].shift(periods=-1)
    subs_dataframe['end_next'] = subs_dataframe['end'].shift(periods=-1)

    # defining it here to use in in dataframe.apply instead of a lambda funcion
    def adjust_end_time(row):
        if row['video_id_next'] == row['video_id']:
            return min(row['end'], row['start_next'])
        else:
            return row['end']

    subs_dataframe['end_adj'] = subs_dataframe.apply(adjust_end_time, axis=1)
    # filter columns that end with '_next' (temp columns)
    subs_dataframe = subs_dataframe.filter(regex='.+(?<!_next)$')

    return subs_dataframe

def filter_dataframe(dataframe):
    '''
        Some sanity check filtering: start ponint is too far
        or if sub is an empty string
    '''
    dataframe = dataframe[dataframe['start'] < 5000].reset_index(drop=True)
    dataframe = dataframe[dataframe['subs'].apply(lambda x: len(x) > 0)].reset_index(drop=True)
    return dataframe

def subtitles_dataframe(video_name, save_path=None):
    '''
        creates a pd.DataFrame object and saves .csv
    '''

    video_ids_acc = []
    starts_acc = []
    ends_acc = []
    subs_acc = []
    comments_acc = []

    filename_path = f'./../video/videos/{video_name}.srt'

    # repeats the same procedure for each folder with subs (en, translated, other)

    comment = "asr_en"
#     for i, filename in enumerate(sorted(os.listdir(subs_folder))):
#         filename_path = os.path.join(subs_folder, filename)
    filename=video_name + ".srt"
    starts, ends, subs = parse_sub_file(filename_path)
    video_id = video_name
    video_ids_acc += [video_id] * len(starts)
    starts_acc += starts
    ends_acc += ends
    subs_acc += subs
    comments_acc += [comment] * len(starts)

    dur = get_length('./../video/videos/' + video_name + '.mp4')

    dataframe = pd.DataFrame({
        'video_id': video_ids_acc,
        'caption': "none",
        'start': starts_acc,
        'end': ends_acc,
        'duration': dur,
        'category_32': "none",
        'subs': subs_acc,
        'phase': 'single_vid',
        'idx': 0,
    })

#     dataframe = add_adjusted_end_2_df(dataframe)
    print(f'Dataset size before filtering: {dataframe.shape}')
    dataframe = filter_dataframe(dataframe)
    print(f'Dataset size after filtering: {dataframe.shape}')
    print(f'save_path: {save_path}')
    if save_path is not None:
        dataframe.to_csv(save_path, index=None, sep='\t')
    return dataframe


if __name__ == "__main__":
    '''
        (mdvc) $ python ./utils/parse_subs.py
    '''
    # make sure to unzip the subs.zip
    # we are using only `en` folder but you can play with other ones
    # we tried with `en` + `translated` but it didn't improve the results
    parser = argparse.ArgumentParser()

    parser.add_argument(
        '-v', '--video', type=str, default='Skate',
        help='name of the srt file w/o .srt'
    )
    args = parser.parse_args()

    video_name = str(args.video)

    save_path = f'./../video/videos/meta_{video_name}.csv'
    subs_dataframe = subtitles_dataframe(video_name, save_path)
    print(subs_dataframe.tail())

    # check if both files are the same upto sorting: os.listdir(subs_folder) doesn't gurantee to return
    # a sorted list. Now code ensures the list of filenames to be lexicographically sorted.
    # old = pd.read_csv('./data/asr_en_new.csv', sep='\t').values.tolist()
    # new = pd.read_csv('./data/asr_en.csv', sep='\t').values.tolist()
    # print(len(old), len(new))
    # from tqdm import tqdm
    # for line in tqdm(old):
    #     assert line in new
    # for line in tqdm(new):
    #     assert line in old

Since I didn't have ground truth proposals for this module of my project, I made some changes(put "none" where I didn't have anything, etc.). Also some changes in parsing timestamp, because my ASR output is such.
I ran it as python parse_subs.py -v <video.mp4>
I changed the main.py to this.

import argparse
from time import strftime, localtime
from shutil import copytree, ignore_patterns
import h5py

import numpy as np
import torch
from torch.utils.data import DataLoader
from torch.utils import tensorboard as tensorboard
# import tensorboardX as tensorboard

from model.transformer import SubsAudioVideoTransformer
from dataset.dataset import ActivityNetCaptionsIteratorDataset
from loss.loss import LabelSmoothing, SimpleLossCompute
from scheduler.lr_scheduler import SimpleScheduler
from epoch_loop.run_epoch import training_loop, validation_next_word_loop, greedy_decoder
from epoch_loop.run_epoch import save_model, validation_1by1_loop, average_metrics_in_two_dicts
from utils import timer

from dataset.dataset import load_multimodal_features_from_h5, filter_features
from run import predict

class Config(object):
    '''
    Note: don't change the methods of this class later in code.
    '''

    def __init__(self, args):
        '''
        Try not to create anything here: like new forders or something
        '''
        self.curr_time = strftime('%y%m%d%H%M%S', localtime())
        # dataset
        self.train_meta_path = args.train_meta_path
        self.val_1_meta_path = args.val_1_meta_path
        self.val_2_meta_path = args.val_2_meta_path
        self.val_prop_meta_path = args.val_prop_meta_path
        self.modality = args.modality
        self.video_feature_name = args.video_feature_name
        self.video_features_path = args.video_features_path
        self.filter_video_feats = args.filter_video_feats
        self.average_video_feats = args.average_video_feats
        self.audio_feature_name = args.audio_feature_name
        self.audio_features_path = args.audio_features_path
        self.filter_audio_feats = args.filter_audio_feats
        self.average_audio_feats = args.average_audio_feats
        self.use_categories = args.use_categories
        if self.use_categories:
            self.video_categories_meta_path = args.video_categories_meta_path
        # make them d_video and d_audio
        self.d_vid = args.d_vid
        self.d_aud = args.d_aud
        self.start_token = args.start_token
        self.end_token = args.end_token
        self.pad_token = args.pad_token
        self.max_len = args.max_len
        self.min_freq = args.min_freq
        # model
        self.model = args.model
        self.dout_p = args.dout_p
        self.N = args.N
        self.use_linear_embedder = args.use_linear_embedder
        if args.use_linear_embedder:
            self.d_model_video = args.d_model_video
            self.d_model_audio = args.d_model_audio
        else:
            self.d_model_video = self.d_vid
            self.d_model_audio = self.d_aud
        self.d_model_subs = args.d_model_subs
        if self.model == 'transformer':
            self.H = args.H
            self.d_ff_video = args.d_ff_video
            self.d_ff_audio = args.d_ff_audio
            self.d_ff_subs = args.d_ff_subs
            if self.use_categories:
                self.d_cat = args.d_cat
        elif self.model == 'bi_gru':
            pass
        else:
            raise Exception(f'Undefined model: "{self.model}"')

        # training
        self.device_ids = args.device_ids
        self.device = f'cuda:{self.device_ids[0]}'
        self.train_batch_size = args.B * len(self.device_ids)
        self.inference_batch_size = args.inf_B_coeff * self.train_batch_size
        self.start_epoch = args.start_epoch # todo: pretraining
        self.epoch_num = args.epoch_num
        self.one_by_one_starts_at = args.one_by_one_starts_at
        self.early_stop_after = args.early_stop_after
        # criterion
        self.criterion = args.criterion
        self.smoothing = args.smoothing # 0 == cross entropy
        # optimizer
        self.optimizer = args.optimizer
        if self.optimizer == 'adam':
            self.beta1, self.beta2 = args.betas
            self.eps = args.eps
        else:
            raise Exception(f'Undefined optimizer: "{self.optimizer}"')
        # lr scheduler
        self.scheduler = args.scheduler
        if self.scheduler == 'attention_is_all_you_need':
            self.lr_coeff = args.lr_coeff
            self.warmup_steps = args.warmup_steps
        elif self.scheduler == 'constant':
            self.lr = args.lr
        else:
            raise Exception(f'Undefined scheduler: "{self.scheduler}"')
        # evaluation
        self.reference_paths = args.reference_paths
        self.tIoUs = args.tIoUs
        self.max_prop_per_vid = args.max_prop_per_vid
        self.verbose_evaluation = args.verbose_evaluation
        # logging
        self.to_log = args.to_log
        self.videos_to_monitor = args.videos_to_monitor
        if args.to_log:
            self.log_dir = args.log_dir
            self.checkpoint_dir = self.log_dir # the same yes
            exper_name = self.make_experiment_name()
            self.comment = args.comment
            self.log_path = os.path.join(self.log_dir, exper_name)
            self.model_checkpoint_path = os.path.join(self.checkpoint_dir, exper_name)
        else:
            self.log_dir = None
            self.log_path = None

    def make_experiment_name(self):
        return self.curr_time[2:]

    def get_params(self, out_type):

        if out_type == 'md_table':
            table  = '| Parameter | Value | \n'
            table += '|-----------|-------| \n'

            for par, val in vars(self).items():
                table += f'| {par} | {val}| \n'

            return table

        elif out_type == 'dict':
            params_to_filter = [
                'model_checkpoint_path', 'log_path', 'comment', 'curr_time',
                'checkpoint_dir', 'log_dir', 'videos_to_monitor', 'to_log',
                'verbose_evaluation', 'tIoUs', 'reference_paths',
                'one_by_one_starts_at', 'device', 'device_ids', 'pad_token',
                'end_token', 'start_token', 'val_1_meta_path', 'video_feature_name',
                'val_2_meta_path', 'train_meta_path', 'betas', 'path'
            ]
            dct = vars(self)
            dct = {k: v for k, v in dct.items() if (k not in params_to_filter) and (v is not None)}

            return dct

    def self_copy(self):

        if self.to_log:
            # let it be in method's arguments (for TBoard)
            self.path = os.path.realpath(__file__)
            pwd = os.path.split(self.path)[0]
            cp_path = os.path.join(self.model_checkpoint_path, 'wdir_copy')
            copytree(pwd, cp_path, ignore=ignore_patterns('todel', 'submodules', '.git'))


def main(cfg, video_name):
    ###########################################################################
    ######################### Some reminders to print #########################
    ###########################################################################
    if cfg.to_log:
        print(f'log_path: {cfg.log_path}')
        print(f'model_checkpoint_path: {cfg.model_checkpoint_path}')
    ###########################################################################
    torch.manual_seed(0)
    np.random.seed(0)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

    torch.cuda.set_device(cfg.device_ids[0])

    train_dataset = ActivityNetCaptionsIteratorDataset(
        cfg.start_token, cfg.end_token, cfg.pad_token, cfg.min_freq,
        cfg.train_batch_size, cfg.video_features_path, cfg.video_feature_name,
        cfg.filter_video_feats, cfg.average_video_feats,
        cfg.audio_features_path, cfg.audio_feature_name,
        cfg.filter_audio_feats, cfg.average_audio_feats,
        cfg.train_meta_path, cfg.val_1_meta_path,
        cfg.val_2_meta_path, torch.device(cfg.device), 'train', cfg.modality,
        cfg.use_categories, props_are_gt=True, get_full_feat=False
    )
#     val_1_dataset = ActivityNetCaptionsIteratorDataset(
#         cfg.start_token, cfg.end_token, cfg.pad_token, cfg.min_freq,
#         cfg.inference_batch_size, cfg.video_features_path, cfg.video_feature_name,
#         cfg.filter_video_feats, cfg.average_video_feats,
#         cfg.audio_features_path, cfg.audio_feature_name,
#         cfg.filter_audio_feats, cfg.average_audio_feats,  cfg.train_meta_path, cfg.val_1_meta_path,
#         cfg.val_2_meta_path, torch.device(cfg.device), 'val_1', cfg.modality,
#         cfg.use_categories, props_are_gt=True, get_full_feat=False
#     )
#     val_2_dataset = ActivityNetCaptionsIteratorDataset(
#         cfg.start_token, cfg.end_token, cfg.pad_token, cfg.min_freq,
#         cfg.inference_batch_size, cfg.video_features_path, cfg.video_feature_name,
#         cfg.filter_video_feats, cfg.average_video_feats,
#         cfg.audio_features_path, cfg.audio_feature_name,
#         cfg.filter_audio_feats, cfg.average_audio_feats, cfg.train_meta_path, cfg.val_1_meta_path,
#         cfg.val_2_meta_path, torch.device(cfg.device), 'val_2', cfg.modality,
#         cfg.use_categories, props_are_gt=True, get_full_feat=False
#     )
#     # 'val_1' in phase doesn't really matter because props are for validation set
#     # cfg.val_1_meta_path -> cfg.val_prop_meta
#     val_pred_prop_dataset = ActivityNetCaptionsIteratorDataset(
#         cfg.start_token, cfg.end_token, cfg.pad_token, cfg.min_freq,
#         cfg.inference_batch_size, cfg.video_features_path, cfg.video_feature_name,
#         cfg.filter_video_feats, cfg.average_video_feats,
#         cfg.audio_features_path, cfg.audio_feature_name,
#         cfg.filter_audio_feats, cfg.average_audio_feats, cfg.train_meta_path,
#         cfg.val_prop_meta_path,
#         cfg.val_2_meta_path, torch.device(cfg.device), 'val_1', cfg.modality,
#         cfg.use_categories, props_are_gt=False, get_full_feat=False
#     )

    # make sure that DataLoader has batch_size = 1!
    train_loader = DataLoader(train_dataset, collate_fn=train_dataset.dont_collate)
#     val_1_loader = DataLoader(val_1_dataset, collate_fn=val_1_dataset.dont_collate)
#     val_2_loader = DataLoader(val_2_dataset, collate_fn=val_2_dataset.dont_collate)
#     val_pred_prop_loader = DataLoader(val_pred_prop_dataset, collate_fn=val_2_dataset.dont_collate)

#     best_model_path = os.path.join(, 'best_model.pt')


    model = SubsAudioVideoTransformer(
        train_dataset.trg_voc_size, train_dataset.subs_voc_size,
        cfg.d_aud, cfg.d_vid, cfg.d_model_audio, cfg.d_model_video,
        cfg.d_model_subs,
        cfg.d_ff_audio, cfg.d_ff_video, cfg.d_ff_subs,
        cfg.N, cfg.N, cfg.N, cfg.dout_p, cfg.H, cfg.use_linear_embedder
    )



#     criterion = LabelSmoothing(cfg.smoothing, train_dataset.pad_idx)

#     # lr = 0 here have no impact on training (see lr scheduler)
#     optimizer = torch.optim.Adam(
#         model.parameters(), 0, (cfg.beta1, cfg.beta2), cfg.eps
#     )
#     lr_scheduler = SimpleScheduler(optimizer, cfg.lr)
#     loss_compute = SimpleLossCompute(criterion, lr_scheduler)

    model.to(torch.device(cfg.device))
    # haven't tested for multi GPU for a while -- might not work.
    model = torch.nn.DataParallel(model, cfg.device_ids)

    checkpoint = torch.load('./model/best_model.pt')
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()


    hdf5_vggish_features_path = './video/videos/' + video_name + '_vggish.hdf5'
    hdf5_i3d_features_path = './video/videos/' + video_name + '_i3d.hdf5'

    h5py.File(hdf5_vggish_features_path, 'w').close()
    h5py.File(hdf5_i3d_features_path, 'w').close()

    # start context managers ('a' == append) for both files
    with h5py.File(hdf5_vggish_features_path, 'a') as hd5vgg, h5py.File(hdf5_i3d_features_path, 'a') as hd5i3d:
        # the for-loop
    #     for path in paths:
            # construct new paths
        path = './video/videos/' + video_name
        vggish_path = f'{path}_vggish.npy'
        rgb_path = f'{path}_rgb.npy'
        flow_path = f'{path}_flow.npy'

            # loading numpy files
        vggish = np.load(vggish_path)
        rgb = np.load(rgb_path)
        flow = np.load(flow_path)

            # extract video names from the paths (relying on the rgb path only)
            # os.path.split() outputs a list which contains parent dir path [0] and filename [1]
            # and removing the part with '_rgb.npy' (`video_1_rgb.npy` -> `video_1`)
#         video_name = 'adidas' #os.path.split(rgb_path)[-1].replace('_rgb.npy', '')

            # append features to the hdf5 files
            # VGGish
        hd5vgg.create_dataset(f'{video_name}/vggish_features', data=vggish)
            # RGB
        hd5i3d.create_dataset(f'{video_name}/i3d_features/rgb', data=rgb)
        hd5i3d.create_dataset(f'{video_name}/i3d_features/flow', data=flow)



    val_dataset = ActivityNetCaptionsIteratorDataset(
        '<s>', '</s>', '<blank>', 1,
        28, './data/sub_activitynet_v1-3.i3d_25fps_stack24step24_2stream.hdf5', 'i3d',
        False, False,
        './data/sub_activitynet_v1-3.vggish.hdf5', 'vggish',
        False, False,
        './data/train_meta.csv', './data/val_1_meta.csv', './data/val_2_meta.csv',
        torch.device('cuda:0'), 'val_1', 'subs_audio_video',
        False, props_are_gt=True, get_full_feat=False
    )

    val_loader = DataLoader(val_dataset, collate_fn=val_dataset.dont_collate)


#     vid_id_list = [video_name]




    res = predict('./video/videos/meta_' + video_name + '.csv', hdf5_vggish_features_path, hdf5_i3d_features_path, video_name, val_loader, greedy_decoder, model)
    print(res)

    file1 = open('./video/videos/' + video_name + ".txt", "w")
    file1.write(res)
    file1.close()



#     param_num = sum(p.numel() for p in model.parameters() if p.requires_grad)
#     print(f'Param Num: {param_num}')



#     if cfg.to_log:
#         os.makedirs(cfg.log_path)
#         os.makedirs(cfg.model_checkpoint_path, exist_ok=True) # handles the case when model_checkpoint_path = log_path
#         TBoard = tensorboard.SummaryWriter(log_dir=cfg.log_path)
#         TBoard.add_text('config', cfg.get_params('md_table'), 0)
#         TBoard.add_text('config/comment', cfg.comment, 0)
#         TBoard.add_scalar('debug/param_number', param_num, 0)
#     else:
#         TBoard = None

#     # keeping track of the best model
#     best_metric = 0
#     # "early stopping" thing
#     num_epoch_best_metric_unchanged = 0


#     val_1_metrics = validation_1by1_loop(
#                 model, val_1_loader, greedy_decoder, loss_compute, lr_scheduler,
#                 epoch, cfg.max_len, cfg.log_path,
#                 cfg.verbose_evaluation, [cfg.reference_paths[0]], cfg.tIoUs,
#                 cfg.max_prop_per_vid, TBoard, cfg.modality, cfg.use_categories,
#             )

#     for epoch in range(cfg.start_epoch, cfg.epoch_num):
#         num_epoch_best_metric_unchanged += 1

#         if (num_epoch_best_metric_unchanged == cfg.early_stop_after) or (timer(cfg.curr_time) > 67):
#             print(f'Early stop at {epoch}: unchanged for {num_epoch_best_metric_unchanged} epochs')
#             print(f'Current timer: {timer(cfg.curr_time)}')
#             break


#         # train
#         training_loop(
#             model, train_loader, loss_compute, lr_scheduler, epoch, TBoard,
#             cfg.modality, cfg.use_categories
#         )
#         # validation (next word)
#         val_1_loss = validation_next_word_loop(
#             model, val_1_loader, greedy_decoder, loss_compute, lr_scheduler,
#             epoch, cfg.max_len, cfg.videos_to_monitor, TBoard, cfg.modality,
#             cfg.use_categories
#         )
#         val_2_loss = validation_next_word_loop(
#             model, val_2_loader, greedy_decoder, loss_compute, lr_scheduler,
#             epoch, cfg.max_len, cfg.videos_to_monitor, TBoard, cfg.modality,
#             cfg.use_categories
#         )

#         val_loss_avg = (val_1_loss + val_2_loss) / 2

#         # validation (1-by-1 word)
#         if epoch >= cfg.one_by_one_starts_at:
#             # validation with g.t. proposals
#             val_1_metrics = validation_1by1_loop(
#                 model, val_1_loader, greedy_decoder, loss_compute, lr_scheduler,
#                 epoch, cfg.max_len, cfg.log_path,
#                 cfg.verbose_evaluation, [cfg.reference_paths[0]], cfg.tIoUs,
#                 cfg.max_prop_per_vid, TBoard, cfg.modality, cfg.use_categories,
#             )
#             val_2_metrics = validation_1by1_loop(
#                 model, val_2_loader, greedy_decoder, loss_compute, lr_scheduler,
#                 epoch, cfg.max_len, cfg.log_path,
#                 cfg.verbose_evaluation, [cfg.reference_paths[1]], cfg.tIoUs,
#                 cfg.max_prop_per_vid, TBoard, cfg.modality, cfg.use_categories,
#             )

#             if cfg.to_log:
#                 # averaging metrics obtained from val_1 and val_2
#                 metrics_avg = average_metrics_in_two_dicts(val_1_metrics, val_2_metrics)
#                 metrics_avg = metrics_avg['Average across tIoUs']

#                 TBoard.add_scalar('metrics/val_loss_avg', val_loss_avg, epoch)
#                 TBoard.add_scalar('metrics/meteor', metrics_avg['METEOR'] * 100, epoch)
#                 TBoard.add_scalar('metrics/bleu4', metrics_avg['Bleu_4'] * 100, epoch)
#                 TBoard.add_scalar('val_avg/bleu3', metrics_avg['Bleu_3'] * 100, epoch)
#                 TBoard.add_scalar('val_avg/bleu2', metrics_avg['Bleu_2'] * 100, epoch)
#                 TBoard.add_scalar('val_avg/bleu1', metrics_avg['Bleu_1'] * 100, epoch)
#                 TBoard.add_scalar('val_avg/rouge_l', metrics_avg['ROUGE_L'] * 100, epoch)
#                 TBoard.add_scalar('val_avg/cider', metrics_avg['CIDEr'] * 100, epoch)
#                 TBoard.add_scalar('val_avg/precision', metrics_avg['Precision'] * 100, epoch)
#                 TBoard.add_scalar('val_avg/recall', metrics_avg['Recall'] * 100, epoch)

#                 # saving the model if it is better than the best so far
#                 if best_metric < metrics_avg['METEOR']:
#                     best_metric = metrics_avg['METEOR']

#                     save_model(
#                         cfg, epoch, model, optimizer, val_1_loss, val_2_loss,
#                         val_1_metrics, val_2_metrics, train_dataset.trg_voc_size
#                     )
#                     # reset the early stopping criterion
#                     num_epoch_best_metric_unchanged = 0

#                 # put it after: so on zeroth epoch it is not zero
#                 TBoard.add_scalar('val_avg/best_metric_meteor', best_metric * 100, epoch)

#     if cfg.to_log:
#         # load the best model
#         best_model_path = os.path.join(cfg.model_checkpoint_path, 'best_model.pt')
#         checkpoint = torch.load(best_model_path)
#         model.load_state_dict(checkpoint['model_state_dict'])
#         val_metrics_pred_prop = validation_1by1_loop(
#             model, val_pred_prop_loader, greedy_decoder, loss_compute, lr_scheduler,
#             checkpoint['epoch'], cfg.max_len, cfg.log_path,
#             cfg.verbose_evaluation, cfg.reference_paths, cfg.tIoUs,
#             cfg.max_prop_per_vid, TBoard, cfg.modality, cfg.use_categories
#         )
#         best_metric_pred_prop = val_metrics_pred_prop['Average across tIoUs']['METEOR']
#         print(f'best_metric: {best_metric}')
#         print(f'best_metric_pred_prop: {best_metric_pred_prop}')
#         TBoard.close()


if __name__ == "__main__":

    parser = argparse.ArgumentParser(description='Run experiment')

    parser.add_argument(
        '-v', '--video', type=str, default='Skate',
        help='name of the srt file w/o .srt'
    )
    parser.add_argument(
        '--train_meta_path', type=str, default='./data/train_meta.csv',
        help='path to the precalculated train meta file'
    )
    parser.add_argument(
        '--val_1_meta_path', type=str, default='./data/val_1_meta.csv',
        help='path to the precalculated val 1 meta file'
    )
    parser.add_argument(
        '--val_2_meta_path', type=str, default='./data/val_2_meta.csv',
        help='path to the precalculated val 2 meta file'
    )
    parser.add_argument(
        '--val_prop_meta_path', type=str, default='./data/bafcg_val_100_proposal_result.csv',
        help='path to the precalculated proposals on the validation set'
    )
    parser.add_argument(
        '--dont_log', dest='to_log', action='store_false',
        help='Prevent logging in the experiment.'
    )
    parser.add_argument(
        '--device_ids', type=int, nargs='+', default=[0],
        help='device indices separated by a whitespace'
    )
    parser.add_argument(
        '--use_categories', dest='use_categories', action='store_false',
        help='whether to condition the model on categories'
    )
    parser.add_argument(
        '--video_categories_meta_path', type=str, default='./data/videoCategoriesMetaUS.json',
        help='Path to the categories meta from Youtube API: \
        https://developers.google.com/youtube/v3/docs/videoCategories/list'
    )
    parser.add_argument(
        '--d_cat', type=int,
        help='size of the category embedding layer'
    )
    parser.add_argument(
        '--modality', type=str, default='subs_audio_video',
        choices=['audio', 'video', 'audio_video', 'subs_audio_video'],
    )
    parser.add_argument('--video_feature_name', type=str, default='i3d')
    parser.add_argument(
        '--video_features_path', type=str,
        default='./data/sub_activitynet_v1-3.i3d_25fps_stack24step24_2stream.hdf5'
    )
    parser.add_argument('--audio_feature_name', type=str, default='vggish')
    parser.add_argument(
        '--audio_features_path', type=str, default='./data/sub_activitynet_v1-3.vggish.hdf5'
    )
    parser.add_argument('--d_vid', type=int, default=1024)
    parser.add_argument('--d_aud', type=int, default=128)
    parser.add_argument(
        '--filter_video_feats', dest='filter_video_feats', action='store_true',
        help='filter video features (removes overlap 16/8 -> 16/16).'
    )
    parser.add_argument(
        '--average_video_feats', dest='average_video_feats', action='store_true',
        help='averages video features (designed for c3d: 16x4 -> 16 (the same time span)).'
    )
    parser.add_argument(
        '--filter_audio_feats', dest='filter_audio_feats', action='store_true',
        help='filter video features (removes overlap 16/8 -> 16/16).'
    )
    parser.add_argument(
        '--average_audio_feats', dest='average_audio_feats', action='store_true',
        help='averages audio features.'
    )
    parser.add_argument(
        '--start_token', type=str, default='<s>',
        help='starting token'
    )
    parser.add_argument(
        '--end_token', type=str, default='</s>',
        help='ending token'
    )
    parser.add_argument(
        '--pad_token', type=str, default='<blank>',
        help='padding token'
    )
    parser.add_argument(
        '--max_len', type=int, default=50,
        help='maximum size of 1by1 prediction'
    )
    parser.add_argument(
        '--min_freq', type=int, default=1,
        help='to be in the vocab a word should appear min_freq times in train dataset'
    )
    parser.add_argument('--model', type=str, default='transformer')
    parser.add_argument('--dout_p', type=float, default=0.1)
    parser.add_argument('--N', type=int, default=1, help='number of layers in a model')
    parser.add_argument(
        '--use_linear_embedder', dest='use_linear_embedder', action='store_true',
        help='Whether to include a dense layer between vid features and RNN'
    )
    parser.add_argument(
        '--d_model_video', type=int,
        help='If use_linear_embedder is true, this is going to be the d_model size for video model'
    )
    parser.add_argument(
        '--d_model_audio', type=int,
        help='If use_linear_embedder is true, this is going to be the d_model size for audio model'
    )
    parser.add_argument('--d_model_subs', type=int, default=512)
    parser.add_argument(
        '--H', type=int, default=4,
        help='number of heads in multiheaded attention in Transformer'
    )
    parser.add_argument(
        '--d_ff_video', type=int, default=2048,
        help='size of the internal layer of PositionwiseFeedForward net in Transformer (Video)'
    )
    parser.add_argument(
        '--d_ff_audio', type=int, default=2048,
        help='size of the internal layer of PositionwiseFeedForward net in Transformer (Audio)'
    )
    parser.add_argument(
        '--d_ff_subs', type=int, default=2048,
        help='size of the internal layer of PositionwiseFeedForward net in Transformer (Subs)'
    )
    parser.add_argument(
        '--B', type=int, default=28,
        help='batch size per a device'
    )
    parser.add_argument(
        '--inf_B_coeff', type=int, default=2,
        help='the batch size on inference is inf_B_coeff times the B'
    )
    parser.add_argument(
        '--start_epoch', type=int, default=0, choices=[0],
        help='the epoch number to start training (if specified, pretraining a net from start_epoch epoch)'
    )
    parser.add_argument(
        '--epoch_num', type=int, default=45,
        help='number of epochs to train'
    )
    parser.add_argument(
        '--one_by_one_starts_at', type=int, default=0,
        help='number of epochs to skip before starting 1-by-1 validation'
    )
    parser.add_argument(
        '--early_stop_after', type=int, default=50,
        help='number of epochs to wait for best metric to change before stopping'
    )
    parser.add_argument(
        '--criterion', type=str, default='label_smoothing', choices=['label_smoothing'],
        help='criterion to measure the loss'
    )
    parser.add_argument(
        '--smoothing', type=float, default=0.7,
        help='smoothing coeff (= 0 cross ent loss; -> 1 more smoothing, random labels) must be in [0, 1]'
    )
    parser.add_argument(
        '--optimizer', type=str, default='adam', choices=['adam'],
        help='optimizer'
    )
    parser.add_argument(
        '--betas', type=float, nargs=2, default=[0.9, 0.98],
        help='beta 1 and beta 2 parameters in adam'
    )
    parser.add_argument(
        '--eps', type=float, default=1e-8,
        help='eps parameter in adam'
    )
    parser.add_argument(
        '--scheduler', type=str, default='constant', choices=['attention_is_all_you_need', 'constant'],
        help='lr scheduler'
    )
    parser.add_argument(
        '--lr_coeff', type=float,
        help='lr scheduler coefficient (if scheduler is attention_is_all_you_need)'
    )
    parser.add_argument(
        '--warmup_steps', type=int,
        help='number of "warmup steps" (if scheduler is attention_is_all_you_need)'
    )
    parser.add_argument('--lr', type=float, default=1e-5, help='lr (if scheduler is constant)')
    parser.add_argument(
        '--reference_paths', type=str, default=['./data/val_1.json', './data/val_2.json'],
        nargs='+',
        help='reference paths for 1-by-1 validation'
    )
    parser.add_argument(
        '--tIoUs', type=float, default=[0.3, 0.5, 0.7, 0.9], nargs='+',
        help='thresholds for tIoU to be used for 1-by-1 validation'
    )
    parser.add_argument(
        '--max_prop_per_vid', type=int, default=1000,
        help='max number of proposal to take into considetation for 1-by-1 validation'
    )
    parser.add_argument(
        '--dont_verbose_evaluation', dest='verbose_evaluation', action='store_false',
        help='dont verbose the evaluation server in 1-by-1 validation (no Precision and R)'
    )
    parser.add_argument('--log_dir', type=str, default='./log/')
    parser.add_argument(
        '--videos_to_monitor', type=str, nargs='+',
        default=['v_GGSY1Qvo990', 'v_bXdq2zI1Ms0', 'v_aLv03Fznf5A'],
        help='the videos to monitor on validation loop with 1 by 1 prediction'
    )
    parser.add_argument('--comment', type=str, default='', help='comment for the experiment')

    parser.set_defaults(to_log=True)
    parser.set_defaults(filter_video_feats=False)
    parser.set_defaults(average_video_feats=False)
    parser.set_defaults(filter_audio_feats=False)
    parser.set_defaults(average_audio_feats=False)
    parser.set_defaults(use_linear_embedder=False)
    parser.set_defaults(verbose_evaluation=True)
    parser.set_defaults(use_categories=False)

    args = parser.parse_args()

    video_name = str(args.video)
    # print(args)
    cfg = Config(args)
    main(cfg, video_name)

And I created run.py in the same folder (/MDVC)

import h5py
import numpy as np

import torch
from torch.utils.data import DataLoader
from dataset.dataset import ActivityNetCaptionsIteratorDataset
from dataset.dataset import load_multimodal_features_from_h5, filter_features
from epoch_loop.run_epoch import greedy_decoder, encode_subs

import pandas as pd
# import h5py

# create empty hdf5 files on your disk

def predict(your_meta_path, vggish_hdf5_features_path, i3d_hdf5_features_path,
    vid_id, val_loader, decoder, model, max_len=100):
    '''
        your_meta_path: path to your .csv
        *_hdf5_features_path: path to your hdf5 files
        vid_ids_list: the ids which will be used to filter meta file. For example: ['video_1', 'video_2']
        val_loader: object defined above
        decoder: just pass the greedy_decoder function
        model: pass the pre-trained model
        max_len: largest caption possible. The generation will stop if exceeded
    '''
    # for dataframe example see `./data/val_1_meta.csv`. Make sure your video_id will correspond to the
    # filenames w/o extension (`video_1_rgb.npy` -> `video_1`) which you used to create hdf5 files
    # because load_multimodal_features_from_h5() will use them
    # as well as they should be present in `vid_ids_list` variable
    meta = pd.read_csv(your_meta_path, sep='\t')

#     print(meta)
    # re-define hdf5 files with your custom ones
    feat_h5_audio = h5py.File(vggish_hdf5_features_path, 'r')
    feat_h5_video = h5py.File(i3d_hdf5_features_path, 'r')

    feature_names = val_loader.dataset.feature_names
    device = val_loader.dataset.device
    start_idx = val_loader.dataset.start_idx
    end_idx = val_loader.dataset.end_idx
    pad_idx = val_loader.dataset.pad_idx
    modality = val_loader.dataset.modality

    text = ''

#     for vid_id in vid_ids_list:
    meta_subset = meta[meta['video_id'] == vid_id]
#         print(meta_subset)
    text += f'\t {vid_id} \n'

    for (video_id, cap, start, end, duration, category, subs, phase, idx) in meta_subset.values:
#             print(subs)
        feature_names_list = val_loader.dataset.features_dataset.feature_names_list
        train_subs_vocab = val_loader.dataset.train_subs_vocab

            # rgb is padded with pad_idx; flow is padded with 0s: expected to be summed later
        video_stack_rgb, video_stack_flow, audio_stack = load_multimodal_features_from_h5(
            feat_h5_video, feat_h5_audio, feature_names_list, video_id, start, end, duration
        )

        subs_stack = encode_subs(train_subs_vocab, idx, meta, start_idx, end_idx)

        video_stack_rgb = video_stack_rgb.unsqueeze(0).to(device)
        video_stack_flow = video_stack_flow.unsqueeze(0).to(device)
        audio_stack = audio_stack.unsqueeze(0).to(device)
        subs_stack = subs_stack.unsqueeze(0).to(device)

        stack = video_stack_rgb + video_stack_flow, audio_stack, subs_stack

        trg_ints = decoder(model, stack, max_len, start_idx, end_idx, pad_idx, modality)
        trg_ints = trg_ints.cpu().numpy()[0]
        trg_words = [val_loader.dataset.train_vocab.itos[i] for i in trg_ints]
        en_sent = ' '.join(trg_words)

        text += f'\t P sent: {en_sent} \n'
        text += f'\t P proposals: {start//60:.0f}:{start%60:02.0f} {end//60:.0f}:{end%60:02.0f} '

    text += '\t \n'

    return text

Then I call python main.py -v <video_name(without extension)>
This is for the configuration that I have according to the rest of my code. So it is not very flexible for taking input from different folders, etc.

@v-iashin
Copy link
Owner

Cool thanks.

@v-iashin v-iashin mentioned this issue Oct 8, 2020
Closed
@siyamsajeebkhan
Copy link

Hi Vladimir,
I read the whole thread and wanted to ask you a question regarding the proposal generator. As you have already mentioned that you used BAFCG (Bi-SST) for the proposal generation and I checked the repo and it seems their implementation is in tensorflow. Do you have a pytorch implementation of that?

@v-iashin
Copy link
Owner

Hi,

Unfortunately, we don’t have a pytorch implementation.

@siyamsajeebkhan
Copy link

Thanks for your prompt reply. So, did you use their tensorflow implementation to generate the proposals and then use them for the captioning?

@v-iashin
Copy link
Owner

Nope, we just used the predicted proposals.

@wjy3326
Copy link

wjy3326 commented Jan 6, 2021

Sure!
This is one of the examples I tested it on.
Adidas ad https://www.youtube.com/watch?v=DpR50O1nGNs&feature=youtu.be
Below are the subtitles I generated using ASR and the corresponding captions generated by your code.
1
0.2-->12.4
have a great first day lockers on the left good morning everyone activities are about to commence 42.

2
12.4-->16.1
I just want to say what's remember why we're here we're here to change things.

3
16.1-->23.4
we can all do amazing things in our own but together we can do so much more to make things better than they were before.

4
23.4-->52.1
having a teacher in your craft and a student in someone else's oh nevermind me say hi to Jen soda our newest team member share your strength and receive strength and most importantly have fun if you're smiling you're doing it right change is a team sport and we'd.

5
52.1-->53.2
be honored to be on a team with you

 adidas 
 P sent: <s> the man is now standing in the court and the other men are running and the other man is shown in the middle of the court </s> 
     P proposals: 0:00 0:12 	 P sent: <s> a man is talking to the camera </s> 
 P proposals: 0:12 0:16 	 P sent: <s> a man is seen standing in front of a large crowd </s> 
 P proposals: 0:16 0:23 	 P sent: <s> the man is then seen speaking to the camera and leads into several clips of people playing the game </s> 
 P proposals: 0:23 0:52 	 P sent: <s> a man is standing in front of a large green field </s> 
 P proposals: 0:52 0:53 	 

`

Could you please share your code for dense video captioning on raw input videos? Thanks!

@harpavatkeerti
Copy link
Author

Unfortunately, I don't have the code for this, as I switched to the BMT code provided by @v-iashin (https://github.com/v-iashin/BMT#single-video-prediction).

@wjy3326
Copy link

wjy3326 commented Jan 7, 2021

Unfortunately, I don't have the code for this, as I switched to the BMT code provided by @v-iashin (https://github.com/v-iashin/BMT#single-video-prediction).

Thanks for provide the code, how about the result of dense video generation? Could it be used in practical applications? Thanks!

@taruntiwarihp
Copy link

taruntiwarihp commented Jun 6, 2021

@v-iashin
@harpavatkeerti

Thanks a lot for your help, this code works for predicting captions for a single video, and the results are also good.
(Just a minor change for anybody else who uses this, change hd5vgg.create_dataset(f'{video_name}/vggish_features', vggish.size(), data=vggish) to
hd5vgg.create_dataset(f'{video_name}/vggish_features', data=vggish) and similarly for the others.) Rest of the code is working perfectly fine.)

Hi Keerti, I hope you doing great!
(I want to get predicted single video captioning) I was also doing the same thing with BMT, it was easy to run that script, but I want to run with MDVC, I had a lot of issues running that code can you please help me with that. Thanks

@harpavatkeerti
Copy link
Author

Hi @taruntiwarihp, unfortunately, I don't have the code regarding MDVC. As I mentioned above, I switched to BMT myself. All I had for MDVC is mentioned in the issues above.

@taruntiwarihp
Copy link

Hi @taruntiwarihp, unfortunately, I don't have the code regarding MDVC. As I mentioned above, I switched to BMT myself. All I had for MDVC is mentioned in the issues above.

No Problem @harpavatkeerti I can do it myself .. once I did I'll share it with you.
Thanks for reply

@mtthwryn
Copy link

mtthwryn commented Aug 3, 2023

Does anyone have code to run MDVC on our own videos?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants