Skip to content
This repository has been archived by the owner on Nov 29, 2022. It is now read-only.

artemmavrin/focal-loss

Repository files navigation

Focal Loss

Python Version PyPI Package Version Last Commit Build Status Code Coverage Documentation Status License

TensorFlow implementation of focal loss [1]: a loss function generalizing binary and multiclass cross-entropy loss that penalizes hard-to-classify examples.

The focal_loss package provides functions and classes that can be used as off-the-shelf replacements for tf.keras.losses functions and classes, respectively.

# Typical tf.keras API usage
import tensorflow as tf
from focal_loss import BinaryFocalLoss

model = tf.keras.Model(...)
model.compile(
    optimizer=...,
    loss=BinaryFocalLoss(gamma=2),  # Used here like a tf.keras loss
    metrics=...,
)
history = model.fit(...)

The focal_loss package includes the functions

  • binary_focal_loss
  • sparse_categorical_focal_loss

and wrapper classes

  • BinaryFocalLoss (use like tf.keras.losses.BinaryCrossentropy)
  • SparseCategoricalFocalLoss (use like tf.keras.losses.SparseCategoricalCrossentropy)

Documentation is available at Read the Docs.

Focal loss plot

Installation

The focal_loss package can be installed using the pip utility. For the latest version, install directly from the package's GitHub page:

pip install git+https://github.com/artemmavrin/focal-loss.git

Alternatively, install a recent release from the Python Package Index (PyPI):

pip install focal-loss

Note. To install the project for development (e.g., to make changes to the source code), clone the project repository from GitHub and run make dev:

git clone https://github.com/artemmavrin/focal-loss.git
cd focal-loss
# Optional but recommended: create and activate a new environment first
make dev

This will additionally install the requirements needed to run tests, check code coverage, and produce documentation.

References

[1]T. Lin, P. Goyal, R. Girshick, K. He and P. Dollár. Focal loss for dense object detection. IEEE Transactions on Pattern Analysis and Machine Intelligence, 2018. (DOI) (arXiv preprint)