forked from srivaschennu/MOHAWK
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtestind.m
executable file
·127 lines (107 loc) · 4.32 KB
/
testind.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
function testind(basename,varargin)
% Copyright (C) 2018 Srivas Chennu, University of Kent and University of Cambrige,
% srivas@gmail.com
%
%
% Tests this individual's data against a previously estimated
% classification ensemble to generate a prediction.
%
%
% This program is free software: you can redistribute it and/or modify
% it under the terms of the GNU General Public License as published by
% the Free Software Foundation, either version 3 of the License, or
% (at your option) any later version.
%
% This program is distributed in the hope that it will be useful,
% but WITHOUT ANY WARRANTY; without even the implied warranty of
% MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
% GNU General Public License for more details.
%
% You should have received a copy of the GNU General Public License
% along with this program. If not, see <https://www.gnu.org/licenses/>.
param = finputcheck(varargin, {
'nclsyfyrs', 'real', [], 50; ...
});
loadpaths
changroups
clsyfyrlist = {
'svm-rbf_UWS_MCS-'
};
weiorbin = 2;
savefile = sprintf('%s/%s_mohawk.mat',filepath,basename);
load(savefile,'bpower','matrix','graphdata','tvals');
fprintf('Loading classifiers:');
for c = 1:length(clsyfyrlist)
fprintf(' %s',clsyfyrlist{c});
if c == 1
load(sprintf('%s/%s.mat',filepath,clsyfyrlist{c}),'output1','output2','clsyfyrinfo');
clsyfyr = vertcat(output1{:});
model = output2;
elseif c > 1
nextclsyfyr = load(sprintf('%s/%s.mat',filepath,clsyfyrlist{c}),'output1','output2','clsyfyrinfo');
clsyfyr = cat(1,clsyfyr,vertcat(nextclsyfyr.output1{:}));
model = cat(1,model,nextclsyfyr.output2);
clsyfyrinfo.clsyfyrparam = cat(1,clsyfyrinfo.clsyfyrparam,nextclsyfyr.clsyfyrinfo.clsyfyrparam);
end
end
fprintf('\n');
clear output1 output2
fprintf('Testing with clsyfyr');
for k = 1:size(clsyfyrinfo.clsyfyrparam,1)
if k > 1
fprintf(repmat('\b',1,length(progstr)));
end
progstr = sprintf(' %d/%d',k,size(clsyfyrinfo.clsyfyrparam,1));
fprintf(progstr);
measure = clsyfyrinfo.clsyfyrparam{k,1};
bandidx = find(strcmp(clsyfyrinfo.clsyfyrparam{k,2},clsyfyrinfo.bands));
if strcmpi(measure,'power')
features = bpower(bandidx,:) * 100;
elseif strcmpi(measure,'median')
features = median(matrix(bandidx,:,:),3);
elseif strcmpi(measure,'mean')
features = mean(matrix(bandidx,:,:),3);
else
m = strcmpi(measure,graphdata(:,1));
features = graphdata{m,weiorbin}(bandidx,:,:);
%round down to 3 decimal places
precision = 3;
threshidx = ismember(round(tvals * 10^precision),round(clsyfyrinfo.clsyfyrparam{k,3} * 10^precision));
if ~all(round(tvals(threshidx) * 10^precision) == round(clsyfyrinfo.clsyfyrparam{k,3} * 10^precision))
error('getfeatures: some requested thresholds not found!');
end
features = features(:,threshidx,:);
end
if ndims(features) == 3
features = permute(features,[1 3 2]);
end
if ~isempty(clsyfyr(k).pcaCoeff)
features = features * clsyfyr(k).pcaCoeff;
end
switch clsyfyrinfo.clsyfyrparam{k,4}{1}
case {'knn' 'svm-linear' 'svm-rbf' 'tree' 'nbayes'}
testres(k).predlabels = predict(model{k}, features);
case 'nn'
testres(k).predlabels = (vec2ind(compet(model{k}(features')))-1)';
end
end
fprintf('\n');
for c = 1:length(clsyfyr)
clsyfyr(c).cm = round(clsyfyr(c).cm * 100 ./ repmat(sum(clsyfyr(c).cm,2),1,size(clsyfyr(c).cm,2),1));
clsyfyr(c).cm = clsyfyr(c).cm + eps;
clsyfyr(c).cm = clsyfyr(c).cm ./ repmat(sum(clsyfyr(c).cm,1),size(clsyfyr(c).cm,1),1,1);
end
fprintf('Combining classifiers.\n');
listname = 'liege';
load(sprintf('%s/combclsyfyr_%s.mat',filepath,listname),'perfsort','perfsort');
for k = 1:length(perfsort)
thispred = testres(perfsort(k)).predlabels;
indprob(k,:) = clsyfyr(perfsort(k)).cm(:,thispred+1);
if k == 1
combprob(k,:) = indprob(k,:);
else
combprob(k,:) = combprob(k-1,:) .* indprob(k,:);
end
combprob(k,:) = combprob(k,:) ./ sum(combprob(k,:));
end
save(savefile,'clsyfyrinfo','testres','indprob','combprob','-append');