Skip to content

Commit

Permalink
add losses
Browse files Browse the repository at this point in the history
  • Loading branch information
xinntao committed Jul 6, 2022
1 parent c017e31 commit 1336902
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 5 deletions.
11 changes: 11 additions & 0 deletions losses/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import importlib
from os import path as osp

from basicsr.utils import scandir

# automatically scan and import loss modules for registry
# scan all the files that end with '_loss.py' under the loss folder
loss_folder = osp.dirname(osp.abspath(__file__))
loss_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(loss_folder) if v.endswith('_loss.py')]
# import all the loss modules
_model_modules = [importlib.import_module(f'losses.{file_name}') for file_name in loss_filenames]
26 changes: 26 additions & 0 deletions losses/example_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from torch import nn as nn
from torch.nn import functional as F

from basicsr.utils.registry import LOSS_REGISTRY


@LOSS_REGISTRY.register()
class ExampleLoss(nn.Module):
"""Example Loss.
Args:
loss_weight (float): Loss weight for Example loss. Default: 1.0.
"""

def __init__(self, loss_weight=1.0):
super(ExampleLoss, self).__init__()
self.loss_weight = loss_weight

def forward(self, pred, target, **kwargs):
"""
Args:
pred (Tensor): of shape (N, C, H, W). Predicted tensor.
target (Tensor): of shape (N, C, H, W). Ground truth tensor.
weight (Tensor, optional): of shape (N, C, H, W). Element-wise weights. Default: None.
"""
return self.loss_weight * F.l1_loss(pred, target, reduction='mean')
3 changes: 1 addition & 2 deletions options/example_option.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,8 @@ train:

# losses
l1_opt:
type: L1Loss
type: ExampleLoss
loss_weight: 1.0
reduction: mean

l2_opt:
type: MSELoss
Expand Down
8 changes: 5 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# flake8: noqa
import os.path as osp

import archs # noqa: F401
import data # noqa: F401
import models # noqa: F401
import archs
import data
import losses
import models
from basicsr.train import train_pipeline

if __name__ == '__main__':
Expand Down

0 comments on commit 1336902

Please sign in to comment.