Skip to content

Commit

Permalink
add training code
Browse files Browse the repository at this point in the history
  • Loading branch information
baowenbo committed Aug 23, 2019
1 parent 82c39bb commit 714d2dc
Show file tree
Hide file tree
Showing 8 changed files with 770 additions and 16 deletions.
23 changes: 23 additions & 0 deletions datasets/Vimeo_90K_interp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import os.path
import random
# import glob
import math
from .listdatasets import ListDataset,Vimeo_90K_loader


def make_dataset(root, list_file):
raw_im_list = open(os.path.join(root, list_file)).read().splitlines()
# the last line is invalid in test set.
# print("The last sample is : " + raw_im_list[-1])
raw_im_list = raw_im_list[:-1]
assert len(raw_im_list) > 0
random.shuffle(raw_im_list)

return raw_im_list

def Vimeo_90K_interp(root, split=1.0, single=False, task = 'interp' ):
train_list = make_dataset(root,"tri_trainlist.txt")
test_list = make_dataset(root,"tri_testlist.txt")
train_dataset = ListDataset(root, train_list, loader=Vimeo_90K_loader)
test_dataset = ListDataset(root, test_list, loader=Vimeo_90K_loader)
return train_dataset, test_dataset
7 changes: 7 additions & 0 deletions datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from .Vimeo_90K_interp import Vimeo_90K_interp

__all__ = (
'Vimeo_90K_interp',
)

# Vimeo_90K = "/tmp4/wenbobao_data/vimeo_triplet"
67 changes: 67 additions & 0 deletions datasets/listdatasets.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import torch.utils.data as data
import os
import os.path
from scipy.ndimage import imread
import numpy as np
import random

def Vimeo_90K_loader(root, im_path, input_frame_size = (3, 256, 448), output_frame_size = (3, 256, 448), data_aug = True):


root = os.path.join(root,'sequences',im_path)

if data_aug and random.randint(0, 1):
path_pre2 = os.path.join(root, "im1.png")
path_pre1 = os.path.join(root, "im2.png")
path_mid = os.path.join(root, "im3.png")
else:
path_pre2 = os.path.join(root, "im1.png")
path_pre1 = os.path.join(root, "im2.png")
path_mid = os.path.join(root, "im3.png")

im_pre2 = imread(path_pre2)
im_pre1 = imread(path_pre1)
im_mid = imread(path_mid)

h_offset = random.choice(range(256 - input_frame_size[1] + 1))
w_offset = random.choice(range(448 - input_frame_size[2] + 1))

im_pre2 = im_pre2[h_offset:h_offset + input_frame_size[1], w_offset: w_offset + input_frame_size[2], :]
im_pre1 = im_pre1[h_offset:h_offset + input_frame_size[1], w_offset: w_offset + input_frame_size[2], :]
im_mid = im_mid[h_offset:h_offset + input_frame_size[1], w_offset: w_offset + input_frame_size[2], :]

if data_aug:
if random.randint(0, 1):
im_pre2 = np.fliplr(im_pre2)
im_mid = np.fliplr(im_mid)
im_pre1 = np.fliplr(im_pre1)
if random.randint(0, 1):
im_pre2 = np.flipud(im_pre2)
im_mid = np.flipud(im_mid)
im_pre1 = np.flipud(im_pre1)

X0 = np.transpose(im_pre2,(2,0,1))
X2 = np.transpose(im_mid, (2, 0, 1))

y = np.transpose(im_pre1, (2, 0, 1))
return X0.astype("float32")/ 255.0, \
X2.astype("float32")/ 255.0,\
y.astype("float32")/ 255.0



class ListDataset(data.Dataset):
def __init__(self, root, path_list, loader=Vimeo_90K_loader):

self.root = root
self.path_list = path_list
self.loader = loader

def __getitem__(self, index):
path = self.path_list[index]
# print(path)
image_0,image_2,image_1 = self.loader(self.root, path)
return image_0,image_2,image_1

def __len__(self):
return len(self.path_list)
86 changes: 86 additions & 0 deletions loss_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import sys
import os

import sys
import threading
import torch
from torch.autograd import Variable
from lr_scheduler import *
from torch.autograd import gradcheck

import numpy




def charbonier_loss(x,epsilon):
loss = torch.mean(torch.sqrt(x * x + epsilon * epsilon))
return loss
def negPSNR_loss(x,epsilon):
loss = torch.mean(torch.mean(torch.mean(torch.sqrt(x * x + epsilon * epsilon),dim=1),dim=1),dim=1)
return torch.mean(-torch.log(1.0/loss) /100.0)

def tv_loss(x,epsilon):
loss = torch.mean( torch.sqrt(
(x[:, :, :-1, :-1] - x[:, :, 1:, :-1]) ** 2 +
(x[:, :, :-1, :-1] - x[:, :, :-1, 1:]) ** 2 + epsilon *epsilon
)
)
return loss


def gra_adap_tv_loss(flow, image, epsilon):
w = torch.exp( - torch.sum( torch.abs(image[:,:,:-1, :-1] - image[:,:,1:, :-1]) +
torch.abs(image[:,:,:-1, :-1] - image[:,:,:-1, 1:]), dim = 1))
tv = torch.sum(torch.sqrt((flow[:, :, :-1, :-1] - flow[:, :, 1:, :-1]) ** 2 + (flow[:, :, :-1, :-1] - flow[:, :, :-1, 1:]) ** 2 + epsilon *epsilon) ,dim=1)
loss = torch.mean( w * tv )
return loss

def smooth_loss(x,epsilon):
loss = torch.mean(
torch.sqrt(
(x[:,:,:-1,:-1] - x[:,:,1:,:-1]) **2 +
(x[:,:,:-1,:-1] - x[:,:,:-1,1:]) **2+ epsilon**2
)
)
return loss


def motion_sym_loss(offset, epsilon, occlusion = None):
if occlusion == None:
# return torch.mean(torch.sqrt( (offset[:,:2,...] + offset[:,2:,...])**2 + epsilon **2))
return torch.mean(torch.sqrt( (offset[0] + offset[1])**2 + epsilon **2))
else:
# TODO: how to design the occlusion aware offset symmetric loss?
# return torch.mean(torch.sqrt((offset[:,:2,...] + offset[:,2:,...])**2 + epsilon **2))
return torch.mean(torch.sqrt((offset[0] + offset[1])**2 + epsilon **2))




def part_loss(diffs, offsets, occlusions, images, epsilon, use_negPSNR=False):
if use_negPSNR:
pixel_loss = [negPSNR_loss(diff, epsilon) for diff in diffs]
else:
pixel_loss = [charbonier_loss(diff, epsilon) for diff in diffs]
#offset_loss = [tv_loss(offset[0], epsilon) + tv_loss(offset[1], epsilon) for offset in
# offsets]

if offsets[0][0] is not None:
offset_loss = [gra_adap_tv_loss(offset[0],images[0], epsilon) + gra_adap_tv_loss(offset[1], images[1], epsilon) for offset in
offsets]
else:
offset_loss = [Variable(torch.zeros(1).cuda())]
# print(torch.max(occlusions[0]))
# print(torch.min(occlusions[0]))
# print(torch.mean(occlusions[0]))

# occlusion_loss = [smooth_loss(occlusion, epsilon) + charbonier_loss(occlusion - 0.5, epsilon) for occlusion in occlusions]
# occlusion_loss = [smooth_loss(occlusion, epsilon) + charbonier_loss(occlusion[:, 0, ...] - occlusion[:, 1, ...], epsilon) for occlusion in occlusions]



sym_loss = [motion_sym_loss(offset,epsilon=epsilon) for offset in offsets]
# sym_loss = [ motion_sym_loss(offset,occlusion) for offset,occlusion in zip(offsets,occlusions)]
return pixel_loss, offset_loss, sym_loss

Loading

0 comments on commit 714d2dc

Please sign in to comment.