-
Notifications
You must be signed in to change notification settings - Fork 2
/
run_inference_lavis.py
125 lines (105 loc) · 4.73 KB
/
run_inference_lavis.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import os
import json
import torch
from PIL import Image
from lavis.models import load_model_and_preprocess
from multiprocessing import Pool
import argparse
from tqdm import tqdm
import time
def load_and_process_image(item):
# Load and preprocess the image
raw_image = Image.open(id2path[item["image_id"]]).convert("RGB")
processed_image = vis_processors["eval"](raw_image).unsqueeze(0).to(device)
return processed_image, item["question"], item["data_id"]
def process_images_in_batches(batch_data, batch_size, prompt):
# Create a pool of workers
# Monitor the progress of the pool
output = []
print("Generate predictions...")
# Process images in batches
for idx, i in enumerate(range(0, len(batch_data), batch_size)):
if (idx + 1) % 100 == 0:
print(f"Processing batch {idx}/{len(batch_data)/batch_size}")
# Subset results for the current batch
batch_subset = batch_data[i:i+batch_size]
# Separate the images, questions, and ids
batch_images, batch_questions, batch_ids = [], [], []
# Load and preprocess the images
for item in batch_subset:
tmp_img, tmp_q, tmp_id = load_and_process_image(item)
batch_images.append(tmp_img)
batch_questions.append(tmp_q)
batch_ids.append(tmp_id)
# Concatenate the batch images
image_batch = torch.cat(batch_images, dim=0)
# add prompt to questions
batch_questions = [prompt.format(q) for q in batch_questions]
# Generate predictions for the batch
start_time = time.time()
answers = model.generate({"image": image_batch, "prompt": batch_questions},
length_penalty=-1)
print(f"Time for batch {idx}: {time.time() - start_time}")
for idx, ans in zip(batch_ids, answers):
output.append({"data_id": idx, "prediction": ans})
return output
if __name__ == "__main__":
# argparse
parser = argparse.ArgumentParser()
parser.add_argument("--split", type=str, default="val", help="val, test, or human")
parser.add_argument("--model_name", type=str, default="blip2_t5", help="blip2_t5 | blip2_vicuna_instruct | blip2_t5_instruct")
parser.add_argument("--model_type", type=str, default="pretrain_flant5xxl", help="pretrain_flant5xxl | vicuna13b | flant5xxl")
parser.add_argument("--output_dir", type=str, default="predictions_ft", help="output directory")
parser.add_argument("--model_ckpt", type=str, default="development/blip2_t5_pretrain_flant5xxl_399_val=14.89.pt", help="output directory")
parser.add_argument("--batch_size", type=int, default=8, help="batch size")
args = parser.parse_args()
split2data = {
"val": "infoseek/infoseek_val.jsonl",
"test": "infoseek/infoseek_test.jsonl",
"human": "infoseek/infoseek_human.jsonl",
}
id2path = dict()
# load image paths
with open("id2image.jsonl", "r") as f:
for line in f:
line = json.loads(line)
image_id = line["image_id"]
path = line["image_path"]
id2path[image_id] = path
# Read the input JSONL file
with open(split2data[args.split], 'r') as f:
batch_data = [json.loads(line) for line in f]
# double check data exists:
not_exist = []
clean_batch_data = []
for idx, item in enumerate(batch_data):
if idx % 10000 == 0:
print(f"Processing {idx}/{len(batch_data)}")
path = id2path[item["image_id"]]
# check path exists
if not os.path.exists(path):
not_exist.append(item["image_id"])
else:
clean_batch_data.append(item)
print(len(not_exist))
# setup device to use
device = torch.device("cuda") if torch.cuda.is_available() else "cpu"
print("Load pretrained model...")
# loads BLIP-2 pre-trained model
model, vis_processors, _ = load_model_and_preprocess(name=args.model_name,
model_type=args.model_type,
is_eval=True, device=device)
# Load model from checkpoint
model.load_state_dict(torch.load(args.model_ckpt), strict=False)
model.eval()
# Desired batch size
batch_size = args.batch_size
PROMPT = "Question: {} Short answer:"
# Run the batch processing function
output = process_images_in_batches(clean_batch_data, batch_size, prompt=PROMPT)
# save output into jsonl
with open(os.path.join(args.output_dir, "zeroshot_{}_{}_{}.jsonl".format(
args.model_name, args.model_type, args.split
)), 'w') as f:
for item in output:
f.write(json.dumps(item, ensure_ascii=False) + "\n")