-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathwonparafac.m
194 lines (158 loc) · 5.26 KB
/
wonparafac.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
%
% Compute weighted non-negative parallel factors with multiplicative updates.
%
% P = wonparafac(X,k) computes an estimate of the best rank-k PARAFAC
% model of a tensor X with nonnegative constraints on the factors.
% The code requires tensor toolbox version 2.6.
% The tensor toolbox is available at: https://www.sandia.gov/~tgkolda/TensorToolbox/index-2.6.html
% This version uses the Lee & Seung multiplicative updates from
% their NMF algorithm. The input X can be a tensor, sptensor,
% ktensor, or ttensor. The result P is a ktensor.
%
%
% P = wonparafac(X,k,OPTS) specify options:
% OPTS.tol: Tolerance on difference in fit {1.0e-4}
% OPTS.maxiters: Maximum number of iterations {50}
% OPTS.dimorder: Order to loop through dimensions {1:ndims(A)}
% OPTS.init: Initial guess [{'random'}|'nvecs'|cell array]
% OPTS.printitn: Print fit every n iterations {1}
% OPTS.wegiht_W: Weight tensor with size of X. {ones(size(X))}
% OPTS.orthogonal: a vector (length = number of modes of X) with the strength of orthogonality constraint. 0 for no orthogonality constraint {false(1,length(size(X))}.
% OPTS.fix_mode: a logical vector (length = number of modes of X) with indication if the factor matrix for the mode need to be fixed {false(1,length(size(X)))}.
%
% [P,U0] = wonparafac(...) also returns the initial guess.
%
function [P, fit_info, Uinit] = wonparafac(X, k, opts)
switch nargin
case 2
opts = struct;
case 1
error('Number of factors (k) is not given.');
end
% Number of modes of X.
N = ndims(X);
fit_info = struct();
fits = [];
ortho = cell(N,1);
for n = 1:N
ortho{n} = [];
end
% Set options and parameters
fitchangetol = setparam(opts,'tol',1e-4);
maxiters = setparam(opts,'maxiters',500);
dimorder = setparam(opts,'dimorder',1:N);
init = setparam(opts,'init','random');
printitn = setparam(opts,'printitn',1);
orthogonal = setparam(opts, 'orthogonal', false(1,length(size(X))));
epsilon = 1e-12; % Small number to protect against round-off error
WeightW = setparam(opts, 'weight_W', ones(size(X)));
opt_arrange = setparam(opts, 'arrange', 'true');
fix_mode = setparam(opts, 'fix_mode', false(1,N));
%% Error checking
% Error checking on maxiters
if maxiters < 0
error('Maximum iteration (OPTS.maxiter) must be positive');
end
% Error checking on dimorder
if ~isequal(1:N,sort(dimorder))
error('OPTS.dimorder must include all elements from 1 to ndims(X)');
end
% Error checking weight
if sum(WeightW(:) < 0) > 0
error('weight (OPTS.weight_W) should be non-negative');
end
% error checking orthogonal constraint
if sum(orthogonal(:) < 0) > 0 || sum(orthogonal(:) > 1) > 0
error('orthogonal constraint has to between 0 and 1. Revise OPTS.orthogonal')
end
% Error checking fix_mode
if ~islogical(fix_mode) || length(fix_mode) ~= N
error('OPT.fix_mode has to be a length N logical vector')
end
%% Initialization of factors
if iscell(init)
for n = dimorder(1:end)
if ~all(size(Unit{n}) == [size(X,n), k])
error(['Initial factors (OPTS.init{', int2str(n),'}) has incorrect size']);
end
end
Uinit = init;
elseif isstr(init)
switch init
case 'random'
Uinit = cell(N,1);
for n = dimorder(1:end)
Uinit{n} = rand(size(X,n),k) + 0.1;
end
case 'nvecs'
Uinit = cell(N,1);
for n = dimorder(1:end)
e = min(k,size(X,n)-2);
Uinit{n} = abs(nvecs(X,n,e));
if (e < k)
Uinit{n} = [Uinit{n} rand(size(X,n),k-e)];
end
end
otherwise
error('Unsupported initialization method specified');
end
else
error('Unsupported type of initialization (check OPTS.init)')
end
U = Uinit;
fit = 0;
%% Multiplicative update rules
for iter = 1:maxiters
fitold = fit;
% Update across the modes
for n = dimorder(1:end)
if fix_mode(n) == false
inds = setdiff(1:N, n);
Y = khatrirao(U{inds(2)},U{inds(1)});
Y = ((U{n}*Y').*double(tenmat(tensor(WeightW),n,inds)))*Y;
Y = (1-orthogonal(n))*Y + ...
orthogonal(n)*U{n}*mttkrp(X.*WeightW,U,n)'*U{n};
% Initialize factor matrix
Unew = U{n};
% Update factors
tmp = mttkrp(X.*WeightW,U,n);
Unew = Unew .* (tmp + epsilon);
Unew = Unew ./ (Y + epsilon);
U{n} = Unew;
end
end
P = ktensor(U);
fit = 1-sum(sum(sum((double(X)-double(tensor(P))).^2)))/sum(double(X(:)).^2);
fits = [fits, fit];
fitchange = abs(fitold - fit);
% calculate the orthogonality
for n=1:N
ortho{n} = [ortho{n}, norm(U{n}'*U{n} - eye(size(U{n},2)))];
end
if mod(iter,printitn)==0
fprintf(' Iter %2d: fit = %e fitdelta = %7.1e\n', iter, fit, fitchange);
end
% Check for convergence
if (iter > 1) && (fitchange < fitchangetol)
break;
end
end
% return the outcome
fit_info.orthogonal = ortho;
fit_info.fits = fits;
if opt_arrange == true
P = arrange(P);
end
if printitn>0
fit = 1-sum(sum(sum((double(X)-double(tensor(P))).^2)))/sum(double(X(:)).^2);
fprintf(' Final fit = %e \n', fit);
end
return;
end
function x = setparam(opts,name,default)
if isfield(opts,name);
x = opts.(name);
else
x = default;
end
end