Skip to content

Commit

Permalink
Add AutoConfig import and update least_token_number calculation
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhangYuanhan committed Mar 8, 2024
1 parent 9049f5c commit 376d988
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion llava/eval/model_video_description_from_t2v.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import math

from transformers import AutoConfig


def split_list(lst, n):
Expand Down Expand Up @@ -107,7 +108,15 @@ def run_inference(args):
overwrite_config["mm_spatial_pool_stride"] = args.mm_spatial_pool_stride
overwrite_config["mm_spatial_pool_out_channels"] = args.mm_spatial_pool_out_channels
overwrite_config["mm_spatial_pool_mode"] = args.mm_spatial_pool_mode
least_token_number = args.for_get_frames_num*(24//args.mm_spatial_pool_stride)**2
overwrite_config["patchify_video_feature"] = False

cfg_pretrained = AutoConfig.from_pretrained(args.model_path)

if "224" in cfg_pretrained.mm_vision_tower:
# suppose the length of text tokens is around 1000, from bo's report
least_token_number = args.for_get_frames_num*(16//args.mm_spatial_pool_stride)**2 + 1000
else:
least_token_number = args.for_get_frames_num*(24//args.mm_spatial_pool_stride)**2 + 1000

scaling_factor = math.ceil(least_token_number/4096)
if scaling_factor >= 2:
Expand Down

0 comments on commit 376d988

Please sign in to comment.