Skip to content

Commit

Permalink
fix checkpoint storage
Browse files Browse the repository at this point in the history
  • Loading branch information
luca-medeiros committed Apr 16, 2023
1 parent d47b9ca commit 0009167
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions lang_sam/lang_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"vit_b": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
}

CACHE_PATH = os.environ.get("TORCH_HOME", "~/.cache/torch/hub")
CACHE_PATH = os.environ.get("TORCH_HOME", "~/.cache/torch/hub/checkpoints")


def load_model_hf(repo_id, filename, ckpt_config_filename, device='cpu'):
Expand Down Expand Up @@ -58,8 +58,10 @@ def __init__(self, sam_type="vit_h"):

def build_sam(self, sam_type):
url = SAM_MODELS[sam_type]
sam_checkpoint = os.path.join(CACHE_PATH, 'checkpoints', os.path.basename(url))
sam_checkpoint = os.path.join(CACHE_PATH, os.path.basename(url))
if not os.path.exists(sam_checkpoint):
if not os.path.exists(CACHE_PATH):
os.makedirs(CACHE_PATH)
request.urlretrieve(url, sam_checkpoint)
sam = sam_model_registry[sam_type](checkpoint=sam_checkpoint)
sam.to(device=self.device)
Expand Down

0 comments on commit 0009167

Please sign in to comment.