Skip to content

Commit

Permalink
update convnext-tiny and add normal estimation for torchhub
Browse files Browse the repository at this point in the history
  • Loading branch information
ZachL1 committed Jun 9, 2024
1 parent af511bc commit ff0ada4
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 6 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -281,10 +281,12 @@ Now you can use Metric3D via Pytorch Hub with just few lines of code:
import torch
model = torch.hub.load('yvanyin/metric3d', 'metric3d_vit_small', pretrain=True)
pred_depth, confidence, output_dict = model.inference({'input': rgb})
pred_normal = output_dict['prediction_normal'][:, :3, :, :] # only available for Metric3Dv2 i.e., ViT models
normal_confidence = output_dict['prediction_normal'][:, 3, :, :] # see https://arxiv.org/abs/2109.09881 for details
```
Supported models: `metric3d_convnext_large`, `metric3d_vit_small`, `metric3d_vit_large`, `metric3d_vit_giant2`.
Supported models: `metric3d_convnext_tiny`, `metric3d_convnext_large`, `metric3d_vit_small`, `metric3d_vit_large`, `metric3d_vit_giant2`.

We also provided a minimal working example in [hubconf.py](https://github.com/YvanYin/Metric3D/blob/main/hubconf.py#L122), which hopefully makes everything clearer.
We also provided a minimal working example in [hubconf.py](https://github.com/YvanYin/Metric3D/blob/main/hubconf.py#L145), which hopefully makes everything clearer.

### News: ONNX Exportation and Inference are supported

Expand Down
40 changes: 38 additions & 2 deletions hubconf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@

MODEL_TYPE = {
'ConvNeXt-Tiny': {
# TODO
'cfg_file': f'{metric3d_dir}/mono/configs/HourglassDecoder/convtiny.0.3_150.py',
'ckpt_file': 'https://huggingface.co/JUGGHM/Metric3D/blob/main/convtiny_hourglass_v1.pth',
},
'ConvNeXt-Large': {
'cfg_file': f'{metric3d_dir}/mono/configs/HourglassDecoder/convlarge.0.3_150.py',
Expand All @@ -34,6 +35,27 @@



def metric3d_convnext_tiny(pretrain=False, **kwargs):
'''
Return a Metric3D model with ConvNeXt-Large backbone and Hourglass-Decoder head.
For usage examples, refer to: https://github.com/YvanYin/Metric3D/blob/main/hubconf.py
Args:
pretrain (bool): whether to load pretrained weights.
Returns:
model (nn.Module): a Metric3D model.
'''
cfg_file = MODEL_TYPE['ConvNeXt-Tiny']['cfg_file']
ckpt_file = MODEL_TYPE['ConvNeXt-Tiny']['ckpt_file']

cfg = Config.fromfile(cfg_file)
model = get_configured_monodepth_model(cfg)
if pretrain:
model.load_state_dict(
torch.hub.load_state_dict_from_url(ckpt_file)['model_state_dict'],
strict=False,
)
return model

def metric3d_convnext_large(pretrain=False, **kwargs):
'''
Return a Metric3D model with ConvNeXt-Large backbone and Hourglass-Decoder head.
Expand Down Expand Up @@ -122,6 +144,7 @@ def metric3d_vit_giant2(pretrain=False, **kwargs):

if __name__ == '__main__':
import cv2
import numpy as np
#### prepare data
rgb_file = 'data/kitti_demo/rgb/0000000050.png'
depth_file = 'data/kitti_demo/depth/0000000050.png'
Expand Down Expand Up @@ -185,4 +208,17 @@ def metric3d_vit_giant2(pretrain=False, **kwargs):

mask = (gt_depth > 1e-8)
abs_rel_err = (torch.abs(pred_depth[mask] - gt_depth[mask]) / gt_depth[mask]).mean()
print('abs_rel_err:', abs_rel_err.item())
print('abs_rel_err:', abs_rel_err.item())

#### normal are also available
if 'prediction_normal' in output_dict: # only available for Metric3Dv2, i.e. vit model
pred_normal = output_dict['prediction_normal'][:, :3, :, :]
normal_confidence = output_dict['prediction_normal'][:, 3, :, :] # see https://arxiv.org/abs/2109.09881 for details
# un pad and resize to some size if needed
pred_normal = pred_normal.squeeze()
pred_normal = pred_normal[:, pad_info[0] : pred_normal.shape[1] - pad_info[1], pad_info[2] : pred_normal.shape[2] - pad_info[3]]
# you can now do anything with the normal
# such as visualize pred_normal
pred_normal_vis = pred_normal.cpu().numpy().transpose((1, 2, 0))
pred_normal_vis = (pred_normal_vis + 1) / 2
cv2.imwrite('normal_vis.png', (pred_normal_vis * 255).astype(np.uint8))
4 changes: 2 additions & 2 deletions mono/utils/do_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,8 @@ def do_scalecano_test_with_custom_data(

for j, gt_depth in enumerate(gt_depths):
normal_out = None
if 'normal_out_list' in outputs.keys():
normal_out = outputs['normal_out_list'][0][j, :]
if 'prediction_normal' in outputs.keys():
normal_out = outputs['prediction_normal'][j, :]

postprocess_per_image(
i*bs+j,
Expand Down

0 comments on commit ff0ada4

Please sign in to comment.