-
Notifications
You must be signed in to change notification settings - Fork 9
/
NonlinearSVM3DDWTRegionSampling.m
82 lines (82 loc) · 3.3 KB
/
NonlinearSVM3DDWTRegionSampling.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
function NonlinearSVM3DDWTRegionSampling(DataFile, timeofRepeatition)
%
addpath('..\data\remote sensing data');
addpath('..\tools\libsvm-3.20\matlab');
rawData = importdata(DataFile);% Load hyperspectral image and groud truth
if ndims(rawData) ~= 3 % save time
return;
end
indexof_= find(DataFile == '_',1);
if isempty(indexof_)
subfix = DataFile(1:end-4);
else
subfix = DataFile(1:indexof_-1);
end
resultsFile = ['Jresults\', subfix, '_', mfilename, '.mat'];
groundTruth = importdata([subfix, '_gt.mat']);
[m, n, b] = size(rawData);
feats = single(dwt3d_feature(rawData));
vdataCube = reshape(feats,[m*n,15*b]);
vgroundTruth = reshape(groundTruth, [numel(groundTruth),1]);
numofClass = max(groundTruth(:));
trainingSamples = cell(numofClass,1);
testingSamples = cell(numofClass,1);
trainingLabels = cell(numofClass,1);
testingLabels = cell(numofClass,1);
numofTest = zeros(numofClass,1);
sampleRateList = [0.05, 0.1, 0.25];
for repeat = 1:timeofRepeatition
for i = 1 : length(sampleRateList)
samplingRate = sampleRateList(i);
if i == 1 % try to use the same seeds when using different sampling rate
[trainingIndex, testingIndex, seeds] = createTrainingSamples(groundTruth, samplingRate);
else
[trainingIndex, testingIndex] = createTrainingSamples(groundTruth, samplingRate, seeds);
end
for c = 1: numofClass
cc = double(c);
trainingSamples{c} = vdataCube(trainingIndex{c},:);
trainingLabels{c} = ones(length(trainingIndex{c}),1)*cc;
testingSamples{c} = vdataCube(testingIndex{c},:);
testingLabels{c} = ones(length(testingIndex{c}),1)*cc;
numofTest(c) = numel(testingIndex{c});
end
mtrainingData = cell2mat(trainingSamples);
mtrainingLabels = cell2mat(trainingLabels);
mtrainingIndex = cell2mat(trainingIndex);
mtestingData = cell2mat(testingSamples);
mtestingLabels = cell2mat(testingLabels);
mtestingIndex = cell2mat(testingIndex);
trainingMap = zeros(m*n,1);
trainingMap(mtrainingIndex) = mtrainingLabels;
% figure; imagesc(reshape(trainingMap,[m,n])); % check the training samples
mtrainingData = double(mtrainingData);
%select parameters c and g
log2cList = -1:1:8;
log2gList = -1:1:8;
cv = zeros(length(log2cList), length(log2gList) );
parfor indexC = 1:length(log2cList)
log2c = log2cList(indexC);
tempcv = zeros(1,length(log2gList));
for indexG = 1:length(log2gList)
log2g = log2gList(indexG);
cmd = ['-q -v 5 -c ', num2str(2^log2c), ' -g ', num2str(2^log2g)];
tempcv(indexG) = svmtrain(mtrainingLabels, mtrainingData, cmd);
end
cv(indexC,:) = tempcv;
end
[~, indexcv]= max(cv(:));
[bestindexC, bestindexG] = ind2sub(size(cv), indexcv);
bestc = 2^log2cList(bestindexC);
bestg = 2^log2gList(bestindexG);
optPara = [ '-q -c ', num2str(bestc), ' -g ', num2str(bestg)];
svm = svmtrain(mtrainingLabels, mtrainingData, optPara);
mtestingData = double(mtestingData);
[predicted_labels, ~, ~] = svmpredict(mtestingLabels, mtestingData, svm);
resultMap = vgroundTruth;
resultMap(mtestingIndex) = predicted_labels;
% figure, imagesc(reshape(resultMap,[m,n]));
results(i, repeat) = assessment(mtestingLabels, predicted_labels, 'class' ); % calculate OA, kappa, AA
end
end
save(resultsFile, 'results');