Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
areslp committed Mar 25, 2013
1 parent 142c15d commit 2abc5eb
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 22 deletions.
5 changes: 3 additions & 2 deletions l21.m
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
function [ M ] = l21( Q, lambda )
QN=sqrt(sum(Q.^2,1));
% assert(length(find(QN<=0))==0);
coefficient=(QN-lambda)./QN;
zero=find(QN==0);
coefficient=(QN-lambda)./QN; % may be NaN
coefficient(zero)=0;
invalid=find(coefficient<=0);
coefficient(invalid)=0;
M=bsxfun(@times,Q,coefficient);
50 changes: 30 additions & 20 deletions multi_NNLRS.m
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
function [Z,ZZ,E] = multi_NNLRS(X,lambda,beta,alpha)
% solve \sum_{i=1}^k(||Z_i||_*+beta||Z_i||_1+lambda||E_i||_{2,1})+alpha||Z||_{2,1}

tic;

% init vars
k=length(X);
[m,n]=size(X{1});
Expand Down Expand Up @@ -52,14 +56,15 @@
eta1=cell(k,1);
for i=1:k
eta1{i}=norm2X{i}*norm2X{i}*1.02;%eta needs to be larger than ||X||_2^2, but need not be too large.
fprintf(1,'eta1{%d} is %f\n',i,eta1{i});
end
mu=1e-6;
max_mu=10^10;
rho=1.9;
% epsilon=1e-4;
% epsilon2=1e-5; % must be small!
epsilon=1e-6;
epsilon2=1e-5; % must be small!
epsilon2=1e-2; % must be small!
MAX_ITER=1000;
iter=0;
convergenced=false;
Expand All @@ -84,31 +89,31 @@
% update Z
[F]=cellfun(@updateF,J,Y2,S,Y3,cmu,'UniformOutput',false);
[M]=cellfun(@updateM,F,'UniformOutput',false);
MM=zeros(k,n*n);
for i=1:k
ZZ(i,:)=M{i};
MM(i,:)=M{i};
end
ZZ=l21(MM,alpha/(2*mu));
if alpha==0
assert(nnz(ZZ-MM)==0);
end
ZZ=l21(ZZ,alpha/(2*mu));
% update Z_i
Zk=Z;
for i=1:k
Z{i}=reshape(ZZ(i,:),n,n)';
end
% update E_i
Ek=E;
[E]=cellfun(@updateE,X,S,E,Y1,cmu,clambda,'UniformOutput',false);

% parameter update rule

% check convergence
[Xv,Xc,ZJv,ZJc,ZSv,ZSc,Zc,Jc,Sc,Ec] = cellfun(@caculateTempVars,X,S,E,Z,J,Zk,Jk,Sk,Ek,Xf,'UniformOutput',false);
[Xv,Xc,ZJv,ZJc,ZSv,ZSc,Zc,Jc,Sc,Ec,Cmax] = cellfun(@caculateTempVars,X,S,E,Z,J,Zk,Jk,Sk,Ek,Xf,eta1,cmu,'UniformOutput',false);
changeX=max([Xv{:}]);
changeZJ=max([ZJv{:}]);
changeZS=max([ZSv{:}]);
changeZ=max([Zc{:}]);
changeJ=max([Jc{:}]);
changeS=max([Sc{:}]);
changeE=max([Ec{:}]);
tmp=[changeZ changeJ changeS changeE ];
gap=mu*max(tmp);
gap=max([Cmax{:}]);
if mod(iter,50)==0
fprintf(1,'===========================================================================================================\n');
fprintf(1,'gap between two iteration is %f,mu is %f\n',gap,mu);
Expand All @@ -118,8 +123,8 @@
end
fprintf(1,'\n');
end
% if changeX <= epsilon && changeZJ <= epsilon && changeZS <= epsilon
if changeX <= epsilon && gap <=epsilon2 && changeZJ <= epsilon && changeZS <= epsilon
if changeX <= epsilon && gap <=epsilon2
% if changeX <= epsilon && gap <=epsilon2 && changeZJ <= epsilon && changeZS <= epsilon
convergenced=true;
fprintf(2,'convergenced, iter is %d\n',iter);
fprintf(2,'iter %d,mu is %f,ResidualX is %f,changeZJ is %f,changeZS is %f\n',iter,mu,changeX,changeZJ,changeZS);
Expand All @@ -139,6 +144,8 @@
iter=iter+1;
end

toc;

function [S,svp] = updateS(xtx,X,E,Y1,Z,S,Y3,eta1,mu)
T=-mu*(xtx-xtx*S-X'*E+X'*Y1/mu+Z-S+Y3/mu);
% argmin_{S} 1/(mu*eta1)||S||_*+1/2*||S-S_k+T/(mu*eta1)||_F^2
Expand All @@ -147,28 +154,31 @@
function [J] = updateJ(Z,J,Y2,mu,beta)
J=wthresh(Z+Y2/mu,'s',beta/mu);

function [RET] = updateF(J,Y2,S,Y3,mu)
RET=1/2*(J-Y2/mu+S-Y3/mu);
function [F] = updateF(J,Y2,S,Y3,mu)
F=0.5*(J-Y2/mu+S-Y3/mu);

function [M] = updateM(F)
n=length(F);
M=reshape(F',1,n*n);

function [E] = updateE(X,S,E,Y1,mu,lambda)
E=l21(X*S-X-Y1/mu,lambda/mu);
E=l21(X-X*S+Y1/mu,lambda/mu); % TODO: -E not E

function [Xv,Xc,ZJv,ZJc,ZSv,ZSc,Zc,Jc,Sc,Ec] = caculateTempVars(X,S,E,Z,J,Zk,Jk,Sk,Ek,Xf)
function [Xv,Xc,ZJv,ZJc,ZSv,ZSc,Zc,Jc,Sc,Ec,Cmax] = caculateTempVars(X,S,E,Z,J,Zk,Jk,Sk,Ek,Xf,eta1,mu)
Xc=X-X*S-E;
ZJc=Z-J;
ZSc=Z-S;
Xv=norm(Xc,'fro')/Xf;
ZJv=norm(ZJc,'fro')/Xf;
ZSv=norm(ZSc,'fro')/Xf;

Zc=norm(Zk-Z,'fro')/Xf;
Jc=norm(Jk-J,'fro')/Xf;
Sc=norm(Sk-S,'fro')/Xf;
Ec=norm(Ek-E,'fro')/Xf;
Zc=norm(Z-Zk,'fro')/Xf;
Jc=norm(J-Jk,'fro')/Xf;
Sc=norm(S-Sk,'fro')/Xf;
Ec=norm(E-Ek,'fro')/Xf;

Cmax=mu*(max([sqrt(eta1)*Sc Jc Zc Ec]));


function [Y1] = updateY1(Y1,mu,Xc)
Y1=Y1+mu*Xc;
Expand Down

0 comments on commit 2abc5eb

Please sign in to comment.