# Copyright (c) OpenMMLab. All rights reserved. from typing import Dict, List, Optional, Sequence, Union import torch import torch.nn as nn import torch.nn.functional as F from mmcv.cnn import build_conv_layer, build_norm_layer, build_upsample_layer from mmengine.model import BaseModule from torch import Tensor from mmseg.registry import MODELS from mmseg.utils import SampleList from ..utils import resize from .decode_head import BaseDecodeHead class VPDDepthDecoder(BaseModule): """VPD Depth Decoder class. Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. num_deconv_layers (int): Number of deconvolution layers. num_deconv_filters (List[int]): List of output channels for deconvolution layers. init_cfg (Optional[Union[Dict, List[Dict]]], optional): Configuration for weight initialization. Defaults to Normal for Conv2d and ConvTranspose2d layers. """ def __init__(self, in_channels: int, out_channels: int, num_deconv_layers: int, num_deconv_filters: List[int], init_cfg: Optional[Union[Dict, List[Dict]]] = dict( type='Normal', std=0.001, layer=['Conv2d', 'ConvTranspose2d'])): super().__init__(init_cfg=init_cfg) self.in_channels = in_channels self.deconv_layers = self._make_deconv_layer( num_deconv_layers, num_deconv_filters, ) conv_layers = [] conv_layers.append( build_conv_layer( dict(type='Conv2d'), in_channels=num_deconv_filters[-1], out_channels=out_channels, kernel_size=3, stride=1, padding=1)) conv_layers.append(build_norm_layer(dict(type='BN'), out_channels)[1]) conv_layers.append(nn.ReLU(inplace=True)) self.conv_layers = nn.Sequential(*conv_layers) self.up_sample = nn.Upsample( scale_factor=2, mode='bilinear', align_corners=False) def forward(self, x): """Forward pass through the decoder network.""" out = self.deconv_layers(x) out = self.conv_layers(out) out = self.up_sample(out) out = self.up_sample(out) return out def _make_deconv_layer(self, num_layers, num_deconv_filters): """Make deconv layers.""" layers = [] in_channels = self.in_channels for i in range(num_layers): num_channels = num_deconv_filters[i] layers.append( build_upsample_layer( dict(type='deconv'), in_channels=in_channels, out_channels=num_channels, kernel_size=2, stride=2, padding=0, output_padding=0, bias=False)) layers.append(nn.BatchNorm2d(num_channels)) layers.append(nn.ReLU(inplace=True)) in_channels = num_channels return nn.Sequential(*layers) @MODELS.register_module() class VPDDepthHead(BaseDecodeHead): """Depth Prediction Head for VPD. .. _`VPD`: https://arxiv.org/abs/2303.02153 Args: max_depth (float): Maximum depth value. Defaults to 10.0. in_channels (Sequence[int]): Number of input channels for each convolutional layer. embed_dim (int): Dimension of embedding. Defaults to 192. feature_dim (int): Dimension of aggregated feature. Defaults to 1536. num_deconv_layers (int): Number of deconvolution layers in the decoder. Defaults to 3. num_deconv_filters (Sequence[int]): Number of filters for each deconv layer. Defaults to (32, 32, 32). fmap_border (Union[int, Sequence[int]]): Feature map border for cropping. Defaults to 0. align_corners (bool): Flag for align_corners in interpolation. Defaults to False. loss_decode (dict): Configurations for the loss function. Defaults to dict(type='SiLogLoss'). init_cfg (dict): Initialization configurations. Defaults to dict(type='TruncNormal', std=0.02, layer=['Conv2d', 'Linear']). """ num_classes = 1 out_channels = 1 input_transform = None def __init__( self, max_depth: float = 10.0, in_channels: Sequence[int] = [320, 640, 1280, 1280], embed_dim: int = 192, feature_dim: int = 1536, num_deconv_layers: int = 3, num_deconv_filters: Sequence[int] = (32, 32, 32), fmap_border: Union[int, Sequence[int]] = 0, align_corners: bool = False, loss_decode: dict = dict(type='SiLogLoss'), init_cfg=dict( type='TruncNormal', std=0.02, layer=['Conv2d', 'Linear']), ): super(BaseDecodeHead, self).__init__(init_cfg=init_cfg) # initialize parameters self.in_channels = in_channels self.max_depth = max_depth self.align_corners = align_corners # feature map border if isinstance(fmap_border, int): fmap_border = (fmap_border, fmap_border) self.fmap_border = fmap_border # define network layers self.conv1 = nn.Sequential( nn.Conv2d(in_channels[0], in_channels[0], 3, stride=2, padding=1), nn.GroupNorm(16, in_channels[0]), nn.ReLU(), nn.Conv2d(in_channels[0], in_channels[0], 3, stride=2, padding=1), ) self.conv2 = nn.Conv2d( in_channels[1], in_channels[1], 3, stride=2, padding=1) self.conv_aggregation = nn.Sequential( nn.Conv2d(sum(in_channels), feature_dim, 1), nn.GroupNorm(16, feature_dim), nn.ReLU(), ) self.decoder = VPDDepthDecoder( in_channels=embed_dim * 8, out_channels=embed_dim, num_deconv_layers=num_deconv_layers, num_deconv_filters=num_deconv_filters) self.depth_pred_layer = nn.Sequential( nn.Conv2d( embed_dim, embed_dim, kernel_size=3, stride=1, padding=1), nn.ReLU(inplace=False), nn.Conv2d(embed_dim, 1, kernel_size=3, stride=1, padding=1)) # build loss if isinstance(loss_decode, dict): self.loss_decode = MODELS.build(loss_decode) elif isinstance(loss_decode, (list, tuple)): self.loss_decode = nn.ModuleList() for loss in loss_decode: self.loss_decode.append(MODELS.build(loss)) else: raise TypeError(f'loss_decode must be a dict or sequence of dict,\ but got {type(loss_decode)}') def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tensor: gt_depth_maps = [ data_sample.gt_depth_map.data for data_sample in batch_data_samples ] return torch.stack(gt_depth_maps, dim=0) def forward(self, x): x = [ x[0], x[1], torch.cat([x[2], F.interpolate(x[3], scale_factor=2)], dim=1) ] x = torch.cat([self.conv1(x[0]), self.conv2(x[1]), x[2]], dim=1) x = self.conv_aggregation(x) x = x[:, :, :x.size(2) - self.fmap_border[0], :x.size(3) - self.fmap_border[1]].contiguous() x = self.decoder(x) out = self.depth_pred_layer(x) depth = torch.sigmoid(out) * self.max_depth return depth def loss_by_feat(self, pred_depth_map: Tensor, batch_data_samples: SampleList) -> dict: """Compute depth estimation loss. Args: pred_depth_map (Tensor): The output from decode head forward function. batch_data_samples (List[:obj:`SegDataSample`]): The seg data samples. It usually includes information such as `metainfo` and `gt_dpeth_map`. Returns: dict[str, Tensor]: a dictionary of loss components """ gt_depth_map = self._stack_batch_gt(batch_data_samples) loss = dict() pred_depth_map = resize( input=pred_depth_map, size=gt_depth_map.shape[2:], mode='bilinear', align_corners=self.align_corners) if not isinstance(self.loss_decode, nn.ModuleList): losses_decode = [self.loss_decode] else: losses_decode = self.loss_decode for loss_decode in losses_decode: if loss_decode.loss_name not in loss: loss[loss_decode.loss_name] = loss_decode( pred_depth_map, gt_depth_map) else: loss[loss_decode.loss_name] += loss_decode( pred_depth_map, gt_depth_map) return loss