-
Notifications
You must be signed in to change notification settings - Fork 0
/
rpca_gd.m
157 lines (131 loc) · 4.21 KB
/
rpca_gd.m
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
function [U, V] = rpca_gd(Y, r, alpha, params)
% [U, V] = RPCA_GD(Y, r, alpha, params)
% Robust PCA via Non-convex Gradient Descent
%
% Y : A sparse matrix to be decomposed into a low-rank matrix M and a sparse
% matrix S. Unobserved entries are represented as zeros.
% r : Target rank
% alpha : An upper bound of max sparsity over the columns/rows of S
% params : parameters for the algorithm
% .step_const : Constant for step size (default .5)
% .max_iter : Maximum number of iterations (default 30)
% .tol : Desired Frobenius norm error (default 2e-4)
% .incoh : Incoherence of M (default 5)
%
% Output:
% U, V : M=U*V' is the estimated lowrank matrix
%
% By:
% Xinyang Yi, Dohyung Park, Yudong Chen, Constantine Caramanis
% {yixy,dhpark,constantine}@utexas.edu, [email protected]
% Default parameter settings
step_const = .5;
max_iter = 30;
tol = 2e-4;
do_project = 0;
gamma = 1;
incoh = 5;
% Read paramter settings
if isfield(params,'gamma') gamma = params.gamma; end
if isfield(params,'incoh') incoh = params.incoh; end
if isfield(params,'step_const') step_const = params.step_const; end
if isfield(params,'max_iter') max_iter = params.max_iter; end
if isfield(params,'tol') tol= params.tol; end
if isfield(params,'do_project') do_project = params.do_project; end
% Library paths
addpath PROPACK;
addpath MinMaxSelection;
% Setting up
err = zeros(1,max_iter);
time = zeros(1,max_iter);
Ynormfro = norm(Y,'fro');
[d1, d2] = size(Y);
is_sparse = issparse(Y);
if is_sparse
[I, J, Y_vec] = find(Y);
n = length(Y_vec);
obs_ind = sub2ind([d1,d2], I, J);
col = [0; find(diff(J)); n];
p = n/d1/d2;
if p>0.9
is_sparse = 0;
Y = full(Y);
end
else
p = 1;
end
%% Phase I: Initialization
t1 = tic; t = 1;
% Initial sparse projection
fprintf('Initial sparse projection; time %f \n', toc(t1));
alpha_col = alpha; alpha_row = alpha;
S = Tproj_partial(Y, gamma*p*alpha_col, gamma*p*alpha_row);
% Initial factorization
fprintf('Initial SVD; time %f \n', toc(t1));
[U,Sig,V] = lansvd((Y-S)/p,r,'L');
U = U(:,1:r) * sqrt(Sig(1:r,1:r));
V = V(:,1:r) * sqrt(Sig(1:r,1:r));
% Projection
if do_project
const1 = sqrt(4*incoh*r/d1)*Sig(1,1);
const2 = sqrt(4*incoh*r/d2)*Sig(1,1);
U = U .* repmat(min(ones(d1,1),const1./sqrt(sum(U.^2,2))),1,r);
V = V .* repmat(min(ones(d2,1),const2./sqrt(sum(V.^2,2))),1,r);
end
% Compute the initial error
err(t) = inf;
time(t) = toc(t1);
%% Phase II: Gradient Descent
steplength = step_const / Sig(1,1);
if is_sparse
YminusUV = sparse(I, J, 1, d1, d2, n);
else
YminusUV = zeros(d1, d2);
end
fprintf('Begin Gradient descent\n');
converged = 0;
while ~converged
fprintf('Iter no. %d err %e time %f \n', t, err(t), time(t));
t = t + 1;
%%
if is_sparse
UVobs_vec = compute_X_Omega(U, V, obs_ind);
%UVobs_vec = partXY(U', V', I, J, n)';
YminusUV = sparse(I, J, Y_vec-UVobs_vec, d1, d2, n); clearvars UVobs_vec;
else
YminusUV = Y - U*V';
end
%err(t) = norm(YminusUV-S, 'fro')/Ynormfro;
%% Sparse Projection for S
S = Tproj_partial(YminusUV, gamma*p*alpha_col, gamma*p*alpha_row);
E = YminusUV - S;
clearvars S;
%% Gradient Descent for U and V
Unew = U + steplength * (E * V) /p - steplength/16*U*(U'*U-V'*V);
Vnew = V + steplength * (U' * E)' /p - steplength/16*V*(V'*V-U'*U);
%% Projection
if do_project
Unew = Unew .* repmat(min(ones(d1,1),const1./sqrt(sum(Unew.^2,2))),1,r);
Vnew = Vnew .* repmat(min(ones(d2,1),const2./sqrt(sum(Vnew.^2,2))),1,r);
end
U = Unew;
V = Vnew;
%% Compute error
err(t) = norm(E, 'fro')/Ynormfro;
time(t) = toc(t1);
%% Convergence check
if (t >= max_iter)
converged = 1;
fprintf('Maximum iterations reached.\n');
end
if (err(t) <= max(tol,eps))
converged = 1;
fprintf('Target error reached.\n');
end
if (err(t) >= err(t-1) - eps)
converged = 1;
fprintf('No improvement.\n');
end
end
err = err(1:t);
time = time(1:t);