-
Notifications
You must be signed in to change notification settings - Fork 0
/
pokemon_gradio.py
161 lines (140 loc) · 6.06 KB
/
pokemon_gradio.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
import gradio as gr
import json
import pandas as pd
from PIL import Image
import requests
import numpy as np
import daft
import vexpresso
from vexpresso.utils import ResourceRequest
from PIL import Image
import requests
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizerFast
import torch
class ClipEmbeddingsFunction:
def __init__(self):
self.model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
self.processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
self.tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32")
self.device = torch.device('cpu')
if torch.cuda.is_available():
self.device = torch.device('cuda')
self.model = self.model.to(self.device)
def __call__(self, inp, inp_type):
if inp_type == "image":
inputs = self.processor(images=inp, return_tensors="pt", padding=True)['pixel_values'].to(self.device)
return self.model.get_image_features(inputs).detach().cpu().numpy()
if inp_type == "text":
inputs = self.tokenizer(inp, padding=True, return_tensors="pt")
inputs["input_ids"] = inputs["input_ids"].to(self.device)
inputs["attention_mask"] = inputs["attention_mask"].to(self.device)
return self.model.get_text_features(**inputs).detach().cpu().numpy()
def download_images(images, image_type):
return [Image.open(requests.get(im["hires"], stream=True).raw).convert(image_type) for im in images]
def add_filter(state, filter_var, filter_method, filter_value):
f = {filter_var:{filter_method:filter_value}}
state.append(f)
return f"{state}"
def remove_row(state, row):
row = max(row, 0)
if int(row) < len(state) and len(state) > 0:
del state[int(row)]
return f"{state}"
if __name__ == "__main__":
with open("./data/pokedex.json", 'r') as f:
stuff = json.load(f)
df = pd.DataFrame(stuff)
# include only Kanto
df = df.iloc[:151]
# create collection
collection = vexpresso.create(data=df, backend="ray")
resource_request = ResourceRequest(num_gpus=0) # change this to 1 to use gpus
# download the images
print("Downoading images and embedding images...")
collection = collection.apply(
download_images,
collection["image"],
image_type="RGB",
to="downloaded_image",
datatype=daft.DataType.python()
).embed(
"downloaded_image",
embedding_fn=ClipEmbeddingsFunction,
inp_type="image",
to="clip_embeddings",
resource_request = resource_request
).execute()
def find_image_vectors(text_query, image_query, state):
if text_query is None and image_query is None:
raise ValueError("Image or text query must be provided")
embeddings = []
if text_query is not None:
embeddings.append(collection.embed_query(text_query, embedding_fn="clip_embeddings", inp_type="text"))
if image_query is not None:
embeddings.append(collection.embed_query(image_query, embedding_fn="clip_embeddings", inp_type="image"))
if len(embeddings) > 1:
query_embedding = 0.5*embeddings[0] + 0.5*embeddings[1]
else:
query_embedding = embeddings[0]
queried = collection.query(
"clip_embeddings",
query_embedding=query_embedding,
k=10,
inp_type="text",
).execute()
if state is not None and len(state) > 0:
for filt in state:
queried = queried.filter(filter_conditions=filt)
images = queried["downloaded_image"].to_list()[:4]
return images
with gr.Blocks() as demo:
state = gr.State([])
with gr.Row():
with gr.Column():
gr.Markdown(
"""
### Query Input! Add either a text prompt or upload an image, or add both to average the predictions.
""")
vector_query = gr.Textbox(placeholder="Type in a text prompt! Ex: Looks like a plant", show_label=False)
image_query = gr.Image(show_label=False)
with gr.Column():
gr.Markdown(
"""
### Filter method. Use this to filter based on metadata fields
""")
filter_var = gr.Textbox(label="filter_var")
filter_method = gr.Dropdown(choices=[
"eq", "neq", "gt", "gte", "lt", "lte", "isin", "notin", "contains", "notcontains"
],
label="filter_method"
)
filter_value = gr.Textbox(label="filter_value")
with gr.Column():
gr.Markdown(
"""
### Current Filter Methods
""")
current_filters = gr.Textbox(label="Current Filters")
filter_button = gr.Button("Add filter")
filter_button.click(fn=add_filter, inputs=[state, filter_var, filter_method, filter_value], outputs=current_filters)
with gr.Row():
button = gr.Button("Submit")
with gr.Row():
gallery = gr.Gallery(
label="Queried Pokemon!", show_label=False, elem_id="gallery", preview=True
).style(columns=[2], rows=[2], object_fit="contain", height="auto")
button.click(find_image_vectors, inputs=[vector_query, image_query, state], outputs=[gallery])
gr.Examples(
examples=[
["Turtle pokemon, has blue skin", None, []],
[None, "data/gradio-demo/mewtwo.jpeg", []],
["Looks like a plant", "data/gradio-demo/bulbasaur.png", []],
[None, "data/gradio-demo/pikachu-dog.jpg", []],
[None, "data/gradio-demo/charmander.png", []],
],
inputs=[vector_query, image_query, state],
# outputs=[gallery],
# fn=find_image_vectors,
cache_examples=False,
)
demo.launch()