forked from baowenbo/DAIN
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
770 additions
and
16 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import os.path | ||
import random | ||
# import glob | ||
import math | ||
from .listdatasets import ListDataset,Vimeo_90K_loader | ||
|
||
|
||
def make_dataset(root, list_file): | ||
raw_im_list = open(os.path.join(root, list_file)).read().splitlines() | ||
# the last line is invalid in test set. | ||
# print("The last sample is : " + raw_im_list[-1]) | ||
raw_im_list = raw_im_list[:-1] | ||
assert len(raw_im_list) > 0 | ||
random.shuffle(raw_im_list) | ||
|
||
return raw_im_list | ||
|
||
def Vimeo_90K_interp(root, split=1.0, single=False, task = 'interp' ): | ||
train_list = make_dataset(root,"tri_trainlist.txt") | ||
test_list = make_dataset(root,"tri_testlist.txt") | ||
train_dataset = ListDataset(root, train_list, loader=Vimeo_90K_loader) | ||
test_dataset = ListDataset(root, test_list, loader=Vimeo_90K_loader) | ||
return train_dataset, test_dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
from .Vimeo_90K_interp import Vimeo_90K_interp | ||
|
||
__all__ = ( | ||
'Vimeo_90K_interp', | ||
) | ||
|
||
# Vimeo_90K = "/tmp4/wenbobao_data/vimeo_triplet" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import torch.utils.data as data | ||
import os | ||
import os.path | ||
from scipy.ndimage import imread | ||
import numpy as np | ||
import random | ||
|
||
def Vimeo_90K_loader(root, im_path, input_frame_size = (3, 256, 448), output_frame_size = (3, 256, 448), data_aug = True): | ||
|
||
|
||
root = os.path.join(root,'sequences',im_path) | ||
|
||
if data_aug and random.randint(0, 1): | ||
path_pre2 = os.path.join(root, "im1.png") | ||
path_pre1 = os.path.join(root, "im2.png") | ||
path_mid = os.path.join(root, "im3.png") | ||
else: | ||
path_pre2 = os.path.join(root, "im1.png") | ||
path_pre1 = os.path.join(root, "im2.png") | ||
path_mid = os.path.join(root, "im3.png") | ||
|
||
im_pre2 = imread(path_pre2) | ||
im_pre1 = imread(path_pre1) | ||
im_mid = imread(path_mid) | ||
|
||
h_offset = random.choice(range(256 - input_frame_size[1] + 1)) | ||
w_offset = random.choice(range(448 - input_frame_size[2] + 1)) | ||
|
||
im_pre2 = im_pre2[h_offset:h_offset + input_frame_size[1], w_offset: w_offset + input_frame_size[2], :] | ||
im_pre1 = im_pre1[h_offset:h_offset + input_frame_size[1], w_offset: w_offset + input_frame_size[2], :] | ||
im_mid = im_mid[h_offset:h_offset + input_frame_size[1], w_offset: w_offset + input_frame_size[2], :] | ||
|
||
if data_aug: | ||
if random.randint(0, 1): | ||
im_pre2 = np.fliplr(im_pre2) | ||
im_mid = np.fliplr(im_mid) | ||
im_pre1 = np.fliplr(im_pre1) | ||
if random.randint(0, 1): | ||
im_pre2 = np.flipud(im_pre2) | ||
im_mid = np.flipud(im_mid) | ||
im_pre1 = np.flipud(im_pre1) | ||
|
||
X0 = np.transpose(im_pre2,(2,0,1)) | ||
X2 = np.transpose(im_mid, (2, 0, 1)) | ||
|
||
y = np.transpose(im_pre1, (2, 0, 1)) | ||
return X0.astype("float32")/ 255.0, \ | ||
X2.astype("float32")/ 255.0,\ | ||
y.astype("float32")/ 255.0 | ||
|
||
|
||
|
||
class ListDataset(data.Dataset): | ||
def __init__(self, root, path_list, loader=Vimeo_90K_loader): | ||
|
||
self.root = root | ||
self.path_list = path_list | ||
self.loader = loader | ||
|
||
def __getitem__(self, index): | ||
path = self.path_list[index] | ||
# print(path) | ||
image_0,image_2,image_1 = self.loader(self.root, path) | ||
return image_0,image_2,image_1 | ||
|
||
def __len__(self): | ||
return len(self.path_list) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
import sys | ||
import os | ||
|
||
import sys | ||
import threading | ||
import torch | ||
from torch.autograd import Variable | ||
from lr_scheduler import * | ||
from torch.autograd import gradcheck | ||
|
||
import numpy | ||
|
||
|
||
|
||
|
||
def charbonier_loss(x,epsilon): | ||
loss = torch.mean(torch.sqrt(x * x + epsilon * epsilon)) | ||
return loss | ||
def negPSNR_loss(x,epsilon): | ||
loss = torch.mean(torch.mean(torch.mean(torch.sqrt(x * x + epsilon * epsilon),dim=1),dim=1),dim=1) | ||
return torch.mean(-torch.log(1.0/loss) /100.0) | ||
|
||
def tv_loss(x,epsilon): | ||
loss = torch.mean( torch.sqrt( | ||
(x[:, :, :-1, :-1] - x[:, :, 1:, :-1]) ** 2 + | ||
(x[:, :, :-1, :-1] - x[:, :, :-1, 1:]) ** 2 + epsilon *epsilon | ||
) | ||
) | ||
return loss | ||
|
||
|
||
def gra_adap_tv_loss(flow, image, epsilon): | ||
w = torch.exp( - torch.sum( torch.abs(image[:,:,:-1, :-1] - image[:,:,1:, :-1]) + | ||
torch.abs(image[:,:,:-1, :-1] - image[:,:,:-1, 1:]), dim = 1)) | ||
tv = torch.sum(torch.sqrt((flow[:, :, :-1, :-1] - flow[:, :, 1:, :-1]) ** 2 + (flow[:, :, :-1, :-1] - flow[:, :, :-1, 1:]) ** 2 + epsilon *epsilon) ,dim=1) | ||
loss = torch.mean( w * tv ) | ||
return loss | ||
|
||
def smooth_loss(x,epsilon): | ||
loss = torch.mean( | ||
torch.sqrt( | ||
(x[:,:,:-1,:-1] - x[:,:,1:,:-1]) **2 + | ||
(x[:,:,:-1,:-1] - x[:,:,:-1,1:]) **2+ epsilon**2 | ||
) | ||
) | ||
return loss | ||
|
||
|
||
def motion_sym_loss(offset, epsilon, occlusion = None): | ||
if occlusion == None: | ||
# return torch.mean(torch.sqrt( (offset[:,:2,...] + offset[:,2:,...])**2 + epsilon **2)) | ||
return torch.mean(torch.sqrt( (offset[0] + offset[1])**2 + epsilon **2)) | ||
else: | ||
# TODO: how to design the occlusion aware offset symmetric loss? | ||
# return torch.mean(torch.sqrt((offset[:,:2,...] + offset[:,2:,...])**2 + epsilon **2)) | ||
return torch.mean(torch.sqrt((offset[0] + offset[1])**2 + epsilon **2)) | ||
|
||
|
||
|
||
|
||
def part_loss(diffs, offsets, occlusions, images, epsilon, use_negPSNR=False): | ||
if use_negPSNR: | ||
pixel_loss = [negPSNR_loss(diff, epsilon) for diff in diffs] | ||
else: | ||
pixel_loss = [charbonier_loss(diff, epsilon) for diff in diffs] | ||
#offset_loss = [tv_loss(offset[0], epsilon) + tv_loss(offset[1], epsilon) for offset in | ||
# offsets] | ||
|
||
if offsets[0][0] is not None: | ||
offset_loss = [gra_adap_tv_loss(offset[0],images[0], epsilon) + gra_adap_tv_loss(offset[1], images[1], epsilon) for offset in | ||
offsets] | ||
else: | ||
offset_loss = [Variable(torch.zeros(1).cuda())] | ||
# print(torch.max(occlusions[0])) | ||
# print(torch.min(occlusions[0])) | ||
# print(torch.mean(occlusions[0])) | ||
|
||
# occlusion_loss = [smooth_loss(occlusion, epsilon) + charbonier_loss(occlusion - 0.5, epsilon) for occlusion in occlusions] | ||
# occlusion_loss = [smooth_loss(occlusion, epsilon) + charbonier_loss(occlusion[:, 0, ...] - occlusion[:, 1, ...], epsilon) for occlusion in occlusions] | ||
|
||
|
||
|
||
sym_loss = [motion_sym_loss(offset,epsilon=epsilon) for offset in offsets] | ||
# sym_loss = [ motion_sym_loss(offset,occlusion) for offset,occlusion in zip(offsets,occlusions)] | ||
return pixel_loss, offset_loss, sym_loss | ||
|
Oops, something went wrong.