Skip to content

Commit

Permalink
add imagenet22k dataset and some minor fixes (microsoft#208)
Browse files Browse the repository at this point in the history
  • Loading branch information
zeliu98 authored May 15, 2022
1 parent 2519a3a commit eda255c
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 9 deletions.
1 change: 0 additions & 1 deletion config.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,6 @@
_C.MODEL.SWINV2.WINDOW_SIZE = 7
_C.MODEL.SWINV2.MLP_RATIO = 4.
_C.MODEL.SWINV2.QKV_BIAS = True
_C.MODEL.SWINV2.QK_SCALE = None
_C.MODEL.SWINV2.APE = False
_C.MODEL.SWINV2.PATCH_NORM = True
_C.MODEL.SWINV2.PRETRAINED_WINDOW_SIZES = [0, 0, 0, 0]
Expand Down
9 changes: 8 additions & 1 deletion data/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from timm.data import create_transform

from .cached_image_folder import CachedImageFolder
from .imagenet22k_dataset import IN22KDATASET
from .samplers import SubsetRandomSampler

try:
Expand Down Expand Up @@ -108,7 +109,13 @@ def build_dataset(is_train, config):
dataset = datasets.ImageFolder(root, transform=transform)
nb_classes = 1000
elif config.DATA.DATASET == 'imagenet22K':
raise NotImplementedError("Imagenet-22K will come soon.")
prefix = 'ILSVRC2011fall_whole'
if is_train:
ann_file = prefix + "_map_train.txt"
else:
ann_file = prefix + "_map_val.txt"
dataset = IN22KDATASET(config.DATA.DATA_PATH, ann_file, transform)
nb_classes = 21841
else:
raise NotImplementedError("We only support ImageNet Now.")

Expand Down
55 changes: 55 additions & 0 deletions data/imagenet22k_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import os
import json
import torch.utils.data as data
import numpy as np
from PIL import Image

import warnings

warnings.filterwarnings("ignore", "(Possibly )?corrupt EXIF data", UserWarning)


class IN22KDATASET(data.Dataset):
def __init__(self, root, ann_file='', transform=None, target_transform=None):
super(IN22KDATASET, self).__init__()

self.data_path = root
self.ann_path = os.path.join(self.data_path, ann_file)
self.transform = transform
self.target_transform = target_transform
# id & label: https://github.com/google-research/big_transfer/issues/7
# total: 21843; only 21841 class have images: map 21841->9205; 21842->15027
self.database = json.load(open(self.ann_path))

def _load_image(self, path):
try:
im = Image.open(path)
except:
print("ERROR IMG LOADED: ", path)
random_img = np.random.rand(224, 224, 3) * 255
im = Image.fromarray(np.uint8(random_img))
return im

def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is class_index of the target class.
"""
idb = self.database[index]

# images
images = self._load_image(self.data_path + '/' + idb[0]).convert('RGB')
if self.transform is not None:
images = self.transform(images)

# target
target = int(idb[1])
if self.target_transform is not None:
target = self.target_transform(target)

return images, target

def __len__(self):
return len(self.database)
30 changes: 29 additions & 1 deletion get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,24 @@ load data:
n01440764/n01440764_10040.JPEG 0
n01440764/n01440764_10042.JPEG 0
```
- For ImageNet-22K dataset, make a folder named `fall11_whole` and move all images to labeled sub-folders in this
folder. Then download the train-val split
file ([ILSVRC2011fall_whole_map_train.txt](https://github.com/SwinTransformer/storage/releases/download/v2.0.1/ILSVRC2011fall_whole_map_train.txt)
& [ILSVRC2011fall_whole_map_val.txt](https://github.com/SwinTransformer/storage/releases/download/v2.0.1/ILSVRC2011fall_whole_map_val.txt))
, and put them in the parent directory of `fall11_whole`. The file structure should look like:

- ```bash
$ tree imagenet22k/
imagenet22k/
├── ILSVRC2011fall_whole_map_train.txt
├── ILSVRC2011fall_whole_map_val.txt
└── fall11_whole
├── n00004475
├── n00005787
├── n00006024
├── n00006484
└── ...
```

### Evaluation

Expand All @@ -140,7 +158,7 @@ python -m torch.distributed.launch --nproc_per_node 1 --master_port 12345 main.p
--cfg configs/swin/swin_base_patch4_window7_224.yaml --resume swin_base_patch4_window7_224.pth --data-path <imagenet-path>
```

### Training from scratch
### Training from scratch on ImageNet-1K

To train a `Swin Transformer` on ImageNet from scratch, run:

Expand Down Expand Up @@ -188,6 +206,16 @@ python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.
--accumulation-steps 2 [--use-checkpoint]
```

### Pre-training on ImageNet-22K

For example, to pre-train a `Swin-B` model on ImageNet-22K:

```bash
python -m torch.distributed.launch --nproc_per_node 8 --master_port 12345 main.py \
--cfg configs/swin/swin_base_patch4_window7_224_22k.yaml --data-path <imagenet22k-path> --batch-size 64 \
--accumulation-steps 8 [--use-checkpoint]
```

### Fine-tuning on higher resolution

For example, to fine-tune a `Swin-B` model pre-trained on 224x224 resolution to 384x384 resolution:
Expand Down
6 changes: 0 additions & 6 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,6 @@
import torch.distributed as dist
from torch._six import inf

try:
# noinspection PyUnresolvedReferences
from apex import amp
except ImportError:
amp = None


def load_checkpoint(config, model, optimizer, lr_scheduler, loss_scaler, logger):
logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................")
Expand Down

0 comments on commit eda255c

Please sign in to comment.