Skip to content

Commit

Permalink
🧪 Clean up tests
Browse files Browse the repository at this point in the history
  • Loading branch information
francois-rozet committed Apr 10, 2023
1 parent 8950f5e commit 91bd63a
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 178 deletions.
259 changes: 113 additions & 146 deletions tests/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,104 +3,102 @@
r"""Benchmark against other packages
Packages:
kornia: https://pypi.org/project/kornia/
piq: https://pypi.org/project/piq/
IQA-pytorch: https://pypi.org/project/IQA-pytorch/
pytorch-msssim: https://pypi.org/project/pytorch-msssim/
"""

import contextlib
import numpy as np
import os
import argparse
import pandas as pd
import sys
import time
import torch
import urllib.request as request

from torchvision import transforms
from PIL import Image, ImageFilter
from torch import Tensor
from torchvision.transforms.functional import to_tensor
from typing import *
from urllib.request import urlopen

import kornia.losses as kornia
import piq
import IQA_pytorch as IQA
import piqa
import pytorch_msssim as vainf

sys.path.append(os.path.abspath('..'))

import piqa

LENNA = 'https://upload.wikimedia.org/wikipedia/en/7/7d/Lenna_(test_image).png'

LENNA = 'https://upload.wikimedia.org/wikipedia/en/7/7d/Lenna_%28test_image%29.png'
NO_REF = {
'TV': {
'piq.tv': piq.total_variation,
'piqa.TV': piqa.TV(norm='L2'),
},
}

METRICS = {
'TV': (1, {
'kornia.tv': kornia.total_variation,
'piq.tv': lambda x: piq.total_variation(x, norm_type='l1'),
'piqa.TV': piqa.TV(),
}),
'PSNR': (2, {
FULL_REF = {
'PSNR': {
'piq.psnr': piq.psnr,
'kornia.PSNR': kornia.PSNRLoss(1.),
'piqa.PSNR': piqa.PSNR(),
}),
'SSIM': (2, {
},
'SSIM': {
'piq.ssim': lambda x, y: piq.ssim(x, y, downsample=False),
'kornia.SSIM-halfloss': kornia.SSIMLoss(11),
'IQA.SSIM-loss': IQA.SSIM(),
'vainf.SSIM': vainf.SSIM(data_range=1.),
'vainf.SSIM': vainf.SSIM(data_range=1),
'piqa.SSIM': piqa.SSIM(),
}),
'MS-SSIM': (2, {
},
'MS-SSIM': {
'piq.ms_ssim': piq.multi_scale_ssim,
'IQA.MS_SSIM-loss': IQA.MS_SSIM(),
'vainf.MS_SSIM': vainf.MS_SSIM(data_range=1.),
'vainf.MS_SSIM': vainf.MS_SSIM(data_range=1),
'piqa.MS_SSIM': piqa.MS_SSIM(),
}),
'LPIPS': (2, {
},
'LPIPS': {
# 'piq.LPIPS': piq.LPIPS(),
# 'IQA.LPIPS': IQA.LPIPSvgg(),
'piqa.LPIPS': piqa.LPIPS(network='vgg')
}),
'GMSD': (2, {
'piqa.LPIPS': piqa.LPIPS(network='vgg'),
},
'GMSD': {
'piq.gmsd': piq.gmsd,
'piqa.GMSD': piqa.GMSD(),
}),
'MS-GMSD': (2, {
},
'MS-GMSD': {
'piq.ms_gmsd': piq.multi_scale_gmsd,
'piqa.MS_GMSD': piqa.MS_GMSD(),
}),
'MDSI': (2, {
},
'MDSI': {
'piq.mdsi': piq.mdsi,
'piqa.MDSI': piqa.MDSI(),
}),
'HaarPSI': (2, {
},
'HaarPSI': {
'piq.haarpsi': piq.haarpsi,
'piqa.HaarPSI': piqa.HaarPSI(),
}),
'VSI': (2, {
},
'VSI': {
'piq.vsi': piq.vsi,
'piqa.VSI': piqa.VSI(),
}),
'FSIM': (2, {
},
'FSIM': {
'piq.fsim': piq.fsim,
'piqa.FSIM': piqa.FSIM(),
}),
},
}

MANY_REF = {
'FID': {
'piq.FID': piq.FID(),
'piqa.FID': piqa.FID(),
},
}

def timeit(f, n: int) -> float:
METRICS = [*NO_REF, *FULL_REF, *MANY_REF]


def timeit(f: Callable, n: int) -> float:
start = time.perf_counter()

for _ in range(n):
f()

end = time.perf_counter()

return (end - start) * 1000
return (end - start) * 1000 / n


def cuda_timeit(f, n: int) -> float:
def cuda_timeit(f: Callable, n: int) -> float:
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)

Expand All @@ -113,129 +111,98 @@ def cuda_timeit(f, n: int) -> float:

torch.cuda.synchronize()

return start.elapsed_time(end)
return start.elapsed_time(end) / n


def main(
url: str,
metrics: list = [],
def benchmark(
metrics: List[str] = METRICS,
batch: int = 1,
warm: int = 69,
loops: int = 420,
device: str = 'cuda',
grad: bool = True,
backend: bool = True,
warmup: int = 16,
repeat: int = 64,
device: str = 'cpu',
tracing: bool = False,
):
# Device
if not torch.cuda.is_available():
device = 'cpu'

# Backend
if device == 'cuda':
torch.backends.cudnn.enabled = backend
torch.backends.cudnn.benchmark = backend

# Images
truth = Image.open(request.urlopen(url))
noisy = truth.filter(ImageFilter.BLUR)
# Image
lenna = Image.open(urlopen(LENNA))
noisy = lenna.filter(ImageFilter.GaussianBlur)

noisy, truth = np.array(noisy), np.array(truth)
lenna = to_tensor(lenna).repeat(batch, 1, 1, 1)
noisy = to_tensor(noisy).repeat(batch, 1, 1, 1)

totensor = transforms.ToTensor()

x = totensor(noisy).repeat(batch, 1, 1, 1).to(device)
y = totensor(truth).repeat(batch, 1, 1, 1).to(device)

x.requires_grad_()

# Metrics
if metrics:
metrics = {
k: v for (k, v) in METRICS.items()
if k in metrics
}
else:
metrics = {k: v for (k, v) in METRICS.items()}

del metrics['LPIPS']
# Features
A = torch.randn(256, 256) / 256 ** 0.5
fx = torch.randn(4096, 256)
fy = torch.randn(4096, 256) @ A + torch.randn(256)

# Benchmark
for name, (nargs, methods) in metrics.items():
if metrics is None:
metrics = METRICS.copy()
metrics.remove('LPIPS')
metrics.remove('FID')

for name in metrics:
if name in NO_REF:
versions = NO_REF[name]
x, y = lenna.to(device), None
elif name in FULL_REF:
versions = FULL_REF[name]
x, y = noisy.to(device), lenna.to(device)
elif name in MANY_REF:
versions = MANY_REF[name]
x, y = fx.to(device), fy.to(device)
else:
continue

print(name)
print('-' * len(name))

data = {
'method': [],
'value': [],
'time': []
}

with contextlib.nullcontext() if grad else torch.no_grad():
for key, method in methods.items():
if hasattr(method, 'to'):
method.to(x.device)
rows = []

if '-np' in key:
a, b = truth, noisy
else:
a, b = x, y
for key, metric in versions.items():
if hasattr(metric, 'to'):
metric.to(device)

if nargs == 1:
f = lambda: method(a).mean()
else:
f = lambda: method(a, b).mean()
if y is None:
f = lambda x: metric(x)
else:
f = lambda x: metric(x, y)

if '-loss' in key:
g = lambda: 1. - f()
elif '-halfloss' in key:
g = lambda: 1. - 2. * f()
else:
g = f
if tracing and 'piqa' in key:
f = torch.jit.trace(f, x)

if '-' in key:
base = key[:key.find('-')]
else:
base = key
g = lambda: torch.autograd.functional.jacobian(f, x)

data['method'].append(base)
data['value'].append(float(g()))
for _ in range(warmup):
g()

time = timeit(g, n=warm) / warm
if device == 'cuda':
time = cuda_timeit(g, n=repeat)
else:
time = timeit(g, n=repeat)

if device == 'cuda' and '-np' not in key:
time = cuda_timeit(g, n=loops) / loops
rows.append({
'version': key,
'value': f(x).item(),
'gradient': g().norm(2).item(),
'time': time,
})

data['time'].append(time)

df = pd.DataFrame(data)
df = pd.DataFrame(rows)

print(df.sort_values(by='time', ignore_index=True))
print()


if __name__ == '__main__':
import argparse

parser = argparse.ArgumentParser(description='Benchmark')

parser.add_argument('-u', '--url', default=LENNA, help='image URL')
parser.add_argument('-m', '--metrics', nargs='+', default=[], choices=list(METRICS.keys()), help='metrics to benchmark')
parser.add_argument('-m', '--metrics', nargs='+', default=None, choices=METRICS, help='metrics to benchmark')
parser.add_argument('-b', '--batch', type=int, default=1, help='batch size')
parser.add_argument('-w', '--warm', type=int, default=69, help='number of warming loops')
parser.add_argument('-l', '--loops', type=int, default=420, help='number of loops')
parser.add_argument('-d', '--device', default='cuda', choices=['cpu', 'cuda'], help='computation device')
parser.add_argument('-grad', default=True, action='store_false', help='enable gradients')
parser.add_argument('-backend', default=True, action='store_false', help='enable backends')
parser.add_argument('-w', '--warmup', type=int, default=16, help='number of warmups')
parser.add_argument('-r', '--repeat', type=int, default=64, help='number of repeats')
parser.add_argument('-d', '--device', default='cpu', choices=['cpu', 'cuda'], help='computation device')
parser.add_argument('-t', '--tracing', default=False, action='store_true', help='enable tracing')

args = parser.parse_args()

main(
metrics=args.metrics,
url=args.url,
batch=args.batch,
warm=args.warm,
loops=args.loops,
device=args.device,
grad=args.grad,
backend=args.backend,
)
benchmark(**vars(args))
Loading

0 comments on commit 91bd63a

Please sign in to comment.