Skip to content

Commit

Permalink
Merge branch 'main' into lkc/develop
Browse files Browse the repository at this point in the history
  • Loading branch information
Andy1621 committed Apr 13, 2023
2 parents 3d733c7 + 78567b0 commit 7dbc8e6
Show file tree
Hide file tree
Showing 8 changed files with 20 additions and 12 deletions.
2 changes: 1 addition & 1 deletion automatic_label_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def save_mask_data(output_dir, caption, mask_list, box_list, label_list):
caption = check_caption(caption, pred_phrases)
print(f"Revise caption with number: {caption}")

transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)

masks, _, _ = predictor.predict_torch(
point_coords = None,
Expand Down
6 changes: 4 additions & 2 deletions gradio_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,9 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
if sam_predictor is None:
# initialize SAM
assert sam_checkpoint, 'sam_checkpoint is not found!'
sam_predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
sam = build_sam(checkpoint=sam_checkpoint)
sam.to(device=device)
sam_predictor = SamPredictor(sam)

image = np.array(image_pil)
sam_predictor.set_image(image)
Expand All @@ -197,7 +199,7 @@ def run_grounded_sam(input_image, text_prompt, task_type, inpaint_prompt, box_th
print(f"After NMS: {boxes_filt.shape[0]} boxes")
print(f"Revise caption with number: {text_prompt}")

transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)

masks, _, _ = sam_predictor.predict_torch(
point_coords = None,
Expand Down
6 changes: 4 additions & 2 deletions gradio_auto_label.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,9 @@ def run_grounded_sam(image_path, openai_key, box_threshold, text_threshold, iou_
size = image_pil.size

# initialize SAM
predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))
sam = build_sam(checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
image = np.array(image_path)
predictor.set_image(image)

Expand All @@ -304,7 +306,7 @@ def run_grounded_sam(image_path, openai_key, box_threshold, text_threshold, iou_
caption = check_caption(caption, pred_phrases)
print(f"Revise caption with number: {caption}")

transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)

masks, _, _ = predictor.predict_torch(
point_coords = None,
Expand Down
6 changes: 4 additions & 2 deletions grounded_sam.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,9 @@
"outputs": [],
"source": [
"sam_checkpoint = 'sam_vit_h_4b8939.pth'\n",
"sam_predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint))"
"sam = build_sam(checkpoint=sam_checkpoint)\n",
"sam.to(device=device)\n",
"sam_predictor = SamPredictor(sam)"
]
},
{
Expand Down Expand Up @@ -404,7 +406,7 @@
"metadata": {},
"outputs": [],
"source": [
"transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_xyxy, image_source.shape[:2])\n",
"transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes_xyxy, image_source.shape[:2]).to(device)\n",
"masks, _, _ = sam_predictor.predict_torch(\n",
" point_coords = None,\n",
" point_labels = None,\n",
Expand Down
2 changes: 1 addition & 1 deletion grounded_sam_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ def save_mask_data(output_dir, mask_list, box_list, label_list):
boxes_filt[i][2:] += boxes_filt[i][:2]

boxes_filt = boxes_filt.cpu()
transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)

masks, _, _ = predictor.predict_torch(
point_coords = None,
Expand Down
2 changes: 1 addition & 1 deletion grounded_sam_inpainting_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def show_box(box, ax, label):
boxes_filt[i][2:] += boxes_filt[i][:2]

boxes_filt = boxes_filt.cpu()
transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)

masks, _, _ = predictor.predict_torch(
point_coords = None,
Expand Down
6 changes: 4 additions & 2 deletions grounded_sam_whisper_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,9 @@ def speech_recognition(speech_file, model):
)

# initialize SAM
predictor = SamPredictor(build_sam(checkpoint=sam_checkpoint).to(args.device))
sam = build_sam(checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(image)
Expand All @@ -229,7 +231,7 @@ def speech_recognition(speech_file, model):
pred_phrases = [pred_phrases[idx] for idx in nms_idx]
print(f"After NMS: {boxes_filt.shape[0]} boxes")

transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)

masks, _, _ = predictor.predict_torch(
point_coords = None,
Expand Down
2 changes: 1 addition & 1 deletion grounded_sam_whisper_inpainting_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ def filter_prompts_with_chatgpt(caption, max_tokens=100, model="gpt-3.5-turbo"):
boxes_filt[i][2:] += boxes_filt[i][:2]

boxes_filt = boxes_filt.cpu()
transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2])
transformed_boxes = predictor.transform.apply_boxes_torch(boxes_filt, image.shape[:2]).to(device)

masks, _, _ = predictor.predict_torch(
point_coords = None,
Expand Down

0 comments on commit 7dbc8e6

Please sign in to comment.