Skip to content

Commit

Permalink
final version VeIGAN
Browse files Browse the repository at this point in the history
  • Loading branch information
nuneslu committed Oct 1, 2019
1 parent f8899a2 commit 4d45004
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 39 deletions.
7 changes: 4 additions & 3 deletions inpaint.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ WGAN_GP_LAMBDA: 10
COARSE_L1_ALPHA: 1.2
L1_LOSS_ALPHA: 1.2
AE_LOSS_ALPHA: 1.2
V_LOSS_ALPHA: 0.8
GAN_WITH_MASK: False
DISCOUNTED_MASK: True
RANDOM_SEED: False
Expand All @@ -21,7 +22,7 @@ PADDING: 'SAME'
NUM_GPUS: 1
GPU_ID: 0 # -1 indicate select any available one, otherwise select gpu ID, e.g. [0,1,3]
TRAIN_SPE: 1000
MAX_ITERS: 30000
MAX_ITERS: 60000
VIZ_MAX_OUT: 10
GRADS_SUMMARY: False
GRADIENT_CLIP: False
Expand All @@ -31,8 +32,8 @@ VAL_PSTEPS: 1000
# data
DATA_FLIST:
veigan: [
'./data_flist/train_depth_shuffled.flist',
'./data_flist/validation_depth_shuffled.flist'
'/home/lrm/CaRINA/Inpainting/ddata_flist3/train_cs_shuffled.flist',
'/home/lrm/CaRINA/Inpainting/ddata_flist3/validation_cs_shuffled.flist'
]

STATIC_VIEW_SIZE: 30
Expand Down
62 changes: 26 additions & 36 deletions inpaint_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,40 +87,23 @@ def build_inpaint_net(self, x, mask, config=None, reuse=False,
#####################################################

x = gen_conv(x, cnum, 5, 1, name='conv1')
#x, _ = sparse_conv(x, filters=cnum, kernel_size=5, binary_mask=None, strides=1, name='sparse_conv1')
x = gen_conv(x, 2*cnum, 3, 2, name='conv2_downsample')
x = gen_conv(x, 2*cnum, 3, 1, name='conv3')
#x, _ = sparse_conv(x, filters=2*cnum, kernel_size=5, binary_mask=None, strides=1, name='sparse_conv3')
x = gen_conv(x, 4*cnum, 3, 2, name='conv4_downsample')
x = gen_conv(x, 4*cnum, 3, 1, name='conv5')
x = gen_conv(x, 4*cnum, 3, 1, name='conv6')
#x, _ = sparse_conv(x, filters=4*cnum, kernel_size=3, binary_mask=None, strides=1, name='sparse_conv6')
mask_s = resize_mask_like(mask, x)
x = gen_conv(x, 4*cnum, 3, rate=2, name='conv7_atrous')
x = gen_conv(x, 4*cnum, 3, rate=4, name='conv8_atrous')
x = gen_conv(x, 4*cnum, 3, rate=8, name='conv9_atrous')
x = gen_conv(x, 4*cnum, 3, rate=16, name='conv10_atrous')
#x, _ = sparse_conv(x, filters=4*cnum, kernel_size=3, binary_mask=None, strides=1, name='sparse_conv10')
x = gen_conv(x, 4*cnum, 3, 1, name='conv11')
x = gen_conv(x, 4*cnum, 3, 1, name='conv12')
#x, _ = sparse_conv(x, filters=4*cnum, kernel_size=3, binary_mask=None, strides=1, name='sparse_conv12')
x = gen_deconv(x, 2*cnum, name='conv13_upsample')
x = gen_conv(x, 2*cnum, 3, 1, name='conv14')
#x, _ = sparse_conv(x, filters=2*cnum, kernel_size=3, binary_mask=None, strides=1, name='sparse_conv14')
x = gen_deconv(x, cnum, name='conv15_upsample')
x = gen_conv(x, cnum//2, 3, 1, name='conv16')
x = gen_conv(x, 3, 3, 1, activation=None, name='conv17')
#x, _ = sparse_conv(x, filters=3, kernel_size=3, binary_mask=None, strides=1, name='sparse_conv17')
#Sparse Convolution
#b_mask = None
#x, b_mask = sparse_conv(x, filters=cnum, kernel_size=3, binary_mask=b_mask, strides=1, name="sparse_conv1")
#x, b_mask = sparse_conv(x, filters=2*cnum, kernel_size=3, binary_mask=b_mask, strides=1, name="sparse_conv2")
#x, b_mask = sparse_conv(x, filters=2*cnum, kernel_size=3, binary_mask=b_mask, strides=1, name="sparse_conv3")
#x, _ = sparse_conv(x, filters=cnum, kernel_size=3, binary_mask=b_mask, strides=1, name="sparse_conv4")

#x = gen_conv(x, cnum, 3, 1, name='conv18')
#x = gen_conv(x, cnum//2, 3, 1, name='conv19')
#x = gen_conv(x, 3, 3, 1, activation=None, name='conv20')

x = tf.clip_by_value(x, -1., 1.)
x_stage1 = x
Expand All @@ -133,30 +116,24 @@ def build_inpaint_net(self, x, mask, config=None, reuse=False,
x.set_shape(xin.get_shape().as_list())
# conv branch
xnow = tf.concat([x, ones_x, ones_x*mask], axis=3)
#Mask the surf output to compute only real depth values (ignore 0 depth/black occluded regions)
x_surf = surf_conv((tf.image.rgb_to_grayscale(x) + 1.) * 127.5) * xsurf_mask
x_surf = surf_conv((tf.image.rgb_to_grayscale(x) + 1.) * 127.5)
xnow_surf = tf.concat([tf.image.rgb_to_grayscale(x), x_surf, ones_x, ones_x*mask], axis=3)

x = gen_conv(xnow, cnum, 5, 1, name='xconv1')
#x, _ = sparse_conv(x, filters=cnum, kernel_size=5, binary_mask=None, strides=1, name='xsparse_conv1')
x = gen_conv(x, cnum, 3, 2, name='xconv2_downsample')
x = gen_conv(x, 2*cnum, 3, 1, name='xconv3')
#x, _ = sparse_conv(x, filters=2*cnum, kernel_size=3, binary_mask=None, strides=1, name='xsparse_conv3')
x = gen_conv(x, 2*cnum, 3, 2, name='xconv4_downsample')
x = gen_conv(x, 4*cnum, 3, 1, name='xconv5')
x = gen_conv(x, 4*cnum, 3, 1, name='xconv6')
#x, _ = sparse_conv(x, filters=4*cnum, kernel_size=3, binary_mask=None, strides=1, name='xsparse_conv6')
x = gen_conv(x, 4*cnum, 3, rate=2, name='xconv7_atrous')
x = gen_conv(x, 4*cnum, 3, rate=4, name='xconv8_atrous')
x = gen_conv(x, 4*cnum, 3, rate=8, name='xconv9_atrous')
x = gen_conv(x, 4*cnum, 3, rate=16, name='xconv10_atrous')
x_hallu = x
# attention branch
x = gen_conv(xnow_surf, cnum, 5, 1, name='pmconv1')
#x, _ = sparse_conv(x, filters=cnum, kernel_size=5, binary_mask=None, strides=1, name='sparse_pmconv1')
x = gen_conv(x, cnum, 3, 2, name='pmconv2_downsample')
x = gen_conv(x, 2*cnum, 3, 1, name='pmconv3')
#x, _ = sparse_conv(x, filters=2*cnum, kernel_size=3, binary_mask=None, strides=1, name='sparse_pmconv3')
x = gen_conv(x, 4*cnum, 3, 2, name='pmconv4_downsample')
x = gen_conv(x, 4*cnum, 3, 1, name='pmconv5')
x = gen_conv(x, 4*cnum, 3, 1, name='pmconv6',
Expand All @@ -171,10 +148,8 @@ def build_inpaint_net(self, x, mask, config=None, reuse=False,

x = gen_conv(x, 4*cnum, 3, 1, name='allconv11')
x = gen_conv(x, 4*cnum, 3, 1, name='allconv12')
#x, _ = sparse_conv(x, filters=4*cnum, kernel_size=3, binary_mask=None, strides=1, name='sparse_allconv12')
x = gen_deconv(x, 2*cnum, name='allconv13_upsample')
x = gen_conv(x, 2*cnum, 3, 1, name='allconv14')
#x, _ = sparse_conv(x, filters=2*cnum, kernel_size=3, binary_mask=None, strides=1, name='sparse_allconv14')
x = gen_deconv(x, cnum, name='allconv15_upsample')
x = gen_conv(x, cnum//2, 3, 1, name='allconv16')
x = gen_conv(x, 3, 3, 1, activation=None, name='allconv17')
Expand All @@ -190,7 +165,6 @@ def build_wgan_local_discriminator(self, x, reuse=False, training=True):
x = dis_conv(x, cnum*2, name='conv2', training=training)
x = dis_conv(x, cnum*4, name='conv3', training=training)
x = dis_conv(x, cnum*8, name='conv4', training=training)
#x = dis_conv(x, cnum*8, name='conv5', training=training)
x = flatten(x, name='flatten')
return x

Expand Down Expand Up @@ -226,7 +200,7 @@ def build_graph_with_losses(self, batch_data, config, training=True,
# generate mask, 1 represents masked point
bbox = random_bbox(config)

#Generate mask from box parameters
#Generate mask from box parameters
mask = bbox2mask(bbox, config, name='mask_c')

batch_incomplete = batch_pos*(1.-mask)
Expand Down Expand Up @@ -269,8 +243,21 @@ def build_graph_with_losses(self, batch_data, config, training=True,
tf.concat(viz_img, axis=2),
'raw_incomplete_predicted_complete', config.VIZ_MAX_OUT)

# global surface patch
surf_pos = surf_conv((tf.image.rgb_to_grayscale(batch_pos) + 1.) * 127.5)
surf_neg = surf_conv((tf.image.rgb_to_grayscale(batch_complete) + 1.) * 127.5)
surf_batch_pos_neg = tf.concat([surf_pos, surf_neg], axis=0)
batch_pos = tf.concat([batch_pos, surf_pos], axis=-1)
batch_complete = tf.concat([batch_complete, surf_neg], axis=-1)
# gan
batch_pos_neg = tf.concat([batch_pos, batch_complete], axis=0)

# local surface patch
local_surf_pos = surf_conv((tf.image.rgb_to_grayscale(local_patch_batch_pos) + 1.) * 127.5)
local_surf_neg = surf_conv((tf.image.rgb_to_grayscale(local_patch_batch_complete) + 1.) * 127.5)
surf_batch_pos_neg = tf.concat([surf_pos, surf_neg], axis=0)
local_patch_batch_pos = tf.concat([local_patch_batch_pos, local_surf_pos], axis=-1)
local_patch_batch_complete = tf.concat([local_patch_batch_complete, local_surf_neg], axis=-1)
# local deterministic patch
local_patch_batch_pos_neg = tf.concat([local_patch_batch_pos, local_patch_batch_complete], 0)

Expand Down Expand Up @@ -299,6 +286,14 @@ def build_graph_with_losses(self, batch_data, config, training=True,
losses['gp_loss'] = config.WGAN_GP_LAMBDA * (penalty_local + penalty_global)
losses['d_loss'] = losses['d_loss'] + losses['gp_loss']

####################################VECTORIAL LOSS##########################################
batch_out = tf.image.rgb_to_grayscale((batch_predicted + 1.) * 127.5)
batch_coarse = tf.image.rgb_to_grayscale((x1 + 1.) * 127.5)
batch_in = tf.image.rgb_to_grayscale(batch_data)
losses['v_loss'] = tf.reduce_mean(tf.abs(surf_conv(batch_out) - surf_conv(batch_in))) #FINAL RESULT
losses['v_loss'] += tf.reduce_mean(tf.abs(surf_conv(batch_coarse) - surf_conv(batch_in))) #COARSE RESULT (AE LOSS)
############################################################################################

if summary and not config.PRETRAIN_COARSE_NETWORK:
gradients_summary(g_loss_local, batch_predicted, name='g_loss_local')
gradients_summary(g_loss_global, batch_predicted, name='g_loss_global')
Expand All @@ -318,19 +313,14 @@ def build_graph_with_losses(self, batch_data, config, training=True,
gradients_summary(losses['l1_loss'], x2, name='l1_loss_to_x2')
gradients_summary(losses['ae_loss'], x1, name='ae_loss_to_x1')
gradients_summary(losses['ae_loss'], x2, name='ae_loss_to_x2')
gradients_summary(losses['v_loss'], batch_coarse, name='v_loss_to_x1')
gradients_summary(losses['v_loss'], batch_out, name='v_loss_to_x2')
if config.PRETRAIN_COARSE_NETWORK:
losses['g_loss'] = 0
else:
losses['g_loss'] = config.GAN_LOSS_ALPHA * losses['g_loss']
losses['g_loss'] += config.L1_LOSS_ALPHA * losses['l1_loss']

batch_out = tf.image.rgb_to_grayscale((batch_predicted + 1.) * 127.5)
batch_in = tf.image.rgb_to_grayscale(batch_data)

batch_mask = tf.clip_by_value(batch_data, 0., 1.)
####################################VECTORIAL LOSS##########################################
losses['g_loss'] += 0.8 * tf.reduce_mean(tf.abs(surf_conv(batch_out) * batch_mask - surf_conv(batch_in) * batch_mask))
############################################################################################
losses['g_loss'] += config.V_LOSS_ALPHA * losses['v_loss']


logger.info('Set L1_LOSS_ALPHA to %f' % config.L1_LOSS_ALPHA)
Expand Down

0 comments on commit 4d45004

Please sign in to comment.