Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ain-soph committed Jun 24, 2023
1 parent 518f432 commit b2b2589
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
8 changes: 3 additions & 5 deletions trojanzoo/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1083,26 +1083,24 @@ def get_data(self, data, **kwargs):
return data

def accuracy(self, _output: torch.Tensor, _label: torch.Tensor,
num_classes: int = None,
topk: Iterable[int] = (1, 5)) -> dict[str, float]:
topk: Iterable[int] = (1, 5), **kwargs) -> dict[str, float]:
r"""Computes the accuracy over the k top predictions
for the specified values of k.
Args:
_output (torch.Tensor): The batched logit tensor with shape ``(N, C)``.
_label (torch.Tensor): The batched label tensor with shape ``(N)``.
num_classes (int): Number of classes. Defaults to :attr:`self.num_classes`.
topk (~collections.abc.Iterable[int]): Which top-k accuracies to show.
Defaults to ``(1, 5)``.
**kwargs: Keyword arguments passed to :func:`trojanzoo.utils.model.accuracy`.
Returns:
dict[str, float]: Top-k accuracies.
Note:
The implementation is in :func:`trojanzoo.utils.model.accuracy`.
"""
num_classes = num_classes or self.num_classes
return accuracy(_output, _label, num_classes, topk)
return accuracy(_output, _label, topk, **kwargs)

def activate_params(self, params: Iterator[nn.Parameter] = []) -> None:
r"""Set ``requires_grad=True`` for selected :attr:`params` of :attr:`module`.
Expand Down
7 changes: 4 additions & 3 deletions trojanzoo/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,21 +471,22 @@ def activate_params(module: nn.Module, params: Iterator[nn.Parameter] = []):


@torch.no_grad()
def accuracy(_output: torch.Tensor, _label: torch.Tensor, num_classes: int,
topk: Iterable[int] = (1, 5)) -> list[float]:
def accuracy(_output: torch.Tensor, _label: torch.Tensor,
topk: Iterable[int] = (1, 5), **kwargs) -> list[float]:
r"""Computes the accuracy over the k top predictions
for the specified values of k.
Args:
_output (torch.Tensor): The batched logit tensor with shape ``(N, C)``.
_label (torch.Tensor): The batched label tensor with shape ``(N)``.
num_classes (int): Number of classes.
topk (~collections.abc.Iterable[int]): Which top-k accuracies to show.
Defaults to ``(1, 5)``.
**kwargs: Any keyword argument (unused).
Returns:
dict[str, float]: Top-k accuracies.
"""
num_classes = _output.size(1)
maxk = min(max(topk), num_classes)
batch_size = _label.size(0)
_, pred = _output.topk(maxk, 1, True, True)
Expand Down
7 changes: 5 additions & 2 deletions trojanzoo/utils/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def train(module: nn.Module, num_classes: int,

if lr_scheduler and lr_scheduler_freq == 'iter':
lr_scheduler.step()
metrics = metric_fn(_output, _label, num_classes=num_classes, **metric_kwargs)
metrics = metric_fn(_input=_input, _label=_label, _output=_output, **metric_kwargs)
batch_size = int(_label.size(0))
logger_train.update(n=batch_size, loss=float(loss), **metrics)
empty_cache()
Expand All @@ -176,6 +176,7 @@ def train(module: nn.Module, num_classes: int,
activate_params(module, [])
loss, acc = (logger_train.meters['loss'].global_avg,
logger_train.meters['top1'].global_avg)
logger_train.reset()
if writer is not None:
from torch.utils.tensorboard import SummaryWriter
assert isinstance(writer, SummaryWriter)
Expand Down Expand Up @@ -245,6 +246,7 @@ def validate(module: nn.Module, num_classes: int,
if logger is None:
logger = MetricLogger()
logger.create_meters(loss=None, top1=None)
logger.reset()
loader_epoch = loader
if verbose:
header: str = '{yellow}{0}{reset}'.format(print_prefix, **ansi)
Expand All @@ -257,7 +259,7 @@ def validate(module: nn.Module, num_classes: int,
with torch.no_grad():
_output = forward_fn(_input)
loss = float(loss_fn(_input, _label, _output=_output, **kwargs))
metrics = metric_fn(_output, _label, num_classes=num_classes, **metric_kwargs)
metrics = metric_fn(_input=_input, _label=_label, _output=_output, **metric_kwargs)
batch_size = int(_label.size(0))
logger.update(n=batch_size, loss=float(loss), **metrics)
acc, loss = (logger.meters['top1'].global_avg,
Expand All @@ -269,6 +271,7 @@ def validate(module: nn.Module, num_classes: int,
tag_scalar_dict={tag: acc}, global_step=_epoch)
writer.add_scalars(main_tag='Loss/' + main_tag,
tag_scalar_dict={tag: loss}, global_step=_epoch)
logger.reset()
return acc, loss


Expand Down

0 comments on commit b2b2589

Please sign in to comment.