Skip to content

Commit

Permalink
Merge pull request #819 from Sentient07/huber-loss
Browse files Browse the repository at this point in the history
Added Huber loss
  • Loading branch information
f0k committed Feb 21, 2018
2 parents c712b4a + 8694951 commit ffc8b8a
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 1 deletion.
1 change: 1 addition & 0 deletions docs/modules/objectives.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Loss functions
.. autofunction:: squared_error
.. autofunction:: binary_hinge_loss
.. autofunction:: multiclass_hinge_loss
.. autofunction:: huber_loss


Aggregation functions
Expand Down
51 changes: 50 additions & 1 deletion lasagne/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
Provides some minimal help with building loss expressions for training or
validating a neural network.
Five functions build element- or item-wise loss expressions from network
Six functions build element- or item-wise loss expressions from network
predictions and targets:
.. autosummary::
Expand All @@ -13,6 +13,7 @@
squared_error
binary_hinge_loss
multiclass_hinge_loss
huber_loss
A convenience function aggregates such losses into a scalar expression
suitable for differentiation:
Expand Down Expand Up @@ -86,6 +87,7 @@
"aggregate",
"binary_hinge_loss",
"multiclass_hinge_loss",
"huber_loss",
"binary_accuracy",
"categorical_accuracy"
]
Expand Down Expand Up @@ -343,6 +345,53 @@ def multiclass_hinge_loss(predictions, targets, delta=1):
return theano.tensor.nnet.relu(rest - corrects + delta)


def huber_loss(predictions, targets, delta=1):
""" Computes the huber loss between predictions and targets.
.. math:: L_i = \\frac{(p - t)^2}{2}, |p - t| \\le \\delta
L_i = \\delta (|p - t| - \\frac{\\delta}{2} ), |p - t| \\gt \\delta
Parameters
----------
predictions : Theano 2D tensor or 1D tensor
Prediction outputs of a neural network.
targets : Theano 2D tensor or 1D tensor
Ground truth to which the prediction is to be compared
with. Either a vector or 2D Tensor.
delta : scalar, default 1
This delta value is defaulted to 1, for `SmoothL1Loss`
described in Fast-RCNN paper [1]_ .
Returns
-------
Theano tensor
An expression for the element-wise huber loss [2]_ .
Notes
-----
This is an alternative to the squared error for
regression problems.
References
----------
.. [1] Ross Girshick et al (2015):
Fast RCNN
https://arxiv.org/pdf/1504.08083.pdf
.. [2] Huber, Peter et al (1964)
Robust Estimation of a Location Parameter
https://projecteuclid.org/euclid.aoms/1177703732
"""
predictions, targets = align_targets(predictions, targets)
abs_diff = abs(targets - predictions)
ift = 0.5 * squared_error(targets, predictions)
iff = delta * (abs_diff - delta / 2.)
return theano.tensor.switch(abs_diff <= delta, ift, iff)


def binary_accuracy(predictions, targets, threshold=0.5):
"""Computes the binary accuracy between predictions and targets.
Expand Down
24 changes: 24 additions & 0 deletions lasagne/tests/test_objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,30 @@ def test_binary_hinge_loss(colvect):
assert np.allclose(hinge, c.eval({p: predictions, t: targets}))


@pytest.mark.parametrize('colvect', (False, True))
@pytest.mark.parametrize('delta', (0.5, 1.0))
def test_huber_loss(colvect, delta):
from lasagne.objectives import huber_loss
if not colvect:
a, b = theano.tensor.matrices('ab')
l = huber_loss(a, b, delta)
else:
a, b = theano.tensor.vectors('ab')
l = huber_loss(a.dimshuffle(0, 'x'), b, delta)[:, 0]

# numeric version
floatX = theano.config.floatX
shape = (10, 20) if not colvect else (10,)
x = np.random.rand(*shape).astype(floatX)
y = np.random.rand(*shape).astype(floatX)
abs_diff = abs(x - y)
ift = 0.5 * abs_diff ** 2
iff = delta * (abs_diff - delta / 2.)
z = np.where(abs_diff <= delta, ift, iff)
# compare
assert np.allclose(z, l.eval({a: x, b: y}))


@pytest.mark.parametrize('colvect', (False, True))
def test_binary_hinge_loss_not_binary_targets(colvect):
from lasagne.objectives import binary_hinge_loss
Expand Down

0 comments on commit ffc8b8a

Please sign in to comment.