-
Notifications
You must be signed in to change notification settings - Fork 244
/
inference.py
46 lines (31 loc) · 1.29 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
'''
* The Inference of RAM and Tag2Text Models
* Written by Xinyu Huang
'''
import torch
def inference_tag2text(image, model, input_tag="None"):
with torch.no_grad():
caption, tag_predict = model.generate(image,
tag_input=None,
max_length=50,
return_tag_predict=True)
if input_tag == '' or input_tag == 'none' or input_tag == 'None':
return tag_predict[0], None, caption[0]
# If user input specified tags:
else:
input_tag_list = []
input_tag_list.append(input_tag.replace(',', ' | '))
with torch.no_grad():
caption, input_tag = model.generate(image,
tag_input=input_tag_list,
max_length=50,
return_tag_predict=True)
return tag_predict[0], input_tag[0], caption[0]
def inference_ram(image, model):
with torch.no_grad():
tags, tags_chinese = model.generate_tag(image)
return tags[0],tags_chinese[0]
def inference_ram_openset(image, model):
with torch.no_grad():
tags = model.generate_tag_openset(image)
return tags[0]