Skip to content

Commit

Permalink
Add files via upload
Browse files Browse the repository at this point in the history
  • Loading branch information
ericlearning authored Jul 16, 2019
1 parent 914cf13 commit 0680105
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion trainers_advanced/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from utils import set_lr, get_lr, generate_noise, plot_multiple_images, save_fig, save, get_sample_images_list, get_display_samples
from losses.losses import SGAN, LSGAN, HINGEGAN, WGAN, RASGAN, RALSGAN, RAHINGEGAN, QPGAN

class Trainer_WGAN_GP_Pro():
class Trainer():
def __init__(self, loss_type, netD, netG, device, train_ds, lr_D = 0.0002, lr_G = 0.0002, resample = True, weight_clip = None, use_gradient_penalty = False, drift = 0.001, loss_interval = 50, image_interval = 50, save_img_dir = 'saved_images/'):
self.loss_type = loss_type
self.loss_dict = {'SGAN':SGAN, 'LSGAN':LSGAN, 'HINGEGAN':HINGEGAN, 'WGAN':WGAN, 'RASGAN':RASGAN, 'RALSGAN':RALSGAN, 'RAHINGEGAN':RAHINGEGAN, 'QPGAN':QPGAN}
Expand Down

0 comments on commit 0680105

Please sign in to comment.