Skip to content

Commit

Permalink
Update dependencies and enable flash attention 2; update demo
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhangYuanhan-AI committed May 6, 2024
1 parent 8b2e15b commit 9fcbebe
Show file tree
Hide file tree
Showing 8 changed files with 127 additions and 28 deletions.
56 changes: 56 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Python
__pycache__
*.pyc
*.egg-info
dist

# Log
*.log
*.log.*
# *.json
# *.jsonl

# Data
!**/alpaca-data-conversation.json
# Editor
.idea
*.swp
.vscode

# Other
.DS_Store
wandb
output

checkpoints
project_checkpoints
debug_checkpoints
playground/data
playground/cc3m_llava34b_cap
ckpts*

.ipynb_checkpoints
chunyl_scripts
*.ipynb

# DevContainer
!.devcontainer/*

# Demo
serve_images/
notebooks/
logs
playground/cc3m_llava34b_cap/progress_0_or_24.json
playground/cc3m_llava34b_cap/progress_1_or_24.json
playground/cc3m_llava34b_cap/progress_3_or_24.json
playground/cc3m_llava34b_cap/progress_5_or_24.json
playground/cc3m_llava34b_cap/progress_6_or_24.json
llava_instruct_json/combined_staged_instruct.json
scripts/dist_*
logs/
submissions/
cn_scripts/
internal_project_checkpoints/
scripts/cn_boli01_lf/.nfs00770000002b96f300000005
scripts/cn_boli01_lf/.nfs007700000045ae9000000001
work_dirs
6 changes: 3 additions & 3 deletions llava/model/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def load_from_hf(repo_id, filename, subfolder=None):
model = LlavaMptForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, **kwargs)
elif "mixtral" in model_name.lower() and "vicuna" not in model_name.lower() and "mistral" not in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = LlavaMixtralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, use_flash_attention_2=False, **kwargs)
model = LlavaMixtralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, use_flash_attention_2=True, **kwargs)
elif "mistral" in model_name.lower() or "zephyr" in model_name.lower():
tokenizer = AutoTokenizer.from_pretrained(model_path)
cfg_pretrained = AutoConfig.from_pretrained(model_path)
Expand All @@ -114,15 +114,15 @@ def load_from_hf(repo_id, filename, subfolder=None):
print(f"Overwriting config with {overwrite_config}")
for k, v in overwrite_config.items():
setattr(cfg_pretrained, k, v)
model = LlavaMistralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, use_flash_attention_2=False, config=cfg_pretrained, **kwargs)
model = LlavaMistralForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, use_flash_attention_2=True, config=cfg_pretrained, **kwargs)
else:
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
cfg_pretrained = AutoConfig.from_pretrained(model_path)
if overwrite_config is not None:
print(f"Overwriting config with {overwrite_config}")
for k, v in overwrite_config.items():
setattr(cfg_pretrained, k, v)
model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, config=cfg_pretrained, **kwargs)
model = LlavaLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=True, use_flash_attention_2=True, config=cfg_pretrained, **kwargs)
else:
# Load language model
if model_base is not None:
Expand Down
3 changes: 2 additions & 1 deletion llava/model/llava_arch.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def __init__(self, config):
self.vision_tower = build_vision_tower(config, delay_load=True)
self.vision_resampler = build_vision_resampler(config, vision_tower=self.vision_tower)
self.mm_projector = build_vision_projector(config, vision_cfg=self.vision_tower.config)

self.vision_resampler.mm_projector = self.mm_projector

if 'unpad' in getattr(config, 'mm_patch_merge_type', ''):
self.image_newline = nn.Parameter(
torch.empty(config.hidden_size, dtype=self.dtype)
Expand Down
48 changes: 37 additions & 11 deletions llava/model/multimodal_encoder/clip_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,51 @@ def __init__(self, vision_tower, args, delay_load=False):

self.vision_tower_name = vision_tower
self.select_layer = args.mm_vision_select_layer
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
self.select_feature = getattr(args, "mm_vision_select_feature", "patch")

if not delay_load:
self.load_model()
elif getattr(args, 'unfreeze_mm_vision_tower', False):
elif getattr(args, "unfreeze_mm_vision_tower", False):
# TODO: better detector is needed.
print(f"The checkpoint seems to contain `vision_tower` weights: `unfreeze_mm_vision_tower`: True.")
self.load_model()
else:
self.cfg_only = CLIPVisionConfig.from_pretrained(self.vision_tower_name)

def load_model(self):
def load_model(self, device_map=None):
if self.is_loaded:
print("{} is already loaded, `load_model` called again, skipping.".format(self.vision_tower_name))
return

# import pdb; pdb.set_trace()
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map=device_map)
self.vision_tower.requires_grad_(False)

self.is_loaded = True

def feature_select(self, image_forward_outs):
image_features = image_forward_outs.hidden_states[self.select_layer]
if self.select_feature == 'patch':
select_feature_type = self.select_feature

if self.select_feature in ["slicefour_patch", "slicefour_cls_patch"]:
select_every_k_layer = len(image_forward_outs.hidden_states) // 4
image_features = torch.cat([image_forward_outs.hidden_states[i] for i in range(select_every_k_layer + self.select_layer, len(image_forward_outs.hidden_states), select_every_k_layer)], dim=-1)
select_feature_type = select_feature_type.replace("slicefour_", "")
elif self.select_feature in ["slice_m25811_f6_patch", "slice_m25811_f6_cls_patch"]:
select_layers = [-2, -5, -8, -11, 6]
image_features = torch.cat([image_forward_outs.hidden_states[i] for i in select_layers], dim=-1)
select_feature_type = select_feature_type.replace("slice_m25811_f6_", "")
else:
image_features = image_forward_outs.hidden_states[self.select_layer]

if select_feature_type == "patch":
image_features = image_features[:, 1:]
elif self.select_feature == 'cls_patch':
elif select_feature_type == "cls_patch":
image_features = image_features
else:
raise ValueError(f'Unexpected select feature: {self.select_feature}')
raise ValueError(f"Unexpected select feature: {select_feature_type}")
return image_features

@torch.no_grad()
def forward(self, images):
if type(images) is list:
image_features = []
Expand Down Expand Up @@ -73,12 +91,20 @@ def config(self):

@property
def hidden_size(self):
return self.config.hidden_size
_hidden_size = self.config.hidden_size
if "slicefour" in self.select_feature:
_hidden_size *= 4
if "slice_m25811_f6" in self.select_feature:
_hidden_size *= 5
return _hidden_size

@property
def num_patches_per_side(self):
return self.config.image_size // self.config.patch_size

@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
_num_patches = (self.config.image_size // self.config.patch_size) ** 2
if "cls_patch" in self.select_feature:
_num_patches += 1
return _num_patches
24 changes: 16 additions & 8 deletions playground/demo/video_demo.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import argparse
import torch

from llavavid.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llavavid.conversation import conv_templates, SeparatorStyle
from llavavid.model.builder import load_pretrained_model
from llavavid.utils import disable_torch_init
from llavavid.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN
from llava.conversation import conv_templates, SeparatorStyle
from llava.model.builder import load_pretrained_model
from llava.utils import disable_torch_init
from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria

import json
import os
Expand All @@ -15,6 +15,7 @@

from transformers import AutoConfig

import time

import numpy as np

Expand Down Expand Up @@ -87,6 +88,7 @@ def run_inference(args):

cfg_pretrained = AutoConfig.from_pretrained(args.model_path)


if "224" in cfg_pretrained.mm_vision_tower:
# suppose the length of text tokens is around 1000, from bo's report
least_token_number = args.for_get_frames_num*(16//args.mm_spatial_pool_stride)**2 + 1000
Expand All @@ -95,8 +97,9 @@ def run_inference(args):

scaling_factor = math.ceil(least_token_number/4096)
if scaling_factor >= 2:
print(float(scaling_factor))
overwrite_config["rope_scaling"] = {"factor": float(scaling_factor), "type": "linear"}
if "vicuna" in cfg_pretrained._name_or_path.lower():
print(float(scaling_factor))
overwrite_config["rope_scaling"] = {"factor": float(scaling_factor), "type": "linear"}
overwrite_config["max_sequence_length"] = 4096 * scaling_factor
overwrite_config["tokenizer_model_max_length"] = 4096 * scaling_factor

Expand All @@ -114,7 +117,8 @@ def run_inference(args):

video_path = args.video_path
sample_set = {}
question = "Please provide a detailed description of the video, focusing on the main subjects, their actions, and the background scenes"
# question = "Please provide a detailed description of the video, focusing on the main subjects, their actions, and the background scenes"
question = "What does this video describe? A. Buiding B.Forest C.coutryside D.Moon \nAnswer with the option's letter from the given choices directly."
sample_set["Q"] = question
sample_set["video_name"] = args.video_path

Expand Down Expand Up @@ -150,7 +154,11 @@ def run_inference(args):
with torch.inference_mode():
model.update_prompt([[cur_prompt]])
# import pdb;pdb.set_trace()
start_time = time.time()
output_ids = model.generate(inputs=input_ids, images=video, attention_mask=attention_masks, modalities="video", do_sample=True, temperature=0.2, max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria])
end_time = time.time()
print(f"Time taken for inference: {end_time - start_time} seconds")
# import pdb;pdb.set_trace()
# output_ids = model.generate(inputs=input_ids, images=video, attention_mask=attention_masks, modalities="video", do_sample=True, temperature=0.2, use_cache=True, stopping_criteria=[stopping_criteria])

outputs = tokenizer.batch_decode(output_ids, skip_special_tokens=True)[0].strip()
Expand Down
10 changes: 8 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ classifiers = [

dependencies = [
"torch==2.1.0", "torchvision==0.16.0",
"transformers==4.36.2", "tokenizers==0.15.2", "sentencepiece==0.1.99", "shortuuid",
"transformers==4.39.2", "tokenizers==0.15.2", "sentencepiece==0.1.99", "shortuuid",
"accelerate==0.27.2", "peft==0.4.0", "bitsandbytes==0.41.0",
"pydantic<2,>=1", "markdown2[all]", "numpy", "scikit-learn==1.2.2",
"gradio==3.35.2", "gradio_client==0.2.9",
Expand All @@ -28,7 +28,7 @@ dependencies = [

[project.urls]
"Homepage" = "https://llava-vl.github.io"
"Bug Tracker" = "https://github.com/LLaVA-VL/LLaVA-NeXT-Video/issues"
"Bug Tracker" = "https://github.com/haotian-liu/LLaVA/issues"

[tool.setuptools.packages.find]
exclude = [
Expand All @@ -41,6 +41,9 @@ exclude = [
"tests*",
"checkpoints*",
"project_checkpoints*",
"work_dirs*",
"data*",
"trl*",
]

[tool.wheel]
Expand All @@ -54,4 +57,7 @@ exclude = [
"tests*",
"checkpoints*",
"project_checkpoints*",
"work_dirs*",
"data*",
"trl*",
]
6 changes: 4 additions & 2 deletions scripts/video/demo/video_demo.sh
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
#!/bin/bash
ROOT_DIR="root to LLaVA-NeXT-Video"
ROOT_DIR="/mnt/bn/vl-research/workspace/yhzhang/llava-next-video"

if [ ! -e $ROOT_DIR ]; then
echo "The root dir does not exist. Exiting the script."
exit 1
fi

cd $ROOT_DIR

export PYTHONWARNINGS=ignore
export TOKENIZERS_PARALLELISM=false

Expand All @@ -24,7 +26,7 @@ else
SAVE_DIR=$(basename $CKPT)_${CONV_MODE}_frames_${FRAMES}_stride_${POOL_STRIDE}
fi

torchrun playground/demo/video_demo.py \
python3 playground/demo/video_demo.py \
--model-path $CKPT \
--video_path ${VIDEO_PATH} \
--output_dir ./work_dirs/video_demo/$SAVE_DIR \
Expand Down
2 changes: 1 addition & 1 deletion scripts/video/eval/video_description_from_t2v.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ FRAMES=$3
POOL_STRIDE=$4
OVERWRITE=$5
CHUNKS=${6:-1}
DO_CENTER_CROP=${7:-True}
DO_CENTER_CROP=${7:-False}

echo "Using $CHUNKS GPUs"

Expand Down

0 comments on commit 9fcbebe

Please sign in to comment.