Skip to content

francois-rozet/piqa

Repository files navigation

PIQA is not endorsed by Facebook, Inc.; PyTorch, the PyTorch logo and any related marks are trademarks of Facebook, Inc.

PyTorch Image Quality Assessment

The piqa package is a collection of measures and metrics for image quality assessment in various image processing tasks such as denoising, super-resolution, image interpolation, etc. It relies only on PyTorch and takes advantage of its efficiency and automatic differentiation.

PIQA is directly inspired from the piq project, but focuses on the conciseness, readability and understandability of its (sub-)modules, such that anyone can easily reuse and/or adapt them to its needs.

However, conciseness should never be at the expense of efficiency; PIQA's implementations are up to 3 times faster than those of other IQA PyTorch packages like kornia, piq and IQA-pytorch.

PIQA should be pronounced pika (like Pikachu ⚡️)

Installation

The piqa package is available on PyPI, which means it is installable with pip:

pip install piqa

Alternatively, if you need the latest features, you can install it using

pip install git+https://github.com/francois-rozet/piqa

or copy the package directly to your project, with

git clone https://github.com/francois-rozet/piqa
cp -R piqa/piqa <path/to/project>/piqa

Documentation

The documentation of this package is generated automatically using pdoc.

Getting started

In piqa, each metric is associated to a class, child of torch.nn.Module, which has to be instantiated to evaluate the metric.

import torch

# PSNR
from piqa import PSNR

x = torch.rand(5, 3, 256, 256)
y = torch.rand(5, 3, 256, 256)

psnr = PSNR()
l = psnr(x, y)

# SSIM
from piqa import SSIM

x = torch.rand(5, 3, 256, 256, requires_grad=True).cuda()
y = torch.rand(5, 3, 256, 256).cuda()

ssim = SSIM().cuda()
l = 1 - ssim(x, y)
l.backward()

Like torch.nn built-in components, these classes are based on functional definitions of the metrics, which are less user-friendly, but more versatile.

import torch

from piqa.ssim import ssim
from piqa.utils.functional import gaussian_kernel

x = torch.rand(5, 3, 256, 256)
y = torch.rand(5, 3, 256, 256)

kernel = gaussian_kernel(11, sigma=1.5).repeat(3, 1, 1)

l = ssim(x, y, kernel=kernel, channel_avg=False)

Metrics

Acronym Class Range Objective Year Metric
TV TV [0, ∞] / 1937 Total Variation
PSNR PSNR [0, ∞] max / Peak Signal-to-Noise Ratio
SSIM SSIM [0, 1] max 2004 Structural Similarity
MS-SSIM MS_SSIM [0, 1] max 2004 Multi-Scale Structural Similarity
LPIPS LPIPS [0, ∞] min 2018 Learned Perceptual Image Patch Similarity
GMSD GMSD [0, ∞] min 2013 Gradient Magnitude Similarity Deviation
MS-GMSD MS_GMSD [0, ∞] min 2017 Multi-Scale Gradient Magnitude Similarity Deviation
MDSI MDSI [0, ∞] min 2016 Mean Deviation Similarity Index
HaarPSI HaarPSI [0, 1] max 2018 Haar Perceptual Similarity Index

JIT

Most functional components of piqa support PyTorch's JIT, i.e. TorchScript, which is a way to create serializable and optimizable functions from PyTorch code.

By default, jitting is disabled for those components. To enable it, the PIQA_JIT environment variable has to be set to 1. To do so temporarily,

  • UNIX-like bash
export PIQA_JIT=1
  • Windows cmd
set PIQA_JIT=1
  • Microsoft PowerShell
$env:PIQA_JIT=1

Assert

PIQA uses type assertions to raise meaningful messages when an object-oriented component doesn't receive an input of the expected type. This feature eases a lot early prototyping and debugging, but it might hurt a little the performances.

If you need the absolute best performances, the assertions can be disabled with the Python flag -O. For example,

python -O your_awesome_code_using_piqa.py