Skip to content

Commit

Permalink
merged build_sam methods
Browse files Browse the repository at this point in the history
  • Loading branch information
KabirSubbiah committed Aug 22, 2023
1 parent f3c4058 commit 307cbd9
Showing 1 changed file with 24 additions and 24 deletions.
48 changes: 24 additions & 24 deletions lang_sam/lang_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,33 +53,33 @@ def __init__(self, sam_type="vit_h", ckpt_path=None):
self.sam_type = sam_type
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.build_groundingdino()
if ckpt_path == None:
self.build_sam(sam_type)
self.build_sam(ckpt_path)

def build_sam(self, ckpt_path):
if self.sam_type is None or ckpt_path is None:
if self.sam_type is None:
print("No sam type indicated. Using vit_h by default.")
self.sam_type = "vit_h"
checkpoint_url = SAM_MODELS[self.sam_type]
try:
sam = sam_model_registry[self.sam_type]()
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)
sam.load_state_dict(state_dict, strict=True)
except:
raise ValueError(f"Problem loading SAM please make sure you have the right model type: {sam_type} \
and a working checkpoint: {checkpoint_url}. Recommend deleting the checkpoint and \
re-downloading it.")
sam.to(device=self.device)
self.sam = SamPredictor(sam)
else:
self.build_sam_with_ckpt(sam_type, ckpt_path)

def build_sam_with_ckpt(self, sam_type, ckpt_path):
try:
sam = sam_model_registry[sam_type](ckpt_path)
except:
raise ValueError(f"Problem loading SAM. Your model type: {sam_type} \
try:
sam = sam_model_registry[self.sam_type](ckpt_path)
except:
raise ValueError(f"Problem loading SAM. Your model type: {sam_type} \
should match your checkpoint path: {ckpt_path}. Recommend calling LangSAM \
using matching model type AND checkpoint path")
sam.to(device=self.device)
self.sam = SamPredictor(sam)

def build_sam(self, sam_type):
checkpoint_url = SAM_MODELS[sam_type]
try:
sam = sam_model_registry[sam_type]()
state_dict = torch.hub.load_state_dict_from_url(checkpoint_url)
sam.load_state_dict(state_dict, strict=True)
except:
raise ValueError(f"Problem loading SAM please make sure you have the right model type: {sam_type} \
and a working checkpoint: {checkpoint_url}. Recommend deleting the checkpoint and \
re-downloading it.")
sam.to(device=self.device)
self.sam = SamPredictor(sam)
sam.to(device=self.device)
self.sam = SamPredictor(sam)

def build_groundingdino(self):
ckpt_repo_id = "ShilongLiu/GroundingDINO"
Expand Down

0 comments on commit 307cbd9

Please sign in to comment.