Skip to content

Commit

Permalink
Fixes a couple of issues to add fp16 training support (#488)
Browse files Browse the repository at this point in the history
* Fixes a couple of issues to add fp16 training support (#476)

* Add half precision support to `nanodet_plus` head

Moves the explicit `sigmoid` calculation inside the `dsl_assigner` so
that `binary_cross_entropy_with_logits` can be used. This allows for the
use of `auto_cast` to support training with `fp16` precision. If this is
not done `torch` will complain that using `binary_cross_entropy` with
`fp16` is unstable and as such refuses to train the model in `fp16`
precision.

* Add model precision settings to config

Allows for setting the model precision during training using the config
system.

Co-authored-by: RangiLyu <[email protected]>

* fix lint

* lightning version

Co-authored-by: Bjarne <[email protected]>
  • Loading branch information
RangiLyu and crisp-snakey committed Jan 20, 2023
1 parent d8ba391 commit a59db3c
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 7 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
python -m pip install -U pip
python -m pip install ninja opencv-python-headless onnx pytest-xdist codecov
python -m pip install torch==${{matrix.torch}}+cpu torchvision==${{matrix.torchvision}}+cpu -f https://download.pytorch.org/whl/torch_stable.html
python -m pip install Cython termcolor numpy tensorboard pycocotools matplotlib pyaml opencv-python tqdm pytorch-lightning torchmetrics codecov flake8 pytest timm
python -m pip install Cython termcolor numpy tensorboard pycocotools matplotlib pyaml opencv-python tqdm pytorch-lightning==1.8.0 torchmetrics codecov flake8 pytest timm
python -m pip install -r requirements.txt
- name: Setup
run: rm -rf .eggs && python setup.py develop
Expand Down
4 changes: 2 additions & 2 deletions nanodet/model/head/assigner/dsl_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def assign(
valid_pred_scores = valid_pred_scores.unsqueeze(1).repeat(1, num_gt, 1)

soft_label = gt_onehot_label * pairwise_ious[..., None]
scale_factor = soft_label - valid_pred_scores
scale_factor = soft_label - valid_pred_scores.sigmoid()

cls_cost = F.binary_cross_entropy(
cls_cost = F.binary_cross_entropy_with_logits(
valid_pred_scores, soft_label, reduction="none"
) * scale_factor.abs().pow(2.0)

Expand Down
2 changes: 1 addition & 1 deletion nanodet/model/head/nanodet_plus_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def target_assign_single_img(
gt_bboxes_ignore = gt_bboxes_ignore.to(decoded_bboxes.dtype)

assign_result = self.assigner.assign(
cls_preds.sigmoid(),
cls_preds,
center_priors,
decoded_bboxes,
gt_bboxes,
Expand Down
1 change: 1 addition & 0 deletions nanodet/util/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
cfg.data.train = CfgNode(new_allowed=True)
cfg.data.val = CfgNode(new_allowed=True)
cfg.device = CfgNode(new_allowed=True)
cfg.device.precision = 32
# train
cfg.schedule = CfgNode(new_allowed=True)

Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ onnx-simplifier
opencv-python
pyaml
pycocotools
pytorch-lightning>=1.7.0
pytorch-lightning>=1.7.0,<1.9.0
tabulate
tensorboard
termcolor
Expand Down
15 changes: 13 additions & 2 deletions tools/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,19 @@ def main(args):
)
if cfg.device.gpu_ids == -1:
logger.info("Using CPU training")
accelerator, devices, strategy = "cpu", None, None
accelerator, devices, strategy, precision = (
"cpu",
None,
None,
cfg.device.precision,
)
else:
accelerator, devices, strategy = "gpu", cfg.device.gpu_ids, None
accelerator, devices, strategy, precision = (
"gpu",
cfg.device.gpu_ids,
None,
cfg.device.precision,
)

if devices and len(devices) > 1:
strategy = "ddp"
Expand All @@ -135,6 +145,7 @@ def main(args):
benchmark=cfg.get("cudnn_benchmark", True),
gradient_clip_val=cfg.get("grad_clip", 0.0),
strategy=strategy,
precision=precision,
)

trainer.fit(task, train_dataloader, val_dataloader)
Expand Down

0 comments on commit a59db3c

Please sign in to comment.