Skip to content

Commit

Permalink
Changed regularizer for least-squares
Browse files Browse the repository at this point in the history
  • Loading branch information
wmkouw committed May 14, 2017
1 parent 741a903 commit 263b0f0
Show file tree
Hide file tree
Showing 6 changed files with 12 additions and 14 deletions.
2 changes: 1 addition & 1 deletion experiment-hdis/exp_da_hdis.m
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ function exp_da_hdis(varargin)

% Domain-adaptive classifiers
case 'tcp-lda'
exp_da_tcp(D(ixS,:), y(ixS), D(ixT,:), y(ixT), 'NN', p.Results.NN, 'nR', p.Results.nR, 'nF', p.Results.nF, 'maxIter', p.Results.maxIter, 'xTol', p.Results.xTol, 'saveName', [p.Results.saveName p.Results.dataName '_prep' p.Results.prep{logical(cellfun(@isstr, p.Results.prep))} '_cc' sprintf('%1i', n) '_nR' num2str(p.Results.nR) '_'], 'alpha', p.Results.alpha, 'lambda', p.Results.lambda, 'clf', 'tcp-lda');
exp_da_tcp(D(ixS,:), y(ixS), D(ixT,:), y(ixT), 'NN', p.Results.NN, 'nR', p.Results.nR, 'nF', p.Results.nF, 'maxIter', p.Results.maxIter, 'xTol', p.Results.xTol, 'saveName', [p.Results.saveName p.Results.dataName '_prep' p.Results.prep{logical(cellfun(@isstr, p.Results.prep))} '_cc' sprintf('%1i', n) '_nR' num2str(p.Results.nR) '_'], 'alpha', p.Results.alpha, 'lambda', p.Results.lambda, 'clf', 'tcp-lda', 'lr', p.Results.lr);
case 'tcp-qda'
exp_da_tcp(D(ixS,:), y(ixS), D(ixT,:), y(ixT), 'NN', p.Results.NN, 'nR', p.Results.nR, 'nF', p.Results.nF, 'maxIter', p.Results.maxIter, 'xTol', p.Results.xTol, 'saveName', [p.Results.saveName p.Results.dataName '_prep' p.Results.prep{logical(cellfun(@isstr, p.Results.prep))} '_cc' sprintf('%1i', n) '_nR' num2str(p.Results.nR) '_'], 'alpha', p.Results.alpha, 'lambda', p.Results.lambda, 'clf', 'tcp-qda', 'lr', p.Results.lr);
case 'tcp-ls'
Expand Down
4 changes: 2 additions & 2 deletions experiment-hdis/run_exp_da_hdis.m
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
nR = 1;

% Hyperparameters
lambda = 1e-1;
lambda = 1e-3;
alpha = 2;

% Optimization parameters
Expand All @@ -20,7 +20,7 @@
saveName = 'results/';

% Loop over all included classifiers
clfs = { 'tca'};
clfs = {'tca', 'kmm-lsq', 'rcsa', 'rba', 'tcp-ls', 'tcp-lda', 'tcp-qda'};
for c = 1:length(clfs)

exp_da_hdis('prep', prep, 'nR', nR, 'clf', clfs{c}, ...
Expand Down
3 changes: 1 addition & 2 deletions util/tcp_lda.m
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
function [theta,varargout] = tcp_lda(X,yX,Z,varargin)
% Function to run the Linear Discriminant Analysis version of the
% Target Contrastive Pessimistic Estimator
% Linear Discriminant Analysis version of the Target Contrastive Pessimistic Estimator
% Input:
% X source data (N samples x D features)
% Z target data (M samples x D features)
Expand Down
12 changes: 6 additions & 6 deletions util/tcp_ls.m
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
function [theta,varargout] = tce_ls(X,yX,Z,varargin)
% Function to run the Least-Squares version of the Target Contrastive Pessimistic Estimator
function [theta,varargout] = tcp_ls(X,yX,Z,varargin)
% Least-Squares version of the Target Contrastive Pessimistic Estimator
% Input:
% X source data (N samples x D features)
% Z target data (M samples x D features)
Expand All @@ -11,7 +11,7 @@
% maxIter maximum number of iterations (default: 500)
% xTol convergence criterion (default: 1e-5)
% Output:
% theta tce estimate
% theta tcp estimate
% Optional output:
% {1} found worst-case labeling q
% {2} target loss of the mcpl/ref estimate with q/u
Expand Down Expand Up @@ -49,7 +49,7 @@
end

% Reference parameter estimates
theta.ref = (X'*X + p.Results.lambda*eye(D))\(X'*yX);
theta.ref = (X'*X + p.Results.lambda*N*eye(D))\(X'*yX);

% Initialize
q = ones(M,K)./K;
Expand All @@ -63,7 +63,7 @@
%%% Minimization

% Closed-form minimization w.r.t. theta
ZZ = svdinv(Z'*Z + p.Results.lambda*eye(D));
ZZ = svdinv(Z'*Z + p.Results.lambda*M*eye(D));
[~,psd] = chol(ZZ); if psd>0; disp('target covariance not psd'); end
theta.tcp = ZZ*(Z'*(labels(1)*q(:,1) + labels(2)*q(:,2)));

Expand Down Expand Up @@ -116,7 +116,7 @@
end

% Oracle parameter estimates
theta.orc = (Z'*Z + p.Results.lambda*eye(D))\(Z'*p.Results.yZ);
theta.orc = (Z'*Z + p.Results.lambda*M*eye(D))\(Z'*p.Results.yZ);

%%% Optional output
if nargout > 1
Expand Down
3 changes: 1 addition & 2 deletions util/tcp_qda.m
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
function [theta,varargout] = tcp_qda(X,yX,Z,varargin)
% Function to run the Quadratic Discriminant Analysis version of the
% Target Contrastive Pessimistic estimator
% Quadratic Discriminant Analysis version of the Target Contrastive Pessimistic Estimator
% Input:
% X source data (N samples x D features)
% Z target data (M samples x D features)
Expand Down
2 changes: 1 addition & 1 deletion util/wlsq.m
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,6 @@
bX = bsxfun(@times, w, X);

% Least-squares
theta = (bX'*bX+p.Results.lambda*eye(D))\(bX'*y);
theta = (bX'*bX+p.Results.lambda*N*eye(D))\(bX'*y);

end

0 comments on commit 263b0f0

Please sign in to comment.