Skip to content

Commit

Permalink
Add ood test
Browse files Browse the repository at this point in the history
  • Loading branch information
jhoon-oh committed May 21, 2021
1 parent 9e4af6b commit 938d27f
Showing 1 changed file with 44 additions and 0 deletions.
44 changes: 44 additions & 0 deletions models/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,51 @@ def test_img_local(net_g, dataset, args, user_idx=-1, idxs=None, return_features
return accuracy, test_loss, features, targets
else:
return accuracy, test_loss

def ood_test_img_local(net_g, dataset, args, user_idx=-1, idxs=None, user_train_targets=None):
net_g.eval()
# testing
per_total = 0
per_correct = 0
ood_total = 0
ood_correct = 0
# data_loader = DataLoader(dataset, batch_size=args.bs)
data_loader = DataLoader(DatasetSplit(dataset, idxs), batch_size=args.bs, shuffle=False)
l = len(data_loader)

for idx, (data, target) in enumerate(data_loader):
if args.gpu != -1:
data, target = data.to(args.device), target.to(args.device)
user_train_targets = user_train_targets.to(args.device)
log_probs = net_g(data)
y_pred = log_probs.data.max(1, keepdim=True)[1]

# get the index of the max log-probability
target_dup = torch.cat([target.view(-1, 1)]*len(user_train_targets), dim=1)
user_train_targets_dup = torch.cat([user_train_targets.view(1, -1)]*len(target), dim=0)
per_ood = torch.sum(target_dup == user_train_targets_dup, dim=1)

per_idx = torch.where(per_ood == 1)
ood_idx = torch.where(per_ood == 0)

per_pred = y_pred[per_idx]
ood_pred = y_pred[ood_idx]

per_target = target[per_idx]
ood_target = target[ood_idx]

per_total += len(per_target)
ood_total += len(ood_target)

per_correct += per_pred.eq(per_target.data.view_as(per_pred)).long().cpu().sum()
ood_correct += ood_pred.eq(ood_target.data.view_as(ood_pred)).long().cpu().sum()

if args.verbose:
print('Local model {}: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)'.format(
user_idx, test_loss, correct, len(data_loader.dataset), accuracy))
else:
return per_correct.item()/per_total*100, ood_correct.item()/ood_total*100

def distance_test_img_local(net_g, dataset_train, dataset_test, args, user_idx=-1, train_idxs=None, test_idxs=None):
net_g.eval()

Expand Down

0 comments on commit 938d27f

Please sign in to comment.