Skip to content

Commit

Permalink
use chatgpt in sam-whisper-inpaint
Browse files Browse the repository at this point in the history
  • Loading branch information
CiaoHe committed Apr 11, 2023
1 parent 6078dd1 commit 0f62596
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 9 deletions.
30 changes: 28 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ Using BLIP to generate caption, extract tags and using Grounded-SAM for box and

![](./assets/automatic_label_output_demo3.jpg)


**🔈Speak to edit🎨: Whisper + ChatGPT + Grounded-SAM + SD**

![](assets/acoustics/gsam_whisper_inpainting_demo.png)

**Imagine Space**

Some possible avenues for future work ...
Expand Down Expand Up @@ -258,9 +263,30 @@ python grounded_sam_whisper_demo.py \

**Run Voice-to-inpaint Demo**

Specify the object you want to inpaint [stored in `args.det_speech_file`] and the text you want to inpaint with [stored in `args.inpaint_speech_file`].
You can enable chatgpt to help you automatically detect the object and inpainting order with `--enable_chatgpt`.

Or you can specify the object you want to inpaint [stored in `args.det_speech_file`] and the text you want to inpaint with [stored in `args.inpaint_speech_file`].

```bash
# Example: enable chatgpt
export CUDA_VISIBLE_DEVICES=0
export OPENAI_KEY=your_openai_key
python grounded_sam_whisper_inpainting_demo.py \
--config GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py \
--grounded_checkpoint groundingdino_swint_ogc.pth \
--sam_checkpoint sam_vit_h_4b8939.pth \
--input_image assets/inpaint_demo.jpg \
--output_dir "outputs" \
--box_threshold 0.3 \
--text_threshold 0.25 \
--prompt_speech_file assets/acoustics/prompt_speech_file.mp3 \
--enable_chatgpt \
--openai_key $OPENAI_KEY \
--device "cuda"
```

```bash
# Example: without chatgpt
export CUDA_VISIBLE_DEVICES=0
python grounded_sam_whisper_inpainting_demo.py \
--config GroundingDINO/groundingdino/config/GroundingDINO_SwinT_OGC.py \
Expand All @@ -275,7 +301,7 @@ python grounded_sam_whisper_inpainting_demo.py \
--device "cuda"
```

![](assets/acoustics/gsam_whisper_inpainting_demo.png)
![](./assets/acoustics/gsam_whisper_inpainting_pipeline.png)


## :cupid: Acknowledgements
Expand Down
Binary file modified assets/acoustics/gsam_whisper_inpainting_demo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/acoustics/prompt_speech_file.mp3
Binary file not shown.
51 changes: 44 additions & 7 deletions grounded_sam_whisper_inpainting_demo.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import argparse
import os
import copy
from warnings import warn

import numpy as np
import torch
Expand Down Expand Up @@ -30,6 +30,9 @@
# whisper
import whisper

# ChatGPT
import openai


def load_image(image_path):
# load image
Expand Down Expand Up @@ -131,6 +134,27 @@ def speech_recognition(speech_file, model):
return speech_text, speech_language


def filter_prompts_with_chatgpt(caption, max_tokens=100, model="gpt-3.5-turbo"):
prompt = [
{
'role': 'system',
'content': f"Extract the main object to be replaced and marked it as 'main_object', " + \
f"Extract the remaining part as 'other prompt' " + \
f"Return (main_object, other prompt)" + \
f'Given caption: {caption}.'
}
]
response = openai.ChatCompletion.create(model=model, messages=prompt, temperature=0.6, max_tokens=max_tokens)
reply = response['choices'][0]['message']['content']
try:
det_prompt, inpaint_prompt = reply.split('\n')[0].split(':')[-1].strip(), \
reply.split('\n')[1].split(':')[-1].strip()
except:
warn(f"Failed to extract tags from caption") # use caption as det_prompt, inpaint_prompt
det_prompt, inpaint_prompt = caption, caption
return det_prompt, inpaint_prompt


if __name__ == "__main__":

parser = argparse.ArgumentParser("Grounded-Segment-Anything Demo", add_help=True)
Expand All @@ -145,8 +169,12 @@ def speech_recognition(speech_file, model):
parser.add_argument(
"--output_dir", "-o", type=str, default="outputs", required=True, help="output directory"
)
parser.add_argument("--det_speech_file", type=str, required=True, help="grounding speech file")
parser.add_argument("--inpaint_speech_file", type=str, required=True, help="inpaint speech file")
parser.add_argument("--det_speech_file", type=str, help="grounding speech file")
parser.add_argument("--inpaint_speech_file", type=str, help="inpaint speech file")
parser.add_argument("--prompt_speech_file", type=str, help="prompt speech file, no need to provide det_speech_file")
parser.add_argument("--enable_chatgpt", action="store_true", help="enable chatgpt")
parser.add_argument("--openai_key", type=str, help="key for chatgpt")
parser.add_argument("--openai_proxy", default=None, type=str, help="proxy for chatgpt")
parser.add_argument("--whisper_model", type=str, default="small", help="whisper model version: tiny, base, small, medium, large")
parser.add_argument("--box_threshold", type=float, default=0.3, help="box threshold")
parser.add_argument("--text_threshold", type=float, default=0.25, help="text threshold")
Expand Down Expand Up @@ -176,10 +204,19 @@ def speech_recognition(speech_file, model):

# recognize speech
whisper_model = whisper.load_model(args.whisper_model)
det_prompt, det_speech_language = speech_recognition(args.det_speech_file, whisper_model)
inpaint_prompt, inpaint_speech_language = speech_recognition(args.inpaint_speech_file, whisper_model)
print(f"det_prompt: {det_prompt}, using language: {det_speech_language}")
print(f"inpaint_prompt: {inpaint_prompt}, using language: {inpaint_speech_language}")

if args.enable_chatgpt:
openai.api_key = args.openai_key
if args.openai_proxy:
openai.proxy = {"http": args.openai_proxy, "https": args.openai_proxy}
speech_text, _ = speech_recognition(args.prompt_speech_file, whisper_model)
det_prompt, inpaint_prompt = filter_prompts_with_chatgpt(speech_text)
print(f"det_prompt: {det_prompt}, inpaint_prompt: {inpaint_prompt}")
else:
det_prompt, det_speech_language = speech_recognition(args.det_speech_file, whisper_model)
inpaint_prompt, inpaint_speech_language = speech_recognition(args.inpaint_speech_file, whisper_model)
print(f"det_prompt: {det_prompt}, using language: {det_speech_language}")
print(f"inpaint_prompt: {inpaint_prompt}, using language: {inpaint_speech_language}")

# run grounding dino model
boxes_filt, pred_phrases = get_grounding_output(
Expand Down

0 comments on commit 0f62596

Please sign in to comment.