-
Notifications
You must be signed in to change notification settings - Fork 1k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Support tuning parameter-level optim hyperparameters. (#463)
* [Feature] Support tuning parameter-level optim hyperparameters. * save scheduler state * update * update * update ut
- Loading branch information
Showing
9 changed files
with
386 additions
and
35 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,130 @@ | ||
save_dir: workspace/convnext/nanodet-plus_convnext-nano_640 | ||
model: | ||
weight_averager: | ||
name: ExpMovingAverager | ||
decay: 0.9998 | ||
arch: | ||
name: NanoDetPlus | ||
detach_epoch: 10 | ||
backbone: | ||
name: TIMMWrapper | ||
model_name: convnext_nano | ||
features_only: True | ||
pretrained: True | ||
# output_stride: 32 | ||
out_indices: [1, 2, 3] | ||
fpn: | ||
name: GhostPAN | ||
in_channels: [160, 320, 640] | ||
out_channels: 128 | ||
kernel_size: 5 | ||
num_extra_level: 1 | ||
use_depthwise: True | ||
activation: SiLU | ||
head: | ||
name: NanoDetPlusHead | ||
num_classes: 80 | ||
input_channel: 128 | ||
feat_channels: 128 | ||
stacked_convs: 2 | ||
kernel_size: 5 | ||
strides: [8, 16, 32, 64] | ||
activation: SiLU | ||
reg_max: 7 | ||
norm_cfg: | ||
type: BN | ||
loss: | ||
loss_qfl: | ||
name: QualityFocalLoss | ||
use_sigmoid: True | ||
beta: 2.0 | ||
loss_weight: 1.0 | ||
loss_dfl: | ||
name: DistributionFocalLoss | ||
loss_weight: 0.25 | ||
loss_bbox: | ||
name: GIoULoss | ||
loss_weight: 2.0 | ||
# Auxiliary head, only use in training time. | ||
aux_head: | ||
name: SimpleConvHead | ||
num_classes: 80 | ||
input_channel: 256 | ||
feat_channels: 256 | ||
stacked_convs: 4 | ||
strides: [8, 16, 32, 64] | ||
activation: SiLU | ||
reg_max: 7 | ||
data: | ||
train: | ||
name: CocoDataset | ||
img_path: coco/train2017 | ||
ann_path: coco/annotations/instances_train2017.json | ||
input_size: [640,640] #[w,h] | ||
keep_ratio: False | ||
pipeline: | ||
perspective: 0.0 | ||
scale: [0.1, 2.0] | ||
stretch: [[0.8, 1.2], [0.8, 1.2]] | ||
rotation: 0 | ||
shear: 0 | ||
translate: 0.2 | ||
flip: 0.5 | ||
brightness: 0.2 | ||
contrast: [0.6, 1.4] | ||
saturation: [0.5, 1.2] | ||
normalize: [[103.53, 116.28, 123.675], [57.375, 57.12, 58.395]] | ||
val: | ||
name: CocoDataset | ||
img_path: coco/val2017 | ||
ann_path: coco/annotations/instances_val2017.json | ||
input_size: [640,640] #[w,h] | ||
keep_ratio: False | ||
pipeline: | ||
normalize: [[103.53, 116.28, 123.675], [57.375, 57.12, 58.395]] | ||
device: | ||
gpu_ids: [0, 1, 2, 3] | ||
workers_per_gpu: 8 | ||
batchsize_per_gpu: 24 | ||
schedule: | ||
# resume: | ||
# load_model: | ||
optimizer: | ||
name: AdamW | ||
lr: 0.001 | ||
weight_decay: 0.05 | ||
no_norm_decay: True | ||
param_level_cfg: | ||
backbone: | ||
lr_mult: 0.1 | ||
warmup: | ||
name: linear | ||
steps: 500 | ||
ratio: 0.0001 | ||
total_epochs: 50 | ||
lr_schedule: | ||
name: CosineAnnealingLR | ||
T_max: 50 | ||
eta_min: 0.0005 | ||
val_intervals: 5 | ||
grad_clip: 35 | ||
evaluator: | ||
name: CocoDetectionEvaluator | ||
save_key: mAP | ||
log: | ||
interval: 50 | ||
|
||
class_names: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', | ||
'train', 'truck', 'boat', 'traffic_light', 'fire_hydrant', | ||
'stop_sign', 'parking_meter', 'bench', 'bird', 'cat', 'dog', | ||
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', | ||
'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee', | ||
'skis', 'snowboard', 'sports_ball', 'kite', 'baseball_bat', | ||
'baseball_glove', 'skateboard', 'surfboard', 'tennis_racket', | ||
'bottle', 'wine_glass', 'cup', 'fork', 'knife', 'spoon', 'bowl', | ||
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', | ||
'hot_dog', 'pizza', 'donut', 'cake', 'chair', 'couch', | ||
'potted_plant', 'bed', 'dining_table', 'toilet', 'tv', 'laptop', | ||
'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave', | ||
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', | ||
'vase', 'scissors', 'teddy_bear', 'hair_drier', 'toothbrush'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .builder import build_optimizer | ||
|
||
__all__ = ["build_optimizer"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
import copy | ||
import logging | ||
|
||
import torch | ||
from torch.nn import GroupNorm, LayerNorm | ||
from torch.nn.modules.batchnorm import _BatchNorm | ||
|
||
NORMS = (GroupNorm, LayerNorm, _BatchNorm) | ||
|
||
|
||
def build_optimizer(model, config): | ||
"""Build optimizer from config. | ||
Supports customised parameter-level hyperparameters. | ||
The config should be like: | ||
>>> optimizer: | ||
>>> name: AdamW | ||
>>> lr: 0.001 | ||
>>> weight_decay: 0.05 | ||
>>> no_norm_decay: True | ||
>>> param_level_cfg: # parameter-level config | ||
>>> backbone: | ||
>>> lr_mult: 0.1 | ||
""" | ||
config = copy.deepcopy(config) | ||
param_dict = {} | ||
no_norm_decay = config.pop("no_norm_decay", False) | ||
no_bias_decay = config.pop("no_bias_decay", False) | ||
param_level_cfg = config.pop("param_level_cfg", {}) | ||
base_lr = config.get("lr", None) | ||
base_wd = config.get("weight_decay", None) | ||
|
||
name = config.pop("name") | ||
optim_cls = getattr(torch.optim, name) | ||
|
||
logger = logging.getLogger("NanoDet") | ||
|
||
# custom param-wise lr and weight_decay | ||
for name, p in model.named_parameters(): | ||
if not p.requires_grad: | ||
continue | ||
param_dict[p] = {"name": name} | ||
|
||
for key in param_level_cfg: | ||
if key in name: | ||
if "lr_mult" in param_level_cfg[key] and base_lr: | ||
param_dict[p].update( | ||
{"lr": base_lr * param_level_cfg[key]["lr_mult"]} | ||
) | ||
if "decay_mult" in param_level_cfg[key] and base_wd: | ||
param_dict[p].update( | ||
{"weight_decay": base_wd * param_level_cfg[key]["decay_mult"]} | ||
) | ||
break | ||
if no_norm_decay: | ||
# update norms decay | ||
for name, m in model.named_modules(): | ||
if isinstance(m, NORMS): | ||
param_dict[m.bias].update({"weight_decay": 0}) | ||
param_dict[m.weight].update({"weight_decay": 0}) | ||
if no_bias_decay: | ||
# update bias decay | ||
for name, m in model.named_modules(): | ||
if hasattr(m, "bias"): | ||
param_dict[m.bias].update({"weight_decay": 0}) | ||
|
||
# convert param dict to optimizer's param groups | ||
param_groups = [] | ||
for p, pconfig in param_dict.items(): | ||
name = pconfig.pop("name", None) | ||
if "weight_decay" in pconfig or "lr" in pconfig: | ||
logger.info(f"special optimizer hyperparameter: {name} - {pconfig}") | ||
param_groups += [{"params": p, **pconfig}] | ||
|
||
optimizer = optim_cls(param_groups, **config) | ||
return optimizer |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import os | ||
import platform | ||
import warnings | ||
|
||
import torch.multiprocessing as mp | ||
|
||
|
||
def set_multi_processing( | ||
mp_start_method: str = "fork", opencv_num_threads: int = 0, distributed: bool = True | ||
) -> None: | ||
"""Set multi-processing related environment. | ||
This function is refered from https://github.com/open-mmlab/mmengine/blob/main/mmengine/utils/dl_utils/setup_env.py | ||
Args: | ||
mp_start_method (str): Set the method which should be used to start | ||
child processes. Defaults to 'fork'. | ||
opencv_num_threads (int): Number of threads for opencv. | ||
Defaults to 0. | ||
distributed (bool): True if distributed environment. | ||
Defaults to False. | ||
""" # noqa | ||
# set multi-process start method as `fork` to speed up the training | ||
if platform.system() != "Windows": | ||
current_method = mp.get_start_method(allow_none=True) | ||
if current_method is not None and current_method != mp_start_method: | ||
warnings.warn( | ||
f"Multi-processing start method `{mp_start_method}` is " | ||
f"different from the previous setting `{current_method}`." | ||
f"It will be force set to `{mp_start_method}`. You can " | ||
"change this behavior by changing `mp_start_method` in " | ||
"your config." | ||
) | ||
mp.set_start_method(mp_start_method, force=True) | ||
|
||
try: | ||
import cv2 | ||
|
||
# disable opencv multithreading to avoid system being overloaded | ||
cv2.setNumThreads(opencv_num_threads) | ||
except ImportError: | ||
pass | ||
|
||
# setup OMP threads | ||
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa | ||
if "OMP_NUM_THREADS" not in os.environ and distributed: | ||
omp_num_threads = 1 | ||
warnings.warn( | ||
"Setting OMP_NUM_THREADS environment variable for each process" | ||
f" to be {omp_num_threads} in default, to avoid your system " | ||
"being overloaded, please further tune the variable for " | ||
"optimal performance in your application as needed." | ||
) | ||
os.environ["OMP_NUM_THREADS"] = str(omp_num_threads) | ||
|
||
# setup MKL threads | ||
if "MKL_NUM_THREADS" not in os.environ and distributed: | ||
mkl_num_threads = 1 | ||
warnings.warn( | ||
"Setting MKL_NUM_THREADS environment variable for each process" | ||
f" to be {mkl_num_threads} in default, to avoid your system " | ||
"being overloaded, please further tune the variable for " | ||
"optimal performance in your application as needed." | ||
) | ||
os.environ["MKL_NUM_THREADS"] = str(mkl_num_threads) |
Oops, something went wrong.