Skip to content

Commit

Permalink
Gradio add automatic (#82)
Browse files Browse the repository at this point in the history
* gradio_app add automatic
  • Loading branch information
tuofeilunhifi committed Apr 11, 2023
1 parent c6f8ffb commit 7c16d30
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 296 deletions.
Binary file modified assets/gradio_demo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
72 changes: 59 additions & 13 deletions gradio_app.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import os
# os.system('pip install v0.1.0-alpha2.tar.gz')
import gradio as gr

import argparse
import os
import copy

import numpy as np
import torch
import torchvision
from PIL import Image, ImageDraw, ImageFont

# Grounding DINO
Expand All @@ -30,6 +32,10 @@
from diffusers import StableDiffusionInpaintPipeline
from huggingface_hub import hf_hub_download

# BLIP
from transformers import BlipProcessor, BlipForConditionalGeneration


def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
args = SLConfig.fromfile(model_config_path)
model = build_model(args)
Expand All @@ -42,6 +48,13 @@ def load_model_hf(model_config_path, repo_id, filename, device='cpu'):
_ = model.eval()
return model

def generate_caption(processor, blip_model, raw_image):
# unconditional image captioning
inputs = processor(raw_image, return_tensors="pt").to("cuda", torch.float16)
out = blip_model.generate(**inputs)
caption = processor.decode(out[0], skip_special_tokens=True)
return caption

def plot_boxes_to_image(image_pil, tgt):
H, W = tgt["size"]
boxes = tgt["boxes"]
Expand Down Expand Up @@ -135,14 +148,16 @@ def get_grounding_output(model, image, caption, box_threshold, text_threshold, w
tokenized = tokenlizer(caption)
# build pred
pred_phrases = []
scores = []
for logit, box in zip(logits_filt, boxes_filt):
pred_phrase = get_phrases_from_posmap(logit > text_threshold, tokenized, tokenlizer)
if with_logits:
pred_phrases.append(pred_phrase + f"({str(logit.max().item())[:4]})")
else:
pred_phrases.append(pred_phrase)
scores.append(logit.max().item())

return boxes_filt, pred_phrases
return boxes_filt, torch.Tensor(scores), pred_phrases

def show_mask(mask, ax, random_color=False):
if random_color:
Expand All @@ -168,27 +183,37 @@ def show_box(box, ax, label):
output_dir="outputs"
device="cuda"

def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold):
assert text_prompt, 'text_prompt is not found!'
def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode):

# make dir
os.makedirs(output_dir, exist_ok=True)
# load image
image_pil, image = load_image(image_path.convert("RGB"))
# load model
model = load_model_hf(config_file, ckpt_repo_id, ckpt_filenmae)
# model = load_model(config_file, ckpt_filenmae, device=device)

# visualize raw image
image_pil.save(os.path.join(output_dir, "raw_image.jpg"))

if task_type == 'automatic':
# generate caption and tags
# use Tag2Text can generate better captions
# https://huggingface.co/spaces/xinyu1205/Tag2Text
# but there are some bugs...
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
blip_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large", torch_dtype=torch.float16).to("cuda")
text_prompt = generate_caption(processor, blip_model, image_pil)
print(f"Caption: {text_prompt}")

# run grounding dino model
boxes_filt, pred_phrases = get_grounding_output(
boxes_filt, scores, pred_phrases = get_grounding_output(
model, image, text_prompt, box_threshold, text_threshold, device=device
)

size = image_pil.size

if task_type == 'seg' or task_type == 'inpainting':
if task_type == 'seg' or task_type == 'inpainting' or task_type == 'automatic':
# initialize SAM
predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
image = np.array(image_path)
Expand All @@ -201,6 +226,16 @@ def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_thr
boxes_filt[i][2:] += boxes_filt[i][:2]

boxes_filt = boxes_filt.cpu()

if task_type == 'automatic':
# use NMS to handle overlapped boxes
print(f"Before NMS: {boxes_filt.shape[0]} boxes")
nms_idx = torchvision.ops.nms(boxes_filt, scores, iou_threshold).numpy().tolist()
boxes_filt = boxes_filt[nms_idx]
pred_phrases = [pred_phrases[idx] for idx in nms_idx]
print(f"After NMS: {boxes_filt.shape[0]} boxes")
print(f"Revise caption with number: {text_prompt}")

transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])

masks, _, _ = predictor.predict_torch(
Expand All @@ -224,7 +259,7 @@ def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_thr
image_with_box.save(image_path)
image_result = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
return image_result
elif task_type == 'seg':
elif task_type == 'seg' or task_type == 'automatic':
assert sam_checkpoint, 'sam_checkpoint is not found!'

# draw output image
Expand All @@ -234,6 +269,8 @@ def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_thr
show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box, label in zip(boxes_filt, pred_phrases):
show_box(box.numpy(), plt.gca(), label)
if task_type == 'automatic':
plt.title(text_prompt)
plt.axis('off')
image_path = os.path.join(output_dir, "grounding_dino_output.jpg")
plt.savefig(image_path, bbox_inches="tight")
Expand All @@ -242,7 +279,11 @@ def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_thr
elif task_type == 'inpainting':
assert inpaint_prompt, 'inpaint_prompt is not found!'
# inpainting pipeline
mask = masks[0][0].cpu().numpy() # simply choose the first mask, which will be refine in the future release
if inpaint_mode == 'merge':
masks = torch.sum(masks, dim=0).unsqueeze(0)
masks = torch.where(masks > 0, True, False)
else:
mask = masks[0][0].cpu().numpy() # simply choose the first mask, which will be refine in the future release
mask_pil = Image.fromarray(mask)

pipe = StableDiffusionInpaintPipeline.from_pretrained(
Expand All @@ -268,15 +309,16 @@ def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_thr
parser = argparse.ArgumentParser("Grounded SAM demo", add_help=True)
parser.add_argument("--debug", action="store_true", help="using debug mode")
parser.add_argument("--share", action="store_true", help="share the app")
parser.add_argument('--port', type=int, default=7589, help='port to run the server')
args = parser.parse_args()

block = gr.Blocks().queue()
with block:
with gr.Row():
with gr.Column():
input_image = gr.Image(source='upload', type="pil")
text_prompt = gr.Textbox(label="Detection Prompt")
task_type = gr.Textbox(label="task type: det/seg/inpainting")
input_image = gr.Image(source='upload', type="pil", value="assets/demo1.jpg")
task_type = gr.Dropdown(["det", "seg", "inpainting", "automatic"], value="automatic", label="task_type")
text_prompt = gr.Textbox(label="Text Prompt")
inpaint_prompt = gr.Textbox(label="Inpaint Prompt")
run_button = gr.Button(label="Run")
with gr.Accordion("Advanced options", open=False):
Expand All @@ -286,14 +328,18 @@ def run_grounded_sam(image_path, text_prompt, task_type, inpaint_prompt, box_thr
text_threshold = gr.Slider(
label="Text Threshold", minimum=0.0, maximum=1.0, value=0.25, step=0.001
)
iou_threshold = gr.Slider(
label="IOU Threshold", minimum=0.0, maximum=1.0, value=0.5, step=0.001
)
inpaint_mode = gr.Dropdown(["merge", "first"], value="merge", label="inpaint_mode")

with gr.Column():
gallery = gr.outputs.Image(
type="pil",
).style(full_width=True, full_height=True)

run_button.click(fn=run_grounded_sam, inputs=[
input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold], outputs=[gallery])
input_image, text_prompt, task_type, inpaint_prompt, box_threshold, text_threshold, iou_threshold, inpaint_mode], outputs=[gallery])


block.launch(server_name='0.0.0.0', server_port=7589, debug=args.debug, share=args.share)
block.launch(server_name='0.0.0.0', server_port=args.port, debug=args.debug, share=args.share)
Loading

0 comments on commit 7c16d30

Please sign in to comment.