-
Notifications
You must be signed in to change notification settings - Fork 8
/
misc.py
93 lines (73 loc) · 3.13 KB
/
misc.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
# -------------------------------------------------------------------------
# Copyright (c) 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual
# property and proprietary rights in and to this software, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this software and related documentation
# without an express license agreement from NVIDIA CORPORATION is strictly
# prohibited.
#
# Written by Jiarui Xu
# -------------------------------------------------------------------------
import collections.abc
from collections import OrderedDict
import torch
import torch.distributed as dist
from .imagenet_template import template_meta
def reduce_tensor(tensor):
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt /= dist.get_world_size()
return rt
def get_grad_norm(parameters, norm_type=2):
if isinstance(parameters, torch.Tensor):
parameters = [parameters]
parameters = list(filter(lambda p: p.grad is not None, parameters))
norm_type = float(norm_type)
total_norm = 0
for p in parameters:
param_norm = p.grad.data.norm(norm_type)
total_norm += param_norm.item()**norm_type
total_norm = total_norm**(1. / norm_type)
return total_norm
def get_batch_size(data):
if isinstance(data, torch.Tensor):
return data.size(0)
elif isinstance(data, collections.abc.Mapping):
return get_batch_size(data[next(iter(data))])
elif isinstance(data, collections.abc.Sequence) and not isinstance(data, str):
# check to make sure that the elements in batch have consistent size
it = iter(data)
return get_batch_size(next(it))
raise TypeError
def data2cuda(data):
if isinstance(data, torch.Tensor):
batch = data.cuda(non_blocking=True)
return batch
elif isinstance(data, collections.abc.Mapping):
return {key: data2cuda(data[key]) for key in data}
elif isinstance(data, collections.abc.Sequence) and not isinstance(data, str):
return [data2cuda(d) for d in data]
else:
raise TypeError
def parse_losses(losses):
log_vars = OrderedDict()
for loss_name, loss_value in losses.items():
if isinstance(loss_value, torch.Tensor):
log_vars[loss_name] = loss_value.mean()
elif isinstance(loss_value, list):
log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
else:
raise TypeError(f'{loss_name} is not a tensor or list of tensors')
loss = sum(_value for _key, _value in log_vars.items() if 'loss' in _key)
return loss, log_vars
def build_dataset_class_tokens(text_transform, template_set, classnames):
tokens = []
templates = template_meta[template_set]
for classname in classnames:
# format with class
tokens.append(torch.stack([text_transform(template.format(classname)) for template in templates]))
# [N, T, L], N: number of instance, T: number of captions (including ensembled), L: sequence length
tokens = torch.stack(tokens)
return tokens