Skip to content

Commit

Permalink
[Feature] Support tuning parameter-level optim hyperparameters. (#463)
Browse files Browse the repository at this point in the history
* [Feature] Support tuning parameter-level optim hyperparameters.

* save scheduler state

* update

* update

* update ut
  • Loading branch information
RangiLyu committed Sep 29, 2022
1 parent a3b3452 commit ad410c2
Show file tree
Hide file tree
Showing 9 changed files with 386 additions and 35 deletions.
130 changes: 130 additions & 0 deletions config/convnext/nanodet-plus_convnext-nano_640.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
save_dir: workspace/convnext/nanodet-plus_convnext-nano_640
model:
weight_averager:
name: ExpMovingAverager
decay: 0.9998
arch:
name: NanoDetPlus
detach_epoch: 10
backbone:
name: TIMMWrapper
model_name: convnext_nano
features_only: True
pretrained: True
# output_stride: 32
out_indices: [1, 2, 3]
fpn:
name: GhostPAN
in_channels: [160, 320, 640]
out_channels: 128
kernel_size: 5
num_extra_level: 1
use_depthwise: True
activation: SiLU
head:
name: NanoDetPlusHead
num_classes: 80
input_channel: 128
feat_channels: 128
stacked_convs: 2
kernel_size: 5
strides: [8, 16, 32, 64]
activation: SiLU
reg_max: 7
norm_cfg:
type: BN
loss:
loss_qfl:
name: QualityFocalLoss
use_sigmoid: True
beta: 2.0
loss_weight: 1.0
loss_dfl:
name: DistributionFocalLoss
loss_weight: 0.25
loss_bbox:
name: GIoULoss
loss_weight: 2.0
# Auxiliary head, only use in training time.
aux_head:
name: SimpleConvHead
num_classes: 80
input_channel: 256
feat_channels: 256
stacked_convs: 4
strides: [8, 16, 32, 64]
activation: SiLU
reg_max: 7
data:
train:
name: CocoDataset
img_path: coco/train2017
ann_path: coco/annotations/instances_train2017.json
input_size: [640,640] #[w,h]
keep_ratio: False
pipeline:
perspective: 0.0
scale: [0.1, 2.0]
stretch: [[0.8, 1.2], [0.8, 1.2]]
rotation: 0
shear: 0
translate: 0.2
flip: 0.5
brightness: 0.2
contrast: [0.6, 1.4]
saturation: [0.5, 1.2]
normalize: [[103.53, 116.28, 123.675], [57.375, 57.12, 58.395]]
val:
name: CocoDataset
img_path: coco/val2017
ann_path: coco/annotations/instances_val2017.json
input_size: [640,640] #[w,h]
keep_ratio: False
pipeline:
normalize: [[103.53, 116.28, 123.675], [57.375, 57.12, 58.395]]
device:
gpu_ids: [0, 1, 2, 3]
workers_per_gpu: 8
batchsize_per_gpu: 24
schedule:
# resume:
# load_model:
optimizer:
name: AdamW
lr: 0.001
weight_decay: 0.05
no_norm_decay: True
param_level_cfg:
backbone:
lr_mult: 0.1
warmup:
name: linear
steps: 500
ratio: 0.0001
total_epochs: 50
lr_schedule:
name: CosineAnnealingLR
T_max: 50
eta_min: 0.0005
val_intervals: 5
grad_clip: 35
evaluator:
name: CocoDetectionEvaluator
save_key: mAP
log:
interval: 50

class_names: ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic_light', 'fire_hydrant',
'stop_sign', 'parking_meter', 'bench', 'bird', 'cat', 'dog',
'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports_ball', 'kite', 'baseball_bat',
'baseball_glove', 'skateboard', 'surfboard', 'tennis_racket',
'bottle', 'wine_glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
'hot_dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
'potted_plant', 'bed', 'dining_table', 'toilet', 'tv', 'laptop',
'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave',
'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
'vase', 'scissors', 'teddy_bear', 'hair_drier', 'toothbrush']
3 changes: 3 additions & 0 deletions nanodet/optim/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .builder import build_optimizer

__all__ = ["build_optimizer"]
76 changes: 76 additions & 0 deletions nanodet/optim/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import copy
import logging

import torch
from torch.nn import GroupNorm, LayerNorm
from torch.nn.modules.batchnorm import _BatchNorm

NORMS = (GroupNorm, LayerNorm, _BatchNorm)


def build_optimizer(model, config):
"""Build optimizer from config.
Supports customised parameter-level hyperparameters.
The config should be like:
>>> optimizer:
>>> name: AdamW
>>> lr: 0.001
>>> weight_decay: 0.05
>>> no_norm_decay: True
>>> param_level_cfg: # parameter-level config
>>> backbone:
>>> lr_mult: 0.1
"""
config = copy.deepcopy(config)
param_dict = {}
no_norm_decay = config.pop("no_norm_decay", False)
no_bias_decay = config.pop("no_bias_decay", False)
param_level_cfg = config.pop("param_level_cfg", {})
base_lr = config.get("lr", None)
base_wd = config.get("weight_decay", None)

name = config.pop("name")
optim_cls = getattr(torch.optim, name)

logger = logging.getLogger("NanoDet")

# custom param-wise lr and weight_decay
for name, p in model.named_parameters():
if not p.requires_grad:
continue
param_dict[p] = {"name": name}

for key in param_level_cfg:
if key in name:
if "lr_mult" in param_level_cfg[key] and base_lr:
param_dict[p].update(
{"lr": base_lr * param_level_cfg[key]["lr_mult"]}
)
if "decay_mult" in param_level_cfg[key] and base_wd:
param_dict[p].update(
{"weight_decay": base_wd * param_level_cfg[key]["decay_mult"]}
)
break
if no_norm_decay:
# update norms decay
for name, m in model.named_modules():
if isinstance(m, NORMS):
param_dict[m.bias].update({"weight_decay": 0})
param_dict[m.weight].update({"weight_decay": 0})
if no_bias_decay:
# update bias decay
for name, m in model.named_modules():
if hasattr(m, "bias"):
param_dict[m.bias].update({"weight_decay": 0})

# convert param dict to optimizer's param groups
param_groups = []
for p, pconfig in param_dict.items():
name = pconfig.pop("name", None)
if "weight_decay" in pconfig or "lr" in pconfig:
logger.info(f"special optimizer hyperparameter: {name} - {pconfig}")
param_groups += [{"params": p, **pconfig}]

optimizer = optim_cls(param_groups, **config)
return optimizer
41 changes: 15 additions & 26 deletions nanodet/trainer/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from pytorch_lightning.utilities import rank_zero_only

from nanodet.data.batch_process import stack_batch_img
from nanodet.optim import build_optimizer
from nanodet.util import convert_avg_params, gather_results, mkdir

from ..model.arch import build_model
Expand Down Expand Up @@ -81,7 +82,7 @@ def training_step(self, batch, batch_idx):
memory = (
torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
)
lr = self.optimizers().param_groups[0]["lr"]
lr = self.trainer.optimizers[0].param_groups[0]["lr"]
log_msg = "Train|Epoch{}/{}|Iter{}({}/{})| mem:{:.3g}G| lr:{:.2e}| ".format(
self.current_epoch + 1,
self.cfg.schedule.total_epochs,
Expand All @@ -108,7 +109,6 @@ def training_step(self, batch, batch_idx):

def training_epoch_end(self, outputs: List[Any]) -> None:
self.trainer.save_checkpoint(os.path.join(self.cfg.save_dir, "model_last.ckpt"))
self.lr_scheduler.step()

def validation_step(self, batch, batch_idx):
batch = self._preprocess_batch_input(batch)
Expand All @@ -121,7 +121,7 @@ def validation_step(self, batch, batch_idx):
memory = (
torch.cuda.memory_reserved() / 1e9 if torch.cuda.is_available() else 0
)
lr = self.optimizers().param_groups[0]["lr"]
lr = self.trainer.optimizers[0].param_groups[0]["lr"]
log_msg = "Val|Epoch{}/{}|Iter{}({}/{})| mem:{:.3g}G| lr:{:.2e}| ".format(
self.current_epoch + 1,
self.cfg.schedule.total_epochs,
Expand Down Expand Up @@ -225,20 +225,17 @@ def configure_optimizers(self):
optimizer
"""
optimizer_cfg = copy.deepcopy(self.cfg.schedule.optimizer)
name = optimizer_cfg.pop("name")
build_optimizer = getattr(torch.optim, name)
optimizer = build_optimizer(params=self.parameters(), **optimizer_cfg)
optimizer = build_optimizer(self.model, optimizer_cfg)

schedule_cfg = copy.deepcopy(self.cfg.schedule.lr_schedule)
name = schedule_cfg.pop("name")
build_scheduler = getattr(torch.optim.lr_scheduler, name)
self.lr_scheduler = build_scheduler(optimizer=optimizer, **schedule_cfg)
# lr_scheduler = {'scheduler': self.lr_scheduler,
# 'interval': 'epoch',
# 'frequency': 1}
# return [optimizer], [lr_scheduler]

return optimizer
scheduler = {
"scheduler": build_scheduler(optimizer=optimizer, **schedule_cfg),
"interval": "epoch",
"frequency": 1,
}
return dict(optimizer=optimizer, lr_scheduler=scheduler)

def optimizer_step(
self,
Expand Down Expand Up @@ -266,23 +263,19 @@ def optimizer_step(
# warm up lr
if self.trainer.global_step <= self.cfg.schedule.warmup.steps:
if self.cfg.schedule.warmup.name == "constant":
warmup_lr = (
self.cfg.schedule.optimizer.lr * self.cfg.schedule.warmup.ratio
)
k = self.cfg.schedule.warmup.ratio
elif self.cfg.schedule.warmup.name == "linear":
k = (1 - self.trainer.global_step / self.cfg.schedule.warmup.steps) * (
1 - self.cfg.schedule.warmup.ratio
)
warmup_lr = self.cfg.schedule.optimizer.lr * (1 - k)
k = 1 - (
1 - self.trainer.global_step / self.cfg.schedule.warmup.steps
) * (1 - self.cfg.schedule.warmup.ratio)
elif self.cfg.schedule.warmup.name == "exp":
k = self.cfg.schedule.warmup.ratio ** (
1 - self.trainer.global_step / self.cfg.schedule.warmup.steps
)
warmup_lr = self.cfg.schedule.optimizer.lr * k
else:
raise Exception("Unsupported warm up type!")
for pg in optimizer.param_groups:
pg["lr"] = warmup_lr
pg["lr"] = pg["initial_lr"] * k

# update params
optimizer.step(closure=optimizer_closure)
Expand Down Expand Up @@ -315,10 +308,6 @@ def save_model_state(self, path):
torch.save({"state_dict": state_dict}, path)

# ------------Hooks-----------------
def on_train_start(self) -> None:
if self.current_epoch > 0:
self.lr_scheduler.last_epoch = self.current_epoch - 1

def on_fit_start(self) -> None:
if "weight_averager" in self.cfg.model:
self.logger.info("Weight Averaging is enabled")
Expand Down
65 changes: 65 additions & 0 deletions nanodet/util/env_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os
import platform
import warnings

import torch.multiprocessing as mp


def set_multi_processing(
mp_start_method: str = "fork", opencv_num_threads: int = 0, distributed: bool = True
) -> None:
"""Set multi-processing related environment.
This function is refered from https://github.com/open-mmlab/mmengine/blob/main/mmengine/utils/dl_utils/setup_env.py
Args:
mp_start_method (str): Set the method which should be used to start
child processes. Defaults to 'fork'.
opencv_num_threads (int): Number of threads for opencv.
Defaults to 0.
distributed (bool): True if distributed environment.
Defaults to False.
""" # noqa
# set multi-process start method as `fork` to speed up the training
if platform.system() != "Windows":
current_method = mp.get_start_method(allow_none=True)
if current_method is not None and current_method != mp_start_method:
warnings.warn(
f"Multi-processing start method `{mp_start_method}` is "
f"different from the previous setting `{current_method}`."
f"It will be force set to `{mp_start_method}`. You can "
"change this behavior by changing `mp_start_method` in "
"your config."
)
mp.set_start_method(mp_start_method, force=True)

try:
import cv2

# disable opencv multithreading to avoid system being overloaded
cv2.setNumThreads(opencv_num_threads)
except ImportError:
pass

# setup OMP threads
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
if "OMP_NUM_THREADS" not in os.environ and distributed:
omp_num_threads = 1
warnings.warn(
"Setting OMP_NUM_THREADS environment variable for each process"
f" to be {omp_num_threads} in default, to avoid your system "
"being overloaded, please further tune the variable for "
"optimal performance in your application as needed."
)
os.environ["OMP_NUM_THREADS"] = str(omp_num_threads)

# setup MKL threads
if "MKL_NUM_THREADS" not in os.environ and distributed:
mkl_num_threads = 1
warnings.warn(
"Setting MKL_NUM_THREADS environment variable for each process"
f" to be {mkl_num_threads} in default, to avoid your system "
"being overloaded, please further tune the variable for "
"optimal performance in your application as needed."
)
os.environ["MKL_NUM_THREADS"] = str(mkl_num_threads)
Loading

0 comments on commit ad410c2

Please sign in to comment.