Skip to content

Commit

Permalink
add pytorch hub support
Browse files Browse the repository at this point in the history
  • Loading branch information
ZachL1 committed May 8, 2024
1 parent cb5ddf3 commit 1b77c64
Showing 1 changed file with 20 additions and 6 deletions.
26 changes: 20 additions & 6 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ def metric3d_convnext_large(pretrain=False, **kwargs):
cfg = Config.fromfile(cfg_file)
model = get_configured_monodepth_model(cfg)
if pretrain:
model.load_state_dict(torch.load(ckpt_file)['model_state_dict'], strict=False)
model.load_state_dict(
torch.hub.load_state_dict_from_url(ckpt_file)['model_state_dict'],
strict=False,
)
return model

def metric3d_vit_small(pretrain=False, **kwargs):
Expand All @@ -65,7 +68,10 @@ def metric3d_vit_small(pretrain=False, **kwargs):
cfg = Config.fromfile(cfg_file)
model = get_configured_monodepth_model(cfg)
if pretrain:
model.load_state_dict(torch.load(ckpt_file)['model_state_dict'], strict=False)
model.load_state_dict(
torch.hub.load_state_dict_from_url(ckpt_file)['model_state_dict'],
strict=False,
)
return model

def metric3d_vit_large(pretrain=False, **kwargs):
Expand All @@ -83,7 +89,10 @@ def metric3d_vit_large(pretrain=False, **kwargs):
cfg = Config.fromfile(cfg_file)
model = get_configured_monodepth_model(cfg)
if pretrain:
model.load_state_dict(torch.load(ckpt_file)['model_state_dict'], strict=False)
model.load_state_dict(
torch.hub.load_state_dict_from_url(ckpt_file)['model_state_dict'],
strict=False,
)
return model

def metric3d_vit_giant2(pretrain=False, **kwargs):
Expand All @@ -101,11 +110,16 @@ def metric3d_vit_giant2(pretrain=False, **kwargs):
cfg = Config.fromfile(cfg_file)
model = get_configured_monodepth_model(cfg)
if pretrain:
model.load_state_dict(torch.load(ckpt_file)['model_state_dict'], strict=False)
model.load_state_dict(
torch.hub.load_state_dict_from_url(ckpt_file)['model_state_dict'],
strict=False,
)
return model



import os
os.environ['HTTP_PROXY'] = 'http:https://192.168.195.225:7890'
os.environ['HTTPS_PROXY'] = 'http:https://192.168.195.225:7890'


if __name__ == '__main__':
Expand Down Expand Up @@ -145,7 +159,7 @@ def metric3d_vit_giant2(pretrain=False, **kwargs):

###################### canonical camera space ######################
# inference
model = torch.hub.load('yvanyin/metric3d', 'metric3d_vit_large', pretrain=True)
model = torch.hub.load('yvanyin/metric3d', 'metric3d_vit_small', pretrain=True)
model.cuda().eval()
with torch.no_grad():
pred_depth, confidence, output_dict = model.inference({'input': rgb})
Expand Down

0 comments on commit 1b77c64

Please sign in to comment.