Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

about Dense Query Recollection and Recurrence #12

Open
MinorityA opened this issue May 2, 2024 · 4 comments
Open

about Dense Query Recollection and Recurrence #12

MinorityA opened this issue May 2, 2024 · 4 comments

Comments

@MinorityA
Copy link

Hi, thank you for your amazing work.

I am particularly interested in the section on DQRR which your team has implemented on AdamXier. I am curious to know whether your team has also tested this on DETR, and if so, what the results were. Any details you could share would be greatly appreciated.

@Fangyi-Chen
Copy link
Owner

I did not test that, but I think it will work on DETR as well .
Thanks.

@MinorityA
Copy link
Author

I did not test that, but I think it will work on DETR as well . Thanks.

Hello! Could you please provide more information or guidance on how to correctly implement this step? Since I trained it on dab-deformable-detr following the instruction of the paper, feeding back outputs of layer 6 to itself again, but the AP result of evaluation when I shared the parameters of layer 6 for all layers was nearly 0.

@Fangyi-Chen
Copy link
Owner

Fangyi-Chen commented Jun 28, 2024

Hi, I get a draft implementation (but correct and runable) of DQRR on adamixer

In the implementation, fakesetsize is for acceleration and you can ignore it

if stage == self.num_stages, i.e., at the latest stage, we treat it differently from the other stages

During testing, you can use the last stage only.

def forward_train(self,
                  x,
                  query_xyzr,
                  query_content,
                  img_metas,
                  gt_bboxes,
                  gt_labels,
                  gt_bboxes_ignore=None,
                  imgs_whwh=None,
                  gt_masks=None):

    num_imgs = len(img_metas)
    num_queries = query_xyzr.size(1)
    imgs_whwh_keep = imgs_whwh.repeat(1, num_queries, 1)
    all_stage_bbox_results = []
    all_stage_loss = {}

    query_xyzr_list_reserve = [query_xyzr]
    query_content_list_reserve = [query_content]
    query_xyzr_list_reserve_last = []
    query_content_list_reserve_last = []

    batchsize = len(img_metas)
    fakesetsize = 2  # 8 will reduce 16 hours; 2 will reduce 9 hours; 4 will reduce 15 hours
    x_keep = [_ for _ in x]
    img_metas_keep = img_metas.copy()
    gt_bboxes_keep = gt_bboxes.copy()
    gt_labels_keep = gt_labels.copy()
    for stage in range(self.num_stages+1):

        if stage == self.num_stages: # at the latest stage
            query_xyzr = torch.cat(query_xyzr_list_reserve_last, dim=0)
            query_content = torch.cat(query_content_list_reserve_last, dim=0)
            setsize = int(len(query_content) / batchsize)
            if setsize > fakesetsize:
                single_stage_group_loss = []
                num_group = int(setsize / fakesetsize)


                for groupid in range(num_group):
                    query_xyzr_this_group = query_xyzr[fakesetsize * batchsize * groupid:fakesetsize * batchsize * (
                                groupid + 1)]
                    query_content_this_group = query_content[
                                               fakesetsize * batchsize * groupid:fakesetsize * batchsize * (
                                                           groupid + 1)]
                    bbox_results = self._bbox_forward(stage-1, x, query_xyzr_this_group, query_content_this_group,
                                                      img_metas)
                    # all_stage_bbox_results.append(bbox_results)
                    if gt_bboxes_ignore is None:
                        # TODO support ignore
                        gt_bboxes_ignore = [None for _ in range(num_imgs)]
                    sampling_results = []
                    cls_pred_list = bbox_results['detach_cls_score_list']
                    bboxes_list = bbox_results['detach_bboxes_list']


                    for i in range(num_imgs * fakesetsize):
                        normalize_bbox_ccwh = bbox_xyxy_to_cxcywh(bboxes_list[i] /
                                                                  imgs_whwh[i])
                        assign_result = self.bbox_assigner[stage-1].assign(
                            normalize_bbox_ccwh, cls_pred_list[i], gt_bboxes[i],
                            gt_labels[i], img_metas[i])
                        sampling_result = self.bbox_sampler[stage-1].sample(
                            assign_result, bboxes_list[i], gt_bboxes[i])
                        sampling_results.append(sampling_result)
                    bbox_targets = self.bbox_head[stage-1].get_targets(
                        sampling_results, gt_bboxes, gt_labels, self.train_cfg[stage-1],
                        True)

                    cls_score = bbox_results['cls_score']
                    decode_bbox_pred = bbox_results['decode_bbox_pred']

                    single_stage_group_loss.append(self.bbox_head[stage-1].loss(
                        cls_score.view(-1, cls_score.size(-1)),
                        decode_bbox_pred.view(-1, 4),
                        *bbox_targets,
                        imgs_whwh=imgs_whwh)
                    )

                # TODO: weight group loss: for the most important group weight it the highest
                # TODO: multiply fakesetsize for each loss or not multiply?  Do not forget to modify the setsize below
                for groupid, single_stage_single_group_loss in enumerate(single_stage_group_loss):
                    if groupid == 0:
                        for key, value in single_stage_single_group_loss.items():
                            all_stage_loss[f'stage{stage}_{key}'] = value * \
                                                                    self.stage_loss_weights[stage-1] * fakesetsize
                    else:
                        for key, value in single_stage_single_group_loss.items():
                            all_stage_loss[f'stage{stage}_{key}'] += value * \
                                                                     self.stage_loss_weights[stage-1] * fakesetsize
            else:

                bbox_results = self._bbox_forward(stage-1, x, query_xyzr, query_content,
                                                  img_metas)
                all_stage_bbox_results.append(bbox_results)
                if gt_bboxes_ignore is None:
                    # TODO support ignore
                    gt_bboxes_ignore = [None for _ in range(num_imgs)]
                sampling_results = []
                cls_pred_list = bbox_results['detach_cls_score_list']
                bboxes_list = bbox_results['detach_bboxes_list']

                for i in range(num_imgs * setsize):
                    normalize_bbox_ccwh = bbox_xyxy_to_cxcywh(bboxes_list[i] /
                                                              imgs_whwh[i])
                    assign_result = self.bbox_assigner[stage-1].assign(
                        normalize_bbox_ccwh, cls_pred_list[i], gt_bboxes[i],
                        gt_labels[i], img_metas[i])
                    sampling_result = self.bbox_sampler[stage-1].sample(
                        assign_result, bboxes_list[i], gt_bboxes[i])
                    sampling_results.append(sampling_result)
                bbox_targets = self.bbox_head[stage-1].get_targets(
                    sampling_results, gt_bboxes, gt_labels, self.train_cfg[stage-1],
                    True)

                cls_score = bbox_results['cls_score']
                decode_bbox_pred = bbox_results['decode_bbox_pred']

                single_stage_group_loss = self.bbox_head[stage-1].loss(
                    cls_score.view(-1, cls_score.size(-1)),
                    decode_bbox_pred.view(-1, 4),
                    *bbox_targets,
                    imgs_whwh=imgs_whwh)

                # TODO: multiply setsize for each loss or not multiply?  Do not forget to modify the fakesetsize above
                for key, value in single_stage_group_loss.items():
                    all_stage_loss[f'stage{stage}_{key}'] = value * \
                                                            self.stage_loss_weights[stage-1] * setsize

            return all_stage_loss

        query_xyzr = torch.cat(query_xyzr_list_reserve, dim=0)
        query_content = torch.cat(query_content_list_reserve, dim=0)
        setsize = int(len(query_content) / batchsize)

        if setsize > fakesetsize:
            single_stage_group_loss = []
            num_group = int(setsize / fakesetsize)

            x = [x_.repeat(fakesetsize, 1, 1, 1) for x_ in x_keep]
            img_metas = img_metas_keep * fakesetsize
            gt_bboxes = gt_bboxes_keep * fakesetsize
            gt_labels = gt_labels_keep * fakesetsize
            imgs_whwh = imgs_whwh_keep.repeat(fakesetsize, 1, 1)

            for groupid in range(num_group):
                query_xyzr_this_group = query_xyzr[fakesetsize*batchsize*groupid:fakesetsize*batchsize*(groupid+1)]
                query_content_this_group = query_content[fakesetsize*batchsize*groupid:fakesetsize*batchsize*(groupid+1)]
                bbox_results = self._bbox_forward(stage, x, query_xyzr_this_group, query_content_this_group,
                                                  img_metas)
                # all_stage_bbox_results.append(bbox_results)
                if gt_bboxes_ignore is None:
                    # TODO support ignore
                    gt_bboxes_ignore = [None for _ in range(num_imgs)]
                sampling_results = []
                cls_pred_list = bbox_results['detach_cls_score_list']
                bboxes_list = bbox_results['detach_bboxes_list']

                query_xyzr_new = bbox_results['query_xyzr'].detach()
                query_content_new = bbox_results['query_content']
                # TODO: detach query content for noisy querys because not going to use them anyway?
                # TODO: only append important query groups, e.x. from the last layer
                if stage == self.num_stages - 1:
                    query_xyzr_list_reserve_last.append(query_xyzr_new)
                    query_content_list_reserve_last.append(query_content_new)
                else:
                    query_xyzr_list_reserve.append(query_xyzr_new)
                    query_content_list_reserve.append(query_content_new)

                for i in range(num_imgs * fakesetsize):
                    normalize_bbox_ccwh = bbox_xyxy_to_cxcywh(bboxes_list[i] /
                                                              imgs_whwh[i])
                    assign_result = self.bbox_assigner[stage].assign(
                        normalize_bbox_ccwh, cls_pred_list[i], gt_bboxes[i],
                        gt_labels[i], img_metas[i])
                    sampling_result = self.bbox_sampler[stage].sample(
                        assign_result, bboxes_list[i], gt_bboxes[i])
                    sampling_results.append(sampling_result)
                bbox_targets = self.bbox_head[stage].get_targets(
                    sampling_results, gt_bboxes, gt_labels, self.train_cfg[stage],
                    True)

                cls_score = bbox_results['cls_score']
                decode_bbox_pred = bbox_results['decode_bbox_pred']

                single_stage_group_loss.append(self.bbox_head[stage].loss(
                    cls_score.view(-1, cls_score.size(-1)),
                    decode_bbox_pred.view(-1, 4),
                    *bbox_targets,
                    imgs_whwh=imgs_whwh)
                )

            # TODO: weight group loss: for the most important group weight it the highest
            # TODO: multiply fakesetsize for each loss or not multiply?  Do not forget to modify the setsize below
            for groupid, single_stage_single_group_loss in enumerate(single_stage_group_loss):
                if groupid == 0:
                    for key, value in single_stage_single_group_loss.items():
                        all_stage_loss[f'stage{stage}_{key}'] = value * \
                                                                self.stage_loss_weights[stage] * fakesetsize
                else:
                    for key, value in single_stage_single_group_loss.items():
                        all_stage_loss[f'stage{stage}_{key}'] += value * \
                                                                 self.stage_loss_weights[stage] * fakesetsize
        else:
            x = [x_.repeat(setsize, 1, 1, 1) for x_ in x_keep]
            img_metas = img_metas_keep * setsize
            gt_bboxes = gt_bboxes_keep * setsize
            gt_labels = gt_labels_keep * setsize
            imgs_whwh = imgs_whwh_keep.repeat(setsize, 1, 1)

            bbox_results = self._bbox_forward(stage, x, query_xyzr, query_content,
                                              img_metas)
            all_stage_bbox_results.append(bbox_results)
            if gt_bboxes_ignore is None:
                # TODO support ignore
                gt_bboxes_ignore = [None for _ in range(num_imgs)]
            sampling_results = []
            cls_pred_list = bbox_results['detach_cls_score_list']
            bboxes_list = bbox_results['detach_bboxes_list']

            query_xyzr_new = bbox_results['query_xyzr'].detach()
            query_content_new = bbox_results['query_content']
            # TODO: detach query content for noisy querys because not going to use them anyway?
            # TODO: only append important query groups, e.x. from the last layer
            if stage == self.num_stages - 1:
                query_xyzr_list_reserve_last.append(query_xyzr_new)
                query_content_list_reserve_last.append(query_content_new)
            else:
                query_xyzr_list_reserve.append(query_xyzr_new)
                query_content_list_reserve.append(query_content_new)

            for i in range(num_imgs * setsize):
                normalize_bbox_ccwh = bbox_xyxy_to_cxcywh(bboxes_list[i] /
                                                          imgs_whwh[i])
                assign_result = self.bbox_assigner[stage].assign(
                    normalize_bbox_ccwh, cls_pred_list[i], gt_bboxes[i],
                    gt_labels[i], img_metas[i])
                sampling_result = self.bbox_sampler[stage].sample(
                    assign_result, bboxes_list[i], gt_bboxes[i])
                sampling_results.append(sampling_result)
            bbox_targets = self.bbox_head[stage].get_targets(
                sampling_results, gt_bboxes, gt_labels, self.train_cfg[stage],
                True)

            cls_score = bbox_results['cls_score']
            decode_bbox_pred = bbox_results['decode_bbox_pred']

            single_stage_group_loss = self.bbox_head[stage].loss(
                cls_score.view(-1, cls_score.size(-1)),
                decode_bbox_pred.view(-1, 4),
                *bbox_targets,
                imgs_whwh=imgs_whwh)

            # TODO: multiply setsize for each loss or not multiply?  Do not forget to modify the fakesetsize above
            for key, value in single_stage_group_loss.items():
                all_stage_loss[f'stage{stage}_{key}'] = value * \
                                                        self.stage_loss_weights[stage] * setsize

    #print(all_stage_loss)
    #print(all_stage_lossa)

@Fangyi-Chen
Copy link
Owner

I did not test that, but I think it will work on DETR as well . Thanks.

Hello! Could you please provide more information or guidance on how to correctly implement this step? Since I trained it on dab-deformable-detr following the instruction of the paper, feeding back outputs of layer 6 to itself again, but the AP result of evaluation when I shared the parameters of layer 6 for all layers was nearly 0.

I'm not sure, but did you make the order of post-norm/pre-norm correctly when you implement the recurrence?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants