In [1]:
# The codes are modified from https://github.com/deepak112/Keras-SRGAN

from model import Generator, Discriminator
import numpy as np
from matplotlib import pyplot as plt
from tensorflow.python.keras import backend as K

from tensorflow.keras.optimizers import Adam
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input

import tensorflow as tf
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [2]:
# More information about the source of the data and preprocessing method are avalible in our paper.
X_train = np.load("X_train.npy")
y_train = np.load("y_train.npy")
X_test = np.load("X_test.npy")
y_test = np.load("y_test.npy")

In [3]:
print(X_train.shape)
print(y_train.shape)
print(X_test.shape)
print(y_test.shape)

(4056, 1, 256, 58)
(4056, 1, 256, 58)
(1984, 1, 256, 58)
(1984, 1, 256, 58)


In [4]:
fs=X_train.shape[2]
ch=X_train.shape[3]

In [5]:
def get_gan_network(discriminator, shape, generator, optimizer):
 discriminator.trainable = False
 gan_input = Input(shape=shape)
 x = generator(gan_input)
 gan_output = discriminator(x)
 gan = Model(inputs=gan_input, outputs=[x,gan_output])
 gan.compile(loss=["mse","binary_crossentropy"],
 loss_weights=[1., 5e-4],
 optimizer=optimizer)
 return gan

In [6]:
K.clear_session()
batch_size = 256
epochs = 500

batch_count = int(X_train.shape[0] / batch_size)
generator = Generator(X_train[0].shape).generator(ch)
discriminator = Discriminator(X_train[0].shape).discriminator()

adam = Adam(lr=1E-3, beta_1=0.9, beta_2=0.999, epsilon=1e-08)
generator.compile(loss='mse', optimizer=adam)
discriminator.compile(loss="binary_crossentropy", optimizer=adam)
 
gan = get_gan_network(discriminator, X_train[0].shape, generator, adam)

for e in range(1, epochs+1):
 print ('-'*15, 'Epoch %d' % e, '-'*15)
 for _ in range(batch_count):
 
 rand_nums = np.random.randint(0, X_train.shape[0], size=batch_size)
 X_train_batch = X_train[rand_nums]
 y_train_batch = y_train[rand_nums]
 generated = generator.predict(X_train_batch)

 real_data_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2
 fake_data_Y = np.random.random_sample(batch_size)*0.2
 
 discriminator.trainable = True
 
 d_loss_real = discriminator.train_on_batch(y_train_batch, real_data_Y)
 d_loss_fake = discriminator.train_on_batch(generated, fake_data_Y)
 
 rand_nums = np.random.randint(0, y_train.shape[0], size=batch_size)
 X_train_batch = X_train[rand_nums]
 y_train_batch = y_train[rand_nums]
 
 gan_Y = np.ones(batch_size) - np.random.random_sample(batch_size)*0.2

 discriminator.trainable = False
 loss_gan = gan.train_on_batch(X_train_batch, [y_train_batch,gan_Y])
 gan_y_test = np.ones(X_test.shape[0]) - np.random.random_sample(X_test.shape[0])*0.2
 loss_test=gan.evaluate(X_test, [y_test,gan_y_test],verbose=0)
 
 print("Loss real , Loss fake, Loss GAN")
 print(d_loss_real, d_loss_fake, loss_gan)
 print("Loss test")
 print(loss_test)

--------------- Epoch 1 ---------------
Loss real , Loss fake, Loss GAN
0.4633001685142517 1.2061243057250977 [0.052180685102939606, 0.05174751207232475, 0.8663432002067566]
Loss test
[0.06091763451695442, 0.06047450751066208, 0.8862844705581665]
--------------- Epoch 2 ---------------
Loss real , Loss fake, Loss GAN
0.6881305575370789 0.6992399096488953 [0.029459087178111076, 0.029072392731904984, 0.7733879089355469]
Loss test
[0.028185183182358742, 0.027789326384663582, 0.791710376739502]
--------------- Epoch 3 ---------------
Loss real , Loss fake, Loss GAN
0.3563421368598938 0.3519001007080078 [0.022904643788933754, 0.021933982148766518, 1.9413237571716309]
Loss test
[0.022301258519291878, 0.021404553204774857, 1.7934106588363647]
--------------- Epoch 4 ---------------
Loss real , Loss fake, Loss GAN
0.38113757967948914 0.3531765639781952 [0.020838206633925438, 0.019911594688892365, 1.8532251119613647]
Loss test
[0.019869033247232437, 0.018971208482980728, 1.7956552505493164]
---

Loss real , Loss fake, Loss GAN
0.33108454942703247 0.3302263021469116 [0.008409496396780014, 0.007940514013171196, 0.9379638433456421]
Loss test
[0.009200270287692547, 0.008767286315560341, 0.8659610748291016]
--------------- Epoch 35 ---------------
Loss real , Loss fake, Loss GAN
0.33418333530426025 0.33950117230415344 [0.008461855351924896, 0.007861167192459106, 1.20137619972229]
Loss test
[0.00934580247849226, 0.008772461675107479, 1.1466830968856812]
--------------- Epoch 36 ---------------
Loss real , Loss fake, Loss GAN
0.31237781047821045 0.33380964398384094 [0.008140018209815025, 0.007733795326203108, 0.8124459981918335]
Loss test
[0.008666315115988255, 0.00826120562851429, 0.810217559337616]
--------------- Epoch 37 ---------------
Loss real , Loss fake, Loss GAN
0.3268812894821167 0.3473111093044281 [0.007878582924604416, 0.007433384656906128, 0.8903970718383789]
Loss test
[0.008758697658777237, 0.008314322680234909, 0.8887456655502319]
--------------- Epoch 38 ------------

Loss real , Loss fake, Loss GAN
0.36521318554878235 0.34999722242355347 [0.006691493559628725, 0.005882998928427696, 1.6169894933700562]
Loss test
[0.007708204910159111, 0.00685869250446558, 1.6990258693695068]
--------------- Epoch 68 ---------------
Loss real , Loss fake, Loss GAN
0.36949920654296875 2.271470546722412 [0.006220344454050064, 0.005994049832224846, 0.45258966088294983]
Loss test
[0.007192005403339863, 0.00695850420743227, 0.4670043885707855]
--------------- Epoch 69 ---------------
Loss real , Loss fake, Loss GAN
0.382470041513443 0.33786553144454956 [0.006686587817966938, 0.005998261272907257, 1.3766528367996216]
Loss test
[0.007447520736604929, 0.006747561506927013, 1.3999152183532715]
--------------- Epoch 70 ---------------
Loss real , Loss fake, Loss GAN
0.364464670419693 0.3514810800552368 [0.00662257382646203, 0.006052341312170029, 1.140465259552002]
Loss test
[0.007490735501050949, 0.006942971609532833, 1.0955301523208618]
--------------- Epoch 71 --------------

Loss real , Loss fake, Loss GAN
0.33068665862083435 0.3156266212463379 [0.005504873115569353, 0.004793302621692419, 1.423140525817871]
Loss test
[0.0064714038744568825, 0.005764520727097988, 1.4137647151947021]
--------------- Epoch 101 ---------------
Loss real , Loss fake, Loss GAN
0.3253394663333893 0.32175779342651367 [0.005524975247681141, 0.0050007617101073265, 1.048426628112793]
Loss test
[0.006303539965301752, 0.005782132036983967, 1.042812705039978]
--------------- Epoch 102 ---------------
Loss real , Loss fake, Loss GAN
0.32187944650650024 0.32437825202941895 [0.005396833177655935, 0.004914812743663788, 0.9640405178070068]
Loss test
[0.006355463061481714, 0.005884449928998947, 0.942024290561676]
--------------- Epoch 103 ---------------
Loss real , Loss fake, Loss GAN
0.619617223739624 0.6356942057609558 [0.004965964704751968, 0.004576759412884712, 0.7784101963043213]
Loss test
[0.006304533686488867, 0.0059163738042116165, 0.7763243317604065]
--------------- Epoch 104 ------

Loss real , Loss fake, Loss GAN
0.3294001817703247 0.33718928694725037 [0.00532921776175499, 0.004626136738806963, 1.4061615467071533]
Loss test
[0.0058983732014894485, 0.005198772065341473, 1.399204134941101]
--------------- Epoch 134 ---------------
Loss real , Loss fake, Loss GAN
0.32872238755226135 0.3424248695373535 [0.005136095918715, 0.0043611107394099236, 1.5499706268310547]
Loss test
[0.00596048915758729, 0.00519076082855463, 1.5394600629806519]
--------------- Epoch 135 ---------------
Loss real , Loss fake, Loss GAN
0.3335907459259033 0.321377694606781 [0.005002390127629042, 0.004358446225523949, 1.287887692451477]
Loss test
[0.005721359048038721, 0.0050733741372823715, 1.2959734201431274]
--------------- Epoch 136 ---------------
Loss real , Loss fake, Loss GAN
0.32729995250701904 0.3721294105052948 [0.006719823926687241, 0.006256989203393459, 0.9256695508956909]
Loss test
[0.006975023075938225, 0.0065063354559242725, 0.9373766779899597]
--------------- Epoch 137 ----------

Loss real , Loss fake, Loss GAN
0.33958539366722107 0.32240763306617737 [0.004379228223115206, 0.003768176306039095, 1.2221039533615112]
Loss test
[0.005638254340738058, 0.005021106451749802, 1.2342966794967651]
--------------- Epoch 167 ---------------
Loss real , Loss fake, Loss GAN
0.3372730612754822 0.3393775522708893 [0.004647134803235531, 0.004014743957668543, 1.2647819519042969]
Loss test
[0.005584017839282751, 0.004947803448885679, 1.2724251747131348]
--------------- Epoch 168 ---------------
Loss real , Loss fake, Loss GAN
0.33301007747650146 0.3272445797920227 [0.004187099169939756, 0.003626209683716297, 1.1217788457870483]
Loss test
[0.005465494003146887, 0.004908937495201826, 1.1131134033203125]
--------------- Epoch 169 ---------------
Loss real , Loss fake, Loss GAN
0.32162725925445557 0.32127028703689575 [0.004213876090943813, 0.003580352058634162, 1.2670477628707886]
Loss test
[0.00541686313226819, 0.0047817048616707325, 1.2703136205673218]
--------------- Epoch 170 ---

Loss real , Loss fake, Loss GAN
0.3260749578475952 0.32726967334747314 [0.005273020826280117, 0.004734572488814592, 1.0768967866897583]
Loss test
[0.0062476592138409615, 0.00571006815880537, 1.075181484222412]
--------------- Epoch 200 ---------------
Loss real , Loss fake, Loss GAN
0.32155895233154297 0.3276899456977844 [0.005439551081508398, 0.004902822431176901, 1.0734573602676392]
Loss test
[0.006035271566361189, 0.005498203914612532, 1.074134111404419]
--------------- Epoch 201 ---------------
Loss real , Loss fake, Loss GAN
0.3223837614059448 0.31995266675949097 [0.005605730228126049, 0.005035303998738527, 1.140852689743042]
Loss test
[0.005853753536939621, 0.0052869608625769615, 1.1335867643356323]
--------------- Epoch 202 ---------------
Loss real , Loss fake, Loss GAN
0.3332936465740204 0.3357492685317993 [0.005810978356748819, 0.005246101878583431, 1.1297531127929688]
Loss test
[0.005801377352327108, 0.005232726689428091, 1.1372997760772705]
--------------- Epoch 203 -------

Loss real , Loss fake, Loss GAN
0.33027076721191406 0.304948091506958 [0.00429898826405406, 0.0037078256718814373, 1.182325005531311]
Loss test
[0.005125520750880241, 0.00453455513343215, 1.1819323301315308]
--------------- Epoch 233 ---------------
Loss real , Loss fake, Loss GAN
0.3327389359474182 0.31568509340286255 [0.004397477023303509, 0.0038020990323275328, 1.1907554864883423]
Loss test
[0.005568633321672678, 0.004973967559635639, 1.189330816268921]
--------------- Epoch 234 ---------------
Loss real , Loss fake, Loss GAN
0.33031418919563293 0.32506218552589417 [0.004169076215475798, 0.0035746218636631966, 1.188908576965332]
Loss test
[0.00508316233754158, 0.00448601646348834, 1.1942886114120483]
--------------- Epoch 235 ---------------
Loss real , Loss fake, Loss GAN
0.3240734338760376 0.3295719623565674 [0.004999370779842138, 0.004336330108344555, 1.3260815143585205]
Loss test
[0.006262491457164288, 0.005591735243797302, 1.3415100574493408]
--------------- Epoch 236 ---------

Loss real , Loss fake, Loss GAN
0.3260180354118347 0.32103049755096436 [0.0035731690004467964, 0.0030201119370758533, 1.1061142683029175]
Loss test
[0.005114364437758923, 0.004562435671687126, 1.1038583517074585]
--------------- Epoch 266 ---------------
Loss real , Loss fake, Loss GAN
0.3205162584781647 0.32927200198173523 [0.003777826204895973, 0.0032560289837419987, 1.0435941219329834]
Loss test
[0.005401751026511192, 0.004878230858594179, 1.047041893005371]
--------------- Epoch 267 ---------------
Loss real , Loss fake, Loss GAN
0.32942160964012146 0.33311453461647034 [0.003984492272138596, 0.0033702864311635494, 1.2284120321273804]
Loss test
[0.005028112791478634, 0.00441518472507596, 1.2258579730987549]
--------------- Epoch 268 ---------------
Loss real , Loss fake, Loss GAN
0.3316256105899811 0.34122753143310547 [0.0038241937290877104, 0.003178325481712818, 1.2917366027832031]
Loss test
[0.004999748431146145, 0.004353051073849201, 1.2933933734893799]
--------------- Epoch 269 

Loss real , Loss fake, Loss GAN
0.31898826360702515 0.3341437578201294 [0.003637696383520961, 0.00301348976790905, 1.2484130859375]
Loss test
[0.005253266543149948, 0.0046271695755422115, 1.2521944046020508]
--------------- Epoch 299 ---------------
Loss real , Loss fake, Loss GAN
0.3315201997756958 0.315360426902771 [0.003667157143354416, 0.0030587988439947367, 1.2167167663574219]
Loss test
[0.004914301447570324, 0.004304394591599703, 1.2198125123977661]
--------------- Epoch 300 ---------------
Loss real , Loss fake, Loss GAN
0.3219660520553589 0.32160311937332153 [0.003401793073862791, 0.002824245486408472, 1.155094861984253]
Loss test
[0.004918275400996208, 0.0043366821482777596, 1.1631890535354614]
--------------- Epoch 301 ---------------
Loss real , Loss fake, Loss GAN
0.3156992197036743 0.33301323652267456 [0.0033000921830534935, 0.0028096302412450314, 0.9809237122535706]
Loss test
[0.00481432257220149, 0.004324723035097122, 0.9791961312294006]
--------------- Epoch 302 -------

Loss real , Loss fake, Loss GAN
0.3248291313648224 0.33887025713920593 [0.003266723593696952, 0.0026607224717736244, 1.2120022773742676]
Loss test
[0.004820933099836111, 0.004213265608996153, 1.2153338193893433]
--------------- Epoch 332 ---------------
Loss real , Loss fake, Loss GAN
0.3266918957233429 0.32457903027534485 [0.0034743736032396555, 0.002926922868937254, 1.0949015617370605]
Loss test
[0.004792611580342054, 0.00424824096262455, 1.088739275932312]
--------------- Epoch 333 ---------------
Loss real , Loss fake, Loss GAN
0.32764336466789246 0.32865792512893677 [0.0033226273953914642, 0.0027266149409115314, 1.1920249462127686]
Loss test
[0.004914156626909971, 0.004318343009799719, 1.1916286945343018]
--------------- Epoch 334 ---------------
Loss real , Loss fake, Loss GAN
0.330899178981781 0.31957772374153137 [0.003684063209220767, 0.0031036597210913897, 1.1608070135116577]
Loss test
[0.005026426166296005, 0.004443516489118338, 1.1658164262771606]
--------------- Epoch 335 -

Loss real , Loss fake, Loss GAN
0.31785082817077637 0.32074224948883057 [0.002920637372881174, 0.002419258002191782, 1.0027589797973633]
Loss test
[0.004834224469959736, 0.004328077659010887, 1.0122923851013184]
--------------- Epoch 365 ---------------
Loss real , Loss fake, Loss GAN
0.3260791301727295 0.32497671246528625 [0.002726188628003001, 0.00213037570938468, 1.1916258335113525]
Loss test
[0.004865956027060747, 0.0042733545415103436, 1.1852024793624878]
--------------- Epoch 366 ---------------
Loss real , Loss fake, Loss GAN
0.322323739528656 0.35112833976745605 [0.003041342366486788, 0.002484417287632823, 1.1138503551483154]
Loss test
[0.00481915008276701, 0.004265176597982645, 1.1079460382461548]
--------------- Epoch 367 ---------------
Loss real , Loss fake, Loss GAN
0.32643550634384155 0.3308731019496918 [0.002751353895291686, 0.0022546860855072737, 0.9933354258537292]
Loss test
[0.004694877192378044, 0.004195926710963249, 0.9979032278060913]
--------------- Epoch 368 ----

Loss real , Loss fake, Loss GAN
0.3300936222076416 0.31815457344055176 [0.0031906506046652794, 0.002628345973789692, 1.1246089935302734]
Loss test
[0.004834894090890884, 0.004278245382010937, 1.1132991313934326]
--------------- Epoch 398 ---------------
Loss real , Loss fake, Loss GAN
0.3183840215206146 0.31629306077957153 [0.0030220497865229845, 0.0024931696243584156, 1.057760238647461]
Loss test
[0.004808325786143541, 0.004278169013559818, 1.0603147745132446]
--------------- Epoch 399 ---------------
Loss real , Loss fake, Loss GAN
0.33881616592407227 0.33964577317237854 [0.0028005512431263924, 0.0023309325333684683, 0.9392375946044922]
Loss test
[0.004782551899552345, 0.004312851000577211, 0.9394003748893738]
--------------- Epoch 400 ---------------
Loss real , Loss fake, Loss GAN
0.3256502151489258 0.32650238275527954 [0.002947833389043808, 0.002385640051215887, 1.1243866682052612]
Loss test
[0.005218828562647104, 0.004658848512917757, 1.119958519935608]
--------------- Epoch 401 

Loss real , Loss fake, Loss GAN
0.33094364404678345 0.31712639331817627 [0.0025265533477067947, 0.002013149671256542, 1.026807427406311]
Loss test
[0.004778544418513775, 0.0042696730233728886, 1.0177429914474487]
--------------- Epoch 431 ---------------
Loss real , Loss fake, Loss GAN
0.31836628913879395 0.32792750000953674 [0.0023813031148165464, 0.0018813369097188115, 0.9999325275421143]
Loss test
[0.004817503038793802, 0.00431636581197381, 1.0022746324539185]
--------------- Epoch 432 ---------------
Loss real , Loss fake, Loss GAN
0.33162063360214233 0.3225913345813751 [0.0032086933497339487, 0.0026861364021897316, 1.0451139211654663]
Loss test
[0.00525143276900053, 0.00472710095345974, 1.0486632585525513]
--------------- Epoch 433 ---------------
Loss real , Loss fake, Loss GAN
0.326513409614563 0.3266405463218689 [0.0026406152173876762, 0.002165552694350481, 0.9501252174377441]
Loss test
[0.005153370089828968, 0.004677045624703169, 0.9526501893997192]
--------------- Epoch 434 -

Loss real , Loss fake, Loss GAN
0.3262811005115509 0.3176744282245636 [0.003061506198719144, 0.002455734647810459, 1.211543083190918]
Loss test
[0.004820708185434341, 0.004220990464091301, 1.1994335651397705]
--------------- Epoch 464 ---------------
Loss real , Loss fake, Loss GAN
0.32801365852355957 0.31378716230392456 [0.0028620916418731213, 0.002301528351381421, 1.121126413345337]
Loss test
[0.004714054986834526, 0.004152737092226744, 1.1226361989974976]
--------------- Epoch 465 ---------------
Loss real , Loss fake, Loss GAN
0.31800028681755066 0.31943848729133606 [0.002837130334228277, 0.002230360172688961, 1.2135404348373413]
Loss test
[0.004654991906136274, 0.0040511819534003735, 1.2076220512390137]
--------------- Epoch 466 ---------------
Loss real , Loss fake, Loss GAN
0.3249910771846771 0.3329964876174927 [0.0028748575132340193, 0.002316912403330207, 1.1158902645111084]
Loss test
[0.004782053176313639, 0.004228111822158098, 1.107883095741272]
--------------- Epoch 467 ----

Loss real , Loss fake, Loss GAN
0.33003225922584534 0.32879021763801575 [0.0023297308944165707, 0.001809030887670815, 1.0413999557495117]
Loss test
[0.004828079603612423, 0.00430930033326149, 1.0375585556030273]
--------------- Epoch 497 ---------------
Loss real , Loss fake, Loss GAN
0.3232933580875397 0.3321796953678131 [0.0026637506671249866, 0.0021206089295446873, 1.0862834453582764]
Loss test
[0.004887976683676243, 0.004342149011790752, 1.091658115386963]
--------------- Epoch 498 ---------------
Loss real , Loss fake, Loss GAN
0.324586421251297 0.3246138095855713 [0.002642363077029586, 0.0020811313297599554, 1.122463583946228]
Loss test
[0.005043461453169584, 0.004480746109038591, 1.1254303455352783]
--------------- Epoch 499 ---------------
Loss real , Loss fake, Loss GAN
0.32819920778274536 0.32191500067710876 [0.002489725360646844, 0.0019449889659881592, 1.0894728899002075]
Loss test
[0.004873072728514671, 0.00432869466021657, 1.0887562036514282]
--------------- Epoch 500 ----

In [7]:
generator.save("model.h5")