diff --git a/l21.m b/l21.m index 60fe8a1..b393755 100644 --- a/l21.m +++ b/l21.m @@ -1,7 +1,8 @@ function [ M ] = l21( Q, lambda ) QN=sqrt(sum(Q.^2,1)); -% assert(length(find(QN<=0))==0); -coefficient=(QN-lambda)./QN; +zero=find(QN==0); +coefficient=(QN-lambda)./QN; % may be NaN +coefficient(zero)=0; invalid=find(coefficient<=0); coefficient(invalid)=0; M=bsxfun(@times,Q,coefficient); diff --git a/multi_NNLRS.m b/multi_NNLRS.m index 43b52db..b64d8bf 100644 --- a/multi_NNLRS.m +++ b/multi_NNLRS.m @@ -1,4 +1,8 @@ function [Z,ZZ,E] = multi_NNLRS(X,lambda,beta,alpha) +% solve \sum_{i=1}^k(||Z_i||_*+beta||Z_i||_1+lambda||E_i||_{2,1})+alpha||Z||_{2,1} + +tic; + % init vars k=length(X); [m,n]=size(X{1}); @@ -52,6 +56,7 @@ eta1=cell(k,1); for i=1:k eta1{i}=norm2X{i}*norm2X{i}*1.02;%eta needs to be larger than ||X||_2^2, but need not be too large. + fprintf(1,'eta1{%d} is %f\n',i,eta1{i}); end mu=1e-6; max_mu=10^10; @@ -59,7 +64,7 @@ % epsilon=1e-4; % epsilon2=1e-5; % must be small! epsilon=1e-6; -epsilon2=1e-5; % must be small! +epsilon2=1e-2; % must be small! MAX_ITER=1000; iter=0; convergenced=false; @@ -84,31 +89,31 @@ % update Z [F]=cellfun(@updateF,J,Y2,S,Y3,cmu,'UniformOutput',false); [M]=cellfun(@updateM,F,'UniformOutput',false); + MM=zeros(k,n*n); for i=1:k - ZZ(i,:)=M{i}; + MM(i,:)=M{i}; + end + ZZ=l21(MM,alpha/(2*mu)); + if alpha==0 + assert(nnz(ZZ-MM)==0); end - ZZ=l21(ZZ,alpha/(2*mu)); % update Z_i Zk=Z; for i=1:k Z{i}=reshape(ZZ(i,:),n,n)'; end % update E_i + Ek=E; [E]=cellfun(@updateE,X,S,E,Y1,cmu,clambda,'UniformOutput',false); % parameter update rule % check convergence - [Xv,Xc,ZJv,ZJc,ZSv,ZSc,Zc,Jc,Sc,Ec] = cellfun(@caculateTempVars,X,S,E,Z,J,Zk,Jk,Sk,Ek,Xf,'UniformOutput',false); + [Xv,Xc,ZJv,ZJc,ZSv,ZSc,Zc,Jc,Sc,Ec,Cmax] = cellfun(@caculateTempVars,X,S,E,Z,J,Zk,Jk,Sk,Ek,Xf,eta1,cmu,'UniformOutput',false); changeX=max([Xv{:}]); changeZJ=max([ZJv{:}]); changeZS=max([ZSv{:}]); - changeZ=max([Zc{:}]); - changeJ=max([Jc{:}]); - changeS=max([Sc{:}]); - changeE=max([Ec{:}]); - tmp=[changeZ changeJ changeS changeE ]; - gap=mu*max(tmp); + gap=max([Cmax{:}]); if mod(iter,50)==0 fprintf(1,'===========================================================================================================\n'); fprintf(1,'gap between two iteration is %f,mu is %f\n',gap,mu); @@ -118,8 +123,8 @@ end fprintf(1,'\n'); end - % if changeX <= epsilon && changeZJ <= epsilon && changeZS <= epsilon - if changeX <= epsilon && gap <=epsilon2 && changeZJ <= epsilon && changeZS <= epsilon + if changeX <= epsilon && gap <=epsilon2 + % if changeX <= epsilon && gap <=epsilon2 && changeZJ <= epsilon && changeZS <= epsilon convergenced=true; fprintf(2,'convergenced, iter is %d\n',iter); fprintf(2,'iter %d,mu is %f,ResidualX is %f,changeZJ is %f,changeZS is %f\n',iter,mu,changeX,changeZJ,changeZS); @@ -139,6 +144,8 @@ iter=iter+1; end +toc; + function [S,svp] = updateS(xtx,X,E,Y1,Z,S,Y3,eta1,mu) T=-mu*(xtx-xtx*S-X'*E+X'*Y1/mu+Z-S+Y3/mu); % argmin_{S} 1/(mu*eta1)||S||_*+1/2*||S-S_k+T/(mu*eta1)||_F^2 @@ -147,17 +154,17 @@ function [J] = updateJ(Z,J,Y2,mu,beta) J=wthresh(Z+Y2/mu,'s',beta/mu); -function [RET] = updateF(J,Y2,S,Y3,mu) - RET=1/2*(J-Y2/mu+S-Y3/mu); +function [F] = updateF(J,Y2,S,Y3,mu) + F=0.5*(J-Y2/mu+S-Y3/mu); function [M] = updateM(F) n=length(F); M=reshape(F',1,n*n); function [E] = updateE(X,S,E,Y1,mu,lambda) - E=l21(X*S-X-Y1/mu,lambda/mu); + E=l21(X-X*S+Y1/mu,lambda/mu); % TODO: -E not E -function [Xv,Xc,ZJv,ZJc,ZSv,ZSc,Zc,Jc,Sc,Ec] = caculateTempVars(X,S,E,Z,J,Zk,Jk,Sk,Ek,Xf) +function [Xv,Xc,ZJv,ZJc,ZSv,ZSc,Zc,Jc,Sc,Ec,Cmax] = caculateTempVars(X,S,E,Z,J,Zk,Jk,Sk,Ek,Xf,eta1,mu) Xc=X-X*S-E; ZJc=Z-J; ZSc=Z-S; @@ -165,10 +172,13 @@ ZJv=norm(ZJc,'fro')/Xf; ZSv=norm(ZSc,'fro')/Xf; - Zc=norm(Zk-Z,'fro')/Xf; - Jc=norm(Jk-J,'fro')/Xf; - Sc=norm(Sk-S,'fro')/Xf; - Ec=norm(Ek-E,'fro')/Xf; + Zc=norm(Z-Zk,'fro')/Xf; + Jc=norm(J-Jk,'fro')/Xf; + Sc=norm(S-Sk,'fro')/Xf; + Ec=norm(E-Ek,'fro')/Xf; + + Cmax=mu*(max([sqrt(eta1)*Sc Jc Zc Ec])); + function [Y1] = updateY1(Y1,mu,Xc) Y1=Y1+mu*Xc;