Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Vandermode committed May 24, 2020
1 parent 94a7e56 commit c007539
Show file tree
Hide file tree
Showing 6 changed files with 184 additions and 31 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,11 @@ Download ICVL hyperspectral image database from [here](http:https://icvl.cs.bgu.ac.il/
* Download our pretrained models from [OneDrive](https://1drv.ms/u/s!AqddfvhavTRiijWftKWgLfUgdSaD?e=nHGjIk) and move them to ```checkpoints/qrnn3d/gauss/``` and ```checkpoints/qrnn3d/complex/``` respectively.

* [Blind Gaussian noise removal]:
```python hsi_eval.py -a qrnn3d -p gauss -r -rp checkpoints/qrnn3d/gauss/model_epoch_50_118454.pth```
```python hsi_test.py -a qrnn3d -p gauss -r -rp checkpoints/qrnn3d/gauss/model_epoch_50_118454.pth```

* [Mixture noise removal]:
```python hsi_eval.py -a qrnn3d -p complex -r -rp checkpoints/qrnn3d/complex/model_epoch_100_159904.pth```
```python hsi_test.py -a qrnn3d -p complex -r -rp checkpoints/qrnn3d/complex/model_epoch_100_159904.pth```

### 3. Training from scratch
Expand Down
Binary file added data/Satellite/IMAGE.mat
Binary file not shown.
64 changes: 64 additions & 0 deletions hsi_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import os
import argparse

from utility import *
from hsi_setup import Engine, train_options
import models


model_names = sorted(name for name in models.__dict__
if name.islower() and not name.startswith("__")
and callable(models.__dict__[name]))


prefix = 'test'

if __name__ == '__main__':
"""Training settings"""
parser = argparse.ArgumentParser(
description='Hyperspectral Image Denoising')
opt = train_options(parser)
print(opt)

cuda = not opt.no_cuda
opt.no_log = True

"""Setup Engine"""
engine = Engine(opt)

datadir = '' # your input data dir
basefolder = '/media/kaixuan/DATA/Papers/Code/Matlab/ECCV2018/ECCVData'
# datadir = os.path.join(basefolder, 'icvl_512_50')
# datadir = os.path.join(basefolder, 'icvl_512_blind')
# datadir = os.path.join(basefolder, 'icvl_512_noniid')
datadir = os.path.join(basefolder, 'icvl_512_mixture')

mat_dataset = MatDataFromFolder(datadir, size=None)

# mat_dataset.filenames = [
# os.path.join(datadir, 'Lehavim_0910-1627.mat')
# ]

mat_transform = Compose([
LoadMatHSI(input_key='input', gt_key='gt', transform=lambda x:x[:,:,:][None]), # for validation
# LoadMatKey(key='hsi'), # for testing
# lambda x: x[None]
])

mat_dataset = TransformDataset(mat_dataset, mat_transform)
mat_loader = DataLoader(
mat_dataset,
batch_size=1, shuffle=False,
num_workers=1, pin_memory=cuda
)

resdir = None # your result dir

# res_arr, input_arr = engine.test_develop(mat_loader, savedir=resdir, verbose=True)
# print(res_arr.mean(axis=0))
engine.validate(mat_loader, '')
99 changes: 96 additions & 3 deletions hsi_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ def forward(self, predict, target):
total_loss += loss(predict, target) * weight
return total_loss

def extra_repr(self):
return 'weight={}'.format(self.weight)


def train_options(parser):
def _parse_str_args(args):
Expand All @@ -50,7 +53,7 @@ def _parse_str_args(args):
parser.add_argument('--wd', type=float, default=0,
help='weight decay. default=0')
parser.add_argument('--loss', type=str, default='l2',
help='which loss to choose.', choices=['l1', 'l2', 'smooth_l1', 'ssim'])
help='which loss to choose.', choices=['l1', 'l2', 'smooth_l1', 'ssim', 'l2_ssim'])
parser.add_argument('--init', type=str, default='kn',
help='which init scheme to choose.', choices=['kn', 'ku', 'xn', 'xu', 'edsr'])
parser.add_argument('--no-cuda', action='store_true', help='disable cuda?')
Expand All @@ -64,6 +67,8 @@ def _parse_str_args(args):
help='resume from checkpoint')
parser.add_argument('--no-ropt', '-nro', action='store_true',
help='not resume optimizer')
parser.add_argument('--chop', action='store_true',
help='forward chop')
parser.add_argument('--resumePath', '-rp', type=str,
default=None, help='checkpoint to use.')
parser.add_argument('--dataroot', '-d', type=str,
Expand Down Expand Up @@ -148,7 +153,7 @@ def __setup(self):
if self.opt.loss == 'ssim':
self.criterion = SSIMLoss(data_range=1, channel=31)
if self.opt.loss == 'l2_ssim':
self.criterion = MultipleLoss([nn.MSELoss(), SSIMLoss(data_range=1, channel=31)], weight=[1, 0.1])
self.criterion = MultipleLoss([nn.MSELoss(), SSIMLoss(data_range=1, channel=31)], weight=[1, 2.5e-3])

print(self.criterion)

Expand All @@ -173,6 +178,51 @@ def __setup(self):
print('==> Building model..')
print(self.net)

def forward(self, inputs):
if self.opt.chop:
output = self.forward_chop(inputs)
else:
output = self.net(inputs)

return output

def forward_chop(self, x, base=16):
n, c, b, h, w = x.size()
h_half, w_half = h // 2, w // 2

shave_h = np.ceil(h_half / base) * base - h_half
shave_w = np.ceil(w_half / base) * base - w_half

shave_h = shave_h if shave_h >= 10 else shave_h + base
shave_w = shave_w if shave_w >= 10 else shave_w + base

h_size, w_size = int(h_half + shave_h), int(w_half + shave_w)

inputs = [
x[..., 0:h_size, 0:w_size],
x[..., 0:h_size, (w - w_size):w],
x[..., (h - h_size):h, 0:w_size],
x[..., (h - h_size):h, (w - w_size):w]
]

outputs = [self.net(input_i) for input_i in inputs]

output = torch.zeros_like(x)
output_w = torch.zeros_like(x)

output[..., 0:h_half, 0:w_half] += outputs[0][..., 0:h_half, 0:w_half]
output_w[..., 0:h_half, 0:w_half] += 1
output[..., 0:h_half, w_half:w] += outputs[1][..., 0:h_half, (w_size - w + w_half):w_size]
output_w[..., 0:h_half, w_half:w] += 1
output[..., h_half:h, 0:w_half] += outputs[2][..., (h_size - h + h_half):h_size, 0:w_half]
output_w[..., h_half:h, 0:w_half] += 1
output[..., h_half:h, w_half:w] += outputs[3][..., (h_size - h + h_half):h_size, (w_size - w + w_half):w_size]
output_w[..., h_half:h, w_half:w] += 1

output /= output_w

return output

def __step(self, train, inputs, targets):
if train:
self.optimizer.zero_grad()
Expand All @@ -190,6 +240,14 @@ def __step(self, train, inputs, targets):
outputs = torch.cat(O, dim=1)
else:
outputs = self.net(inputs)
# outputs = torch.clamp(self.net(inputs), 0, 1)
# loss = self.criterion(outputs, targets)

# if outputs.ndimension() == 5:
# loss = self.criterion(outputs[:,0,...], torch.clamp(targets[:,0,...], 0, 1))
# else:
# loss = self.criterion(outputs, torch.clamp(targets, 0, 1))

loss = self.criterion(outputs, targets)

if train:
Expand Down Expand Up @@ -328,6 +386,10 @@ def torch2numpy(hsi):
res_arr[batch_idx, :] = MSIQA(outputs, targets)
input_arr[batch_idx, :] = MSIQA(inputs, targets)

"""Visualization"""
# Visualize3D(inputs.data[0].cpu().numpy())
# Visualize3D(outputs.data[0].cpu().numpy())

psnr = res_arr[batch_idx, 0]
ssim = res_arr[batch_idx, 1]
if verbose:
Expand All @@ -345,8 +407,39 @@ def torch2numpy(hsi):

return res_arr, input_arr

def test_real(self, test_loader, savedir=None):
"""Warning: this code is not compatible with bandwise flag"""
from scipy.io import savemat
from os.path import basename
self.net.eval()
dataset = test_loader.dataset.dataset

with torch.no_grad():
for batch_idx, inputs in enumerate(test_loader):
if not self.opt.no_cuda:
inputs = inputs.cuda()

outputs = self.forward(inputs)

"""Visualization"""
input_np = inputs[0].cpu().numpy()
output_np = outputs[0].cpu().numpy()

display = np.concatenate([input_np, output_np], axis=-1)

Visualize3D(display)
# Visualize3D(outputs[0].cpu().numpy())
# Visualize3D((outputs-inputs).data[0].cpu().numpy())

if savedir:
R_hsi = outputs.data[0].cpu().numpy()[0,...].transpose((1,2,0))
savepath = join(savedir, basename(dataset.filenames[batch_idx]).split('.')[0], self.opt.arch + '.mat')
savemat(savepath, {'R_hsi': R_hsi})

return outputs

def get_net(self):
if len(self.opt.gpu_ids) > 1:
return self.net.module
else:
return self.net
return self.net
26 changes: 7 additions & 19 deletions hsi_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,34 +31,22 @@
"""Setup Engine"""
engine = Engine(opt)

datadir = '' # your input data dir
basefolder = '/media/kaixuan/DATA/Papers/Code/Matlab/ECCV2018/ECCVData'
# datadir = os.path.join(basefolder, 'icvl_512_50')
datadir = os.path.join(basefolder, 'icvl_512_blind')
# datadir = os.path.join(basefolder, 'icvl_512_noniid')
# datadir = os.path.join(basefolder, 'icvl_512_mixture')

mat_dataset = MatDataFromFolder(datadir, size=None)

# mat_dataset.filenames = [
# os.path.join(datadir, 'Lehavim_0910-1627.mat')
# ]
mat_dataset = MatDataFromFolder('data/Satellite')

mat_transform = Compose([
LoadMatHSI(input_key='input', gt_key='gt', transform=lambda x:x[:,:,:][None]), # for validation
# LoadMatKey(key='hsi'), # for testing
# lambda x: x[None]
LoadMatKey(key='img'), # for testing
lambda x: x[:,:220,:256][None],
minmax_normalize,
])

mat_dataset = TransformDataset(mat_dataset, mat_transform)

mat_loader = DataLoader(
mat_dataset,
batch_size=1, shuffle=False,
num_workers=1, pin_memory=cuda
)

resdir = None # your result dir
# print(engine.net)

# res_arr, input_arr = engine.test_develop(mat_loader, savedir=resdir, verbose=True)
# print(res_arr.mean(axis=0))
engine.validate(mat_loader, '')
engine.test_real(mat_loader, savedir=None)
24 changes: 15 additions & 9 deletions utility/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def __init__(self, key):

def __call__(self, mat):
item = mat[self.key][:].transpose((2,0,1))
return item
return item.astype(np.float32)


# Define Datasets
Expand All @@ -297,13 +297,19 @@ def __len__(self):

class MatDataFromFolder(Dataset):
"""Wrap mat data from folder"""
def __init__(self, data_dir, load=loadmat, suffix='mat', size=None):
def __init__(self, data_dir, load=loadmat, suffix='mat', fns=None, size=None):
super(MatDataFromFolder, self).__init__()
self.filenames = [
os.path.join(data_dir, fn)
for fn in os.listdir(data_dir)
if fn.endswith(suffix)
]
if fns is not None:
self.filenames = [
os.path.join(data_dir, fn) for fn in fns
]
else:
self.filenames = [
os.path.join(data_dir, fn)
for fn in os.listdir(data_dir)
if fn.endswith(suffix)
]

self.load = load

if size and size <= len(self.filenames):
Expand Down Expand Up @@ -428,8 +434,8 @@ def __getitem__(self, idx):

if __name__ == '__main__':
"""Mat dataset test"""
# dataset = MatDataFromFolder('/media/kaixuan/DATA/Papers/Code/Matlab/ECCV2018/Result/Indian/Indian_pines/')
# dataset = MatDataFromFolder('/media/kaixuan/DATA/Papers/Code/Matlab/ECCV2018/ECCVResult/Indian/Indian_pines/')
# mat = dataset[0]
# hsi = mat['R_hsi'].transpose((2,0,1))
# Visualize3D(hsi)
pass
pass

0 comments on commit c007539

Please sign in to comment.