-
Notifications
You must be signed in to change notification settings - Fork 93
/
inference.py
120 lines (91 loc) 路 4.24 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
from video_chatgpt.video_conversation import conv_templates, SeparatorStyle
from video_chatgpt.model.utils import KeywordsStoppingCriteria
import torch
# Define constants
DEFAULT_VIDEO_TOKEN = "<video>"
DEFAULT_VIDEO_PATCH_TOKEN = "<vid_patch>"
DEFAULT_VID_START_TOKEN = "<vid_start>"
DEFAULT_VID_END_TOKEN = "<vid_end>"
def get_spatio_temporal_features_torch(features):
"""
Computes spatio-temporal features from given features.
Parameters:
features (torch.Tensor): Input features to process.
Returns:
torch.Tensor: Spatio-temporal features.
"""
# Extract the dimensions of the features
t, s, c = features.shape
# Compute temporal tokens as the mean along the time axis
temporal_tokens = torch.mean(features, dim=1)
# Padding size calculation
padding_size = 100 - t
# Pad temporal tokens if necessary
if padding_size > 0:
padding = torch.zeros(padding_size, c, device=features.device)
temporal_tokens = torch.cat((temporal_tokens, padding), dim=0)
# Compute spatial tokens as the mean along the spatial axis
spatial_tokens = torch.mean(features, dim=0)
# Concatenate temporal and spatial tokens and cast to half precision
concat_tokens = torch.cat([temporal_tokens, spatial_tokens], dim=0).half()
return concat_tokens
def video_chatgpt_infer(video_frames, question, conv_mode, model, vision_tower, tokenizer, image_processor, video_token_len):
"""
Run inference using the Video-ChatGPT model.
Parameters:
sample : Initial sample
video_frames (torch.Tensor): Video frames to process.
question (str): The question string.
conv_mode: Conversation mode.
model: The pretrained Video-ChatGPT model.
vision_tower: Vision model to extract video features.
tokenizer: Tokenizer for the model.
image_processor: Image processor to preprocess video frames.
video_token_len (int): The length of video tokens.
Returns:
dict: Dictionary containing the model's output.
"""
# Prepare question string for the model
if model.get_model().vision_config.use_vid_start_end:
qs = question + '\n' + DEFAULT_VID_START_TOKEN + DEFAULT_VIDEO_PATCH_TOKEN * video_token_len + DEFAULT_VID_END_TOKEN
else:
qs = question + '\n' + DEFAULT_VIDEO_PATCH_TOKEN * video_token_len
# Prepare conversation prompt
conv = conv_templates[conv_mode].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
# Tokenize the prompt
inputs = tokenizer([prompt])
# Preprocess video frames and get image tensor
image_tensor = image_processor.preprocess(video_frames, return_tensors='pt')['pixel_values']
# Move image tensor to GPU and reduce precision to half
image_tensor = image_tensor.half().cuda()
# Generate video spatio-temporal features
with torch.no_grad():
image_forward_outs = vision_tower(image_tensor, output_hidden_states=True)
frame_features = image_forward_outs.hidden_states[-2][:, 1:] # Use second to last layer as in LLaVA
video_spatio_temporal_features = get_spatio_temporal_features_torch(frame_features)
# Move inputs to GPU
input_ids = torch.as_tensor(inputs.input_ids).cuda()
# Define stopping criteria for generation
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
stopping_criteria = KeywordsStoppingCriteria([stop_str], tokenizer, input_ids)
# Run model inference
with torch.inference_mode():
output_ids = model.generate(
input_ids,
video_spatio_temporal_features=video_spatio_temporal_features.unsqueeze(0),
do_sample=True,
temperature=0.2,
max_new_tokens=1024,
stopping_criteria=[stopping_criteria])
# Check if output is the same as input
n_diff_input_output = (input_ids != output_ids[:, :input_ids.shape[1]]).sum().item()
if n_diff_input_output > 0:
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
# Decode output tokens
outputs = tokenizer.batch_decode(output_ids[:, input_ids.shape[1]:], skip_special_tokens=True)[0]
# Clean output string
outputs = outputs.strip().rstrip(stop_str).strip()
return outputs