From 6274fca833bc521c915e344963c05e6248e2976e Mon Sep 17 00:00:00 2001 From: wangyingming Date: Mon, 11 Oct 2021 14:33:28 +0800 Subject: [PATCH] remove not used codes and update the URL of the retrained model --- README.md | 12 ++++++------ main.py | 5 +---- models/anchor_detr.py | 8 ++------ models/transformer.py | 27 +++++++++++---------------- 4 files changed, 20 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 6655b05..f3a0cf1 100644 --- a/README.md +++ b/README.md @@ -22,7 +22,7 @@ We also propose an attention variant RCDA to reduce the memory cost for high-res | SMCA | multi-level | 50 | 43.7 | 152 | 10 | | Deformable DETR | multi-level | 50 | 43.8 | 173 | 15 | | Conditional DETR | DC5 | 50 | 43.8 | 195 | 10 | -| Anchor DETR | DC5 | 50 | 44.2 | 151 | 16 (19) | +| Anchor DETR | DC5 | 50 | 44.3 | 151 | 16 (19) | *Note:* @@ -34,12 +34,12 @@ We also propose an attention variant RCDA to reduce the memory cost for high-res ## Model | name | backbone | AP | URL | |:----------------:|:---------:|:-------:|:-----:| -| AnchorDETR-C5 | R50 | 42.1 | [model](https://drive.google.com/file/d/1FKDrTL7qg9riNN5a910Gzf4aZYJTHdT-/view?usp=sharing) / [log](https://drive.google.com/file/d/1b3jy9xkpLA0vi0GWlchtg4SY5jIqVz5S/view?usp=sharing) | -| AnchorDETR-DC5 | R50 | 44.2 | [model](https://drive.google.com/file/d/1ggsdoBOZa53S4h6Ur3rlK1-7-eABBlid/view?usp=sharing) / [log](https://drive.google.com/file/d/1S3rtBYMsAv437hGL0nm3JlYp6P0nqZfj/view?usp=sharing) | -| AnchorDETR-C5 | R101 | 43.5 | [model](https://drive.google.com/file/d/19CQqNvrrpdpSxIyn-2IPmLZOf2KP-Zft/view?usp=sharing) / [log](https://drive.google.com/file/d/1O4K00CLiMBaNu0x61xECg7Kek2Rf-tUr/view?usp=sharing) | -| AnchorDETR-DC5 | R101 | 45.1 | [model](https://drive.google.com/file/d/1bEnFnHCoDSVQ1u_q7B0gR3yxhq12Wevp/view?usp=sharing) / [log](https://drive.google.com/file/d/1wPeEf84zil8yPBLEnweONXadr5LrwXXv/view?usp=sharing) | +| AnchorDETR-C5 | R50 | 42.1 | [model](https://drive.google.com/file/d/1ktLJyw4PGdaXkOn61W537Z67WHcttXDs/view?usp=sharing) / [log](https://drive.google.com/file/d/1CoEUzs6pxYw-z1ew04qC1jFJwVjdDlPv/view?usp=sharing) | +| AnchorDETR-DC5 | R50 | 44.3 | [model](https://drive.google.com/file/d/1lJZWdIlHj6KKmAdU28Y01tTyO0hc6Jxs/view?usp=sharing) / [log](https://drive.google.com/file/d/1ywmE02P7ORj_1HQOR2lYW11kfuqX00v-/view?usp=sharing) | +| AnchorDETR-C5 | R101 | 43.5 | [model](https://drive.google.com/file/d/1eBLYzlKWwSF_RRcfjgRXqIplRKetsvtg/view?usp=sharing) / [log](https://drive.google.com/file/d/1XIDSpYCioYlK5NwdJnbUHQls-PUr_xwi/view?usp=sharing) | +| AnchorDETR-DC5 | R101 | 45.1 | [model](https://drive.google.com/file/d/1irmZPSALME4Nht3_qhM9WLExDyO9Sj-J/view?usp=sharing) / [log](https://drive.google.com/file/d/1KIIYid8mmoAWX7w6T6VPhORc86STqoXR/view?usp=sharing) | -*Note:* the models and logs are also available at [Baidu Netdisk](https://pan.baidu.com/s/1Fgx-YPQ0WdTuZIsbOv6hLw) with code `f56r`. +*Note:* the models and logs are also available at [Baidu Netdisk](https://pan.baidu.com/s/1iB8qtVPb9dWHYgA5z1I4xg) with code `hh13`. ## Usage diff --git a/main.py b/main.py index 31c04aa..dfdcef4 100644 --- a/main.py +++ b/main.py @@ -41,11 +41,8 @@ def get_args_parser(): parser.add_argument('--clip_max_norm', default=0.1, type=float, help='gradient clipping max norm') - parser.add_argument('--sgd', action='store_true') - parser.add_argument('--with_box_refine', default=False, action='store_true') - # Model parameters parser.add_argument('--frozen_weights', type=str, default=None, help="Path to the pretrained model. If set, only the mask head will be trained") @@ -209,7 +206,7 @@ def match_name_keywords(n, name_keywords): lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop) if args.distributed: - model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu],find_unused_parameters=True) + model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu]) model_without_ddp = model.module if args.dataset_file == "coco_panoptic": diff --git a/models/anchor_detr.py b/models/anchor_detr.py index 3e27b39..c2adc27 100644 --- a/models/anchor_detr.py +++ b/models/anchor_detr.py @@ -30,15 +30,13 @@ class AnchorDETR(nn.Module): """ This is the AnchorDETR module that performs object detection """ - def __init__(self, backbone, transformer, num_feature_levels, - aux_loss=True, with_box_refine=False): + def __init__(self, backbone, transformer, num_feature_levels, aux_loss=True): """ Initializes the model. Parameters: backbone: torch module of the backbone to be used. See backbone.py transformer: torch module of the transformer architecture. See transformer.py num_classes: number of object classes aux_loss: True if auxiliary decoding losses (loss at each decoder layer) are to be used. - with_box_refine: iterative bounding box refinement """ super().__init__() self.transformer = transformer @@ -69,7 +67,6 @@ def __init__(self, backbone, transformer, num_feature_levels, )]) self.backbone = backbone self.aux_loss = aux_loss - self.with_box_refine = with_box_refine for proj in self.input_proj: nn.init.xavier_uniform_(proj[0].weight, gain=1) @@ -367,8 +364,7 @@ def build(args): backbone, transformer, num_feature_levels=args.num_feature_levels, - aux_loss=args.aux_loss, - with_box_refine=args.with_box_refine + aux_loss=args.aux_loss ) if args.masks: model = DETRsegm(model, freeze_detr=(args.frozen_weights is not None)) diff --git a/models/transformer.py b/models/transformer.py index ff5ffd3..f72f1e8 100644 --- a/models/transformer.py +++ b/models/transformer.py @@ -52,7 +52,8 @@ def __init__(self, d_model=256, nhead=8, self.spatial_prior=spatial_prior - self.level_embed = nn.Embedding(num_feature_levels, d_model) + if num_feature_levels>1: + self.level_embed = nn.Embedding(num_feature_levels, d_model) self.num_pattern = num_query_pattern self.pattern = nn.Embedding(self.num_pattern, d_model) @@ -77,8 +78,6 @@ def __init__(self, d_model=256, nhead=8, self.class_embed = nn.Linear(d_model, num_classes) self.bbox_embed = MLP(d_model, d_model, 4, 3) - self.refine_box = False - self._reset_parameters() def _reset_parameters(self): @@ -94,14 +93,10 @@ def _reset_parameters(self): if self.spatial_prior == "learned": nn.init.uniform_(self.position.weight.data, 0, 1) - if self.refine_box: - self.class_embed = _get_clones(self.class_embed, num_pred) - self.bbox_embed = _get_clones(self.bbox_embed, num_pred) - nn.init.constant_(self.bbox_embed[0].layers[-1].bias.data[2:], -2.0) - else: - nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0) - self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)]) - self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)]) + nn.init.constant_(self.bbox_embed.layers[-1].bias.data[2:], -2.0) + self.class_embed = nn.ModuleList([self.class_embed for _ in range(num_pred)]) + self.bbox_embed = nn.ModuleList([self.bbox_embed for _ in range(num_pred)]) + def forward(self, srcs, masks): @@ -163,8 +158,6 @@ def forward(self, srcs, masks): outputs_coord = tmp.sigmoid() outputs_classes.append(outputs_class[None,]) outputs_coords.append(outputs_coord[None,]) - if self.refine_box: - reference_points = outputs_coord output = torch.cat(outputs_classes, dim=0), torch.cat(outputs_coords, dim=0) @@ -285,9 +278,11 @@ def __init__(self, d_model=256, d_ffn=1024, self.dropout2 = nn.Dropout(dropout) self.norm2 = nn.LayerNorm(d_model) - # self attention - self.self_attn_level = nn.MultiheadAttention(d_model, n_heads, dropout=dropout) - self.level_fc = nn.Linear(d_model * n_levels, d_model) + + # level combination + if n_levels>1: + self.level_fc = nn.Linear(d_model * n_levels, d_model) + # ffn self.ffn = FFN(d_model, d_ffn, dropout, activation)