Skip to content

Commit

Permalink
[fix]sam_hq (#164)
Browse files Browse the repository at this point in the history
* modified:   sam/_wsgi.py
	modified:   sam/mmdetection.py

* modified:   readme.md
	modified:   readme_zh.md

* modified:   readme.md
	modified:   readme_zh.md

---------

Co-authored-by: JimmyMa99 <[email protected]>
  • Loading branch information
JimmyMa99 and JimmyMa99 authored Oct 16, 2023
1 parent c06ae5d commit f5f5e80
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 8 deletions.
6 changes: 4 additions & 2 deletions label_anything/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ cd path/to/playground/label_anything
# conda install pycocotools -c conda-forge
pip install opencv-python pycocotools matplotlib onnxruntime onnx
pip install git+https://github.com/facebookresearch/segment-anything.git
pip install segment-anything-hq
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

# If you're on a windows machine you can use the following in place of wget
Expand Down Expand Up @@ -127,10 +128,11 @@ label-studio-ml start sam --port 8003 --with \
# inference on HQ-SAM
label-studio-ml start sam --port 8003 --with \
sam_config=vit_b \
sam_checkpoint_file=./sam_hq_vit_l.pth \
sam_checkpoint_file=./sam_hq_vit_b.pth \
out_mask=True \
out_bbox=True \
device=cuda:0
device=cuda:0 \
model_name=sam_hq
# device=cuda:0 is for using GPU inference. If you want to use CPU inference, replace cuda:0 with cpu.
# out_poly=True returns the annotation of the bounding polygon.

Expand Down
4 changes: 3 additions & 1 deletion label_anything/readme_zh.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ cd path/to/playground/label_anything
# conda install pycocotools -c conda-forge
pip install opencv-python pycocotools matplotlib onnxruntime onnx timm
pip install git+https://github.com/facebookresearch/segment-anything.git
pip install segment-anything-hq

# 下载sam预训练模型
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth
Expand Down Expand Up @@ -127,10 +128,11 @@ device=cuda:0
# 采用 HQ-SAM 进行后端推理
label-studio-ml start sam --port 8003 --with \
sam_config=vit_b \
sam_checkpoint_file=./sam_hq_vit_l.pth \
sam_checkpoint_file=./sam_hq_vit_b.pth \
out_mask=True \
out_bbox=True \
device=cuda:0 \
model_name=sam_hq
# device=cuda:0 为使用 GPU 推理,如果使用 cpu 推理,将 cuda:0 替换为 cpu
# out_poly=True 返回外接多边形的标注

Expand Down
1 change: 1 addition & 0 deletions label_anything/sam/_wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def get_kwargs_from_config(config_path=_DEFAULT_CONFIG_PATH):
'--model-dir',
dest='model_dir',
default=os.path.dirname(__file__),
# default='./sam_hq_vit_b.pth',
help='Directory models are store',
)
parser.add_argument(
Expand Down
15 changes: 10 additions & 5 deletions label_anything/sam/mmdetection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,17 @@
get_single_tag_keys)
from label_studio_tools.core.utils.io import get_data_dir
from filter_poly import NearNeighborRemover
import pdb
# from mmdet.apis import inference_detector, init_detector

logger = logging.getLogger(__name__)


def load_my_model(
model_name="sam",
model_name="sam_hq",
device="cuda:0",
sam_config="vit_b",
sam_checkpoint_file="sam_vit_b_01ec64.pth"):
sam_checkpoint_file="sam_hq_vit_b.pth"):
"""
Loads the Segment Anything model on initializing Label studio, so if you call it outside MyModel it doesn't load every time you try to make a prediction
Returns the predictor object. For more, look at Facebook's SAM docs
Expand All @@ -42,6 +43,12 @@ def load_my_model(
sam.to(device=device)
predictor = SamPredictor(sam)
return predictor
elif model_name == "sam_hq":
from segment_anything_hq import sam_model_registry, SamPredictor
sam = sam_model_registry[sam_config](checkpoint=sam_checkpoint_file)
sam.to(device=device)
predictor = SamPredictor(sam)
return predictor
elif model_name == "mobile_sam":
from models.mobile_sam import SamPredictor, sam_model_registry
sam = sam_model_registry[sam_config](checkpoint=sam_checkpoint_file)
Expand All @@ -56,7 +63,7 @@ class MMDetection(LabelStudioMLBase):
"""Object detector based on https://github.com/open-mmlab/mmdetection."""

def __init__(self,
model_name="sam",
model_name="sam_hq",
config_file=None,
checkpoint_file=None,
sam_config='vit_b',
Expand All @@ -79,7 +86,6 @@ def __init__(self,
self.out_mask = out_mask
self.out_bbox = out_bbox
self.out_poly = out_poly

# config_file = config_file or os.environ['config_file']
# checkpoint_file = checkpoint_file or os.environ['checkpoint_file']
# self.config_file = config_file
Expand Down Expand Up @@ -227,7 +233,6 @@ def predict(self, tasks, **kwargs):
point_labels=np.array([1]),
multimask_output=False,
)

mask = masks[0].astype(np.uint8) # each mask has shape [H, W]
# converting the mask from the model to RLE format which is usable in Label Studio

Expand Down

0 comments on commit f5f5e80

Please sign in to comment.