This code contains experiments for our ICML paper: Implicit competitive regularization in GANs.
Optimizers in this package are for competitive optimization problems, given by $$ \min_{x}\max_{y} f(x,y) $$
pip install CGDs
See details at CGDs package: CGDs · PyPI.
You can also directly copy the folder 'optims' to your workspace.
The package contains the original Compeititive Gradient Descent (BCGD), and the Adaptive Competitive Gradient Descent (ACGD).
Quickstart with notebook: Examples of using ACGD.
It's important to force cudnn to benchmark and pick the best algo.
Check more details at cgds-package: Package for CGD and ACGD optimizers .
import torch
torch.backends.cudnn.benchmark = True
from CGDs import ACGD
device = torch.device('cuda:0')
lr = 0.0001
G = Generator()
D = Discriminator()
optimizer = ACGD(max_params=G.parameters(), min_params=D.parameters(), lr_max=lr, lr_min=lr, device=device)
# max_parems is maximizing the objective function while the min_params is trying to minimizing it.
# BCGD(max_params=G.parameters(), min_params=D.parameters(), lr_max=lr, lr_min=lr, device=device)
# ACGD: Adaptive CGD;
for img in dataloader:
d_real = D(img)
z = torch.randn((batch_size, z_dim), device=device)
d_fake = D(G(z))
loss = criterion(d_real, d_fake)
optimizer.zero_grad()
optimizer.step(loss=loss)
==Warning==:
- zero sum game setting only. This implementation uses conjugate gradient method to solve matrix inversion efficiently, which requires the matrix to be positive definite. If you are using competitive gradient descent (CGD) algorithm for non-zero sum games, please check more details in CGD paper https://arxiv.org/abs/1905.12103. For example, GMRES (the generalized minimal residual) algorithm can be a solver for non-zero sum setting.
- This implementation doesn't work with torch.nn.parallel.DistributedDataParallel module because we need autograd.grad() to compute Hessian vector product. See details at [DDP doc](DistributedDataParallel — PyTorch 1.7.0 documentation) .
Please cite the following paper if you find this code useful. Thanks!
@misc{schfer2019implicit,
title={Implicit competitive regularization in GANs},
author={Florian Schäfer and Hongkai Zheng and Anima Anandkumar},
year={2019},
eprint={1910.05852},
archivePrefix={arXiv},
primaryClass={cs.LG}
}