Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix] Fix MaskFormer and Mask2Former of MMSegmentation #2532

Merged
merged 17 commits into from
Feb 1, 2023

Conversation

Li-Qingyun
Copy link
Contributor

@Li-Qingyun Li-Qingyun commented Jan 30, 2023

Motivation

The DETR-related modules have been refactored in open-mmlab/mmdetection#8763, which causes breakings of MaskFormer and Mask2Former in both MMDetection (has been fixed in open-mmlab/mmdetection#9515) and MMSegmentation. This pr fix the bugs in MMSegmentation.

TO-DO List

  • update configs
  • check and modify data flow
  • fix unit test
  • aligning inference
  • write a ckpt converter
  • write ckpt update script
  • update model zoo
  • update model link in readme
  • update faq.md

Tips of Fixing other implementations based on MaskXFormer of mmseg

  1. The Transformer modules should be built directly. The original building with register manner has been refactored.
  2. The config requires to be modified. Delete type and modify several keys, according to the modifications in this pr.
  3. The batch_first is set True uniformly in the new implementations. Hence the data flow requires to be transposed and config of batch_first needs to be modified.
  4. The checkpoint trained on the old implementation should be converted to be used in the new one.

Convert script

import argparse
from copy import deepcopy
from collections import OrderedDict

import torch

from mmengine.config import Config
from mmseg.models import build_segmentor
from mmseg.utils import register_all_modules
register_all_modules(init_default_scope=True)


def parse_args():
    parser = argparse.ArgumentParser(
        description='MMSeg convert MaskXFormer model, by Li-Qingyun')
    parser.add_argument('Mask_what_former', type=int,
                        help='Mask what former, can be a `1` or `2`',
                        choices=[1, 2])
    parser.add_argument('CFG_FILE', help='config file path')
    parser.add_argument('OLD_CKPT_FILEPATH', help='old ckpt file path')
    parser.add_argument('NEW_CKPT_FILEPATH', help='new ckpt file path')
    args = parser.parse_args()
    return args


args = parse_args()

def get_new_name(old_name: str):
    new_name = old_name

    if 'encoder.layers' in new_name:
        new_name = new_name.replace('attentions.0', 'self_attn')

    new_name = new_name.replace('ffns.0', 'ffn')

    if 'decoder.layers' in new_name:

        if args.Mask_what_former == 2:
            # for Mask2Former
            new_name = new_name.replace('attentions.0', 'cross_attn')
            new_name = new_name.replace('attentions.1', 'self_attn')
        else:
            # for Mask2Former
            new_name = new_name.replace('attentions.0', 'self_attn')
            new_name = new_name.replace('attentions.1', 'cross_attn')

    return new_name
    
def cvt_sd(old_sd: OrderedDict):
    new_sd = OrderedDict()
    for name, param in old_sd.items():
        new_name = get_new_name(name)
        assert new_name not in new_sd
        new_sd[new_name] = param
    assert len(new_sd) == len(old_sd)
    return new_sd
    
if __name__ == '__main__':
    cfg = Config.fromfile(args.CFG_FILE)
    model_cfg = cfg.model

    segmentor = build_segmentor(model_cfg)

    refer_sd = segmentor.state_dict()
    old_ckpt = torch.load(args.OLD_CKPT_FILEPATH)
    old_sd = old_ckpt['state_dict']

    new_sd = cvt_sd(old_sd)
    print(segmentor.load_state_dict(new_sd))

    new_ckpt = deepcopy(old_ckpt)
    new_ckpt['state_dict'] = new_sd
    torch.save(new_ckpt, args.NEW_CKPT_FILEPATH)
    print(f'{args.NEW_CKPT_FILEPATH} has been saved!')

Usage:

# for example
python ckpt4pr2532.py 1 configs/maskformer/maskformer_r50-d32_8xb2-160k_ade20k-512x512.py original_ckpts/maskformer_r50-d32_8xb2-160k_ade20k-512x512_20221030_182724-cbd39cc1.pth cvt_outputs/maskformer_r50-d32_8xb2-160k_ade20k-512x512_20221030_182724.pth
python ckpt4pr2532.py 2 configs/mask2former/mask2former_r50_8xb2-160k_ade20k-512x512.py original_ckpts/mask2former_r50_8xb2-160k_ade20k-512x512_20221204_000055-4c62652d.pth cvt_outputs/mask2former_r50_8xb2-160k_ade20k-512x512_20221204_000055.pth

@Li-Qingyun Li-Qingyun changed the title fix: update readme [Fix] Fix MaskFormer and Mask2Former of MMSegmentation Jan 30, 2023
@MengzhangLI
Copy link
Contributor

MengzhangLI commented Jan 30, 2023

Hi, qingyun, thanks for your PR.

We also need to change mmdet dependencies in faq.md before this PR being merged into dev-1.x branch.

Also, we need to change mmdet dependencies in ./github/workflows/ to make sure it could pass CI.

@Li-Qingyun
Copy link
Contributor Author

Li-Qingyun commented Jan 30, 2023

Hi, qingyun, thanks for your PR.

We also need to change mmdet dependencies in faq.md before this PR being merged into dev-1.x branch.

OKKKK, I've add this to TO-DO List.

@Li-Qingyun Li-Qingyun closed this Jan 30, 2023
@Li-Qingyun Li-Qingyun reopened this Jan 30, 2023
@Li-Qingyun
Copy link
Contributor Author

@MengzhangLI I have passed unit test of test_maskformer_head.py and test_mask2former_head.py.
maskformer
图片
mask2former
图片

And I have modify the ci to mmdet:dev-3.x, but ci field.

@Li-Qingyun
Copy link
Contributor Author

Inference test

Maskformer+r50 ADE20k
01/30 20:37:47 - mmengine - INFO - Iter(test) [2000/2000] aAcc: 80.1200 mIoU: 44.3000 mAcc: 56.3600
Mask2former+r50 ADE20k
01/30 20:46:46 - mmengine - INFO - Iter(test) [2000/2000] aAcc: 82.2500 mIoU: 47.8100 mAcc: 61.4000

@Li-Qingyun
Copy link
Contributor Author

@MeowZheng
scripts-pr2532.zip

Usage

# Put it in project folder of mmsegmentation
unzip scripts-pr2532.zip
bash download_ckpts.sh
bash cvt_ckpts.sh
bash publish_model.sh

@codecov
Copy link

codecov bot commented Jan 31, 2023

Codecov Report

Base: 83.25% // Head: 83.35% // Increases project coverage by +0.09% 🎉

Coverage data is based on head (81e5dd7) compared to base (124b87c).
Patch has no changes to coverable lines.

Additional details and impacted files
@@             Coverage Diff             @@
##           dev-1.x    #2532      +/-   ##
===========================================
+ Coverage    83.25%   83.35%   +0.09%     
===========================================
  Files          145      145              
  Lines         8505     8505              
  Branches      1273     1273              
===========================================
+ Hits          7081     7089       +8     
+ Misses        1213     1202      -11     
- Partials       211      214       +3     
Flag Coverage Δ
unittests 83.35% <ø> (+0.09%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Impacted Files Coverage Δ
mmseg/datasets/transforms/transforms.py 90.53% <0.00%> (+1.03%) ⬆️

Help us with your feedback. Take ten seconds to tell us how you rate us. Have a feature suggestion? Share it here.

☔ View full report at Codecov.
📢 Do you have feedback about the report comment? Let us know in this issue.

@MeowZheng MeowZheng merged commit a092fea into open-mmlab:dev-1.x Feb 1, 2023
@Li-Qingyun Li-Qingyun deleted the fix-maskxformer branch February 1, 2023 16:09
nahidnazifi87 pushed a commit to nahidnazifi87/mmsegmentation_playground that referenced this pull request Apr 5, 2024
## Motivation

The DETR-related modules have been refactored in
open-mmlab/mmdetection#8763, which causes breakings of MaskFormer and
Mask2Former in both MMDetection (has been fixed in
open-mmlab/mmdetection#9515) and MMSegmentation. This pr fix the bugs in
MMSegmentation.

### TO-DO List

- [x] update configs
- [x] check and modify data flow
- [x] fix unit test
- [x] aligning inference
- [x] write a ckpt converter
- [x] write ckpt update script
- [x] update model zoo
- [x] update model link in readme
- [x] update
[faq.md](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/docs/en/notes/faq.md#installation)

## Tips of Fixing other implementations based on MaskXFormer of mmseg

1. The Transformer modules should be built directly. The original
building with register manner has been refactored.
2. The config requires to be modified. Delete `type` and modify several
keys, according to the modifications in this pr.
3. The `batch_first` is set `True` uniformly in the new implementations.
Hence the data flow requires to be transposed and config of
`batch_first` needs to be modified.
4. The checkpoint trained on the old implementation should be converted
to be used in the new one.

### Convert script

```Python
import argparse
from copy import deepcopy
from collections import OrderedDict

import torch

from mmengine.config import Config
from mmseg.models import build_segmentor
from mmseg.utils import register_all_modules
register_all_modules(init_default_scope=True)


def parse_args():
    parser = argparse.ArgumentParser(
        description='MMSeg convert MaskXFormer model, by Li-Qingyun')
    parser.add_argument('Mask_what_former', type=int,
                        help='Mask what former, can be a `1` or `2`',
                        choices=[1, 2])
    parser.add_argument('CFG_FILE', help='config file path')
    parser.add_argument('OLD_CKPT_FILEPATH', help='old ckpt file path')
    parser.add_argument('NEW_CKPT_FILEPATH', help='new ckpt file path')
    args = parser.parse_args()
    return args


args = parse_args()

def get_new_name(old_name: str):
    new_name = old_name

    if 'encoder.layers' in new_name:
        new_name = new_name.replace('attentions.0', 'self_attn')

    new_name = new_name.replace('ffns.0', 'ffn')

    if 'decoder.layers' in new_name:

        if args.Mask_what_former == 2:
            # for Mask2Former
            new_name = new_name.replace('attentions.0', 'cross_attn')
            new_name = new_name.replace('attentions.1', 'self_attn')
        else:
            # for Mask2Former
            new_name = new_name.replace('attentions.0', 'self_attn')
            new_name = new_name.replace('attentions.1', 'cross_attn')

    return new_name
    
def cvt_sd(old_sd: OrderedDict):
    new_sd = OrderedDict()
    for name, param in old_sd.items():
        new_name = get_new_name(name)
        assert new_name not in new_sd
        new_sd[new_name] = param
    assert len(new_sd) == len(old_sd)
    return new_sd
    
if __name__ == '__main__':
    cfg = Config.fromfile(args.CFG_FILE)
    model_cfg = cfg.model

    segmentor = build_segmentor(model_cfg)

    refer_sd = segmentor.state_dict()
    old_ckpt = torch.load(args.OLD_CKPT_FILEPATH)
    old_sd = old_ckpt['state_dict']

    new_sd = cvt_sd(old_sd)
    print(segmentor.load_state_dict(new_sd))

    new_ckpt = deepcopy(old_ckpt)
    new_ckpt['state_dict'] = new_sd
    torch.save(new_ckpt, args.NEW_CKPT_FILEPATH)
    print(f'{args.NEW_CKPT_FILEPATH} has been saved!')
```

Usage:
```bash
# for example
python ckpt4pr2532.py 1 configs/maskformer/maskformer_r50-d32_8xb2-160k_ade20k-512x512.py original_ckpts/maskformer_r50-d32_8xb2-160k_ade20k-512x512_20221030_182724-cbd39cc1.pth cvt_outputs/maskformer_r50-d32_8xb2-160k_ade20k-512x512_20221030_182724.pth
python ckpt4pr2532.py 2 configs/mask2former/mask2former_r50_8xb2-160k_ade20k-512x512.py original_ckpts/mask2former_r50_8xb2-160k_ade20k-512x512_20221204_000055-4c62652d.pth cvt_outputs/mask2former_r50_8xb2-160k_ade20k-512x512_20221204_000055.pth
```

---------

Co-authored-by: MeowZheng <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants