-
Notifications
You must be signed in to change notification settings - Fork 0
/
mask2former_loss.py
105 lines (87 loc) · 3.74 KB
/
mask2former_loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from typing import List, Optional
import torch.distributed as dist
import torch
import torch.nn as nn
from transformers.models.mask2former.modeling_mask2former import (
Mask2FormerLoss as TransformersMask2FormerLoss,
)
from training.mask2former_matcher import Mask2formerMatcher
class Mask2formerLoss(TransformersMask2FormerLoss):
def __init__(
self,
num_points: int,
oversample_ratio: float,
importance_sample_ratio: float,
mask_coefficient: float,
dice_coefficient: float,
class_coefficient: Optional[float] = None,
num_labels: Optional[int] = None,
no_object_coefficient: Optional[float] = None,
):
nn.Module.__init__(self)
self.num_points = num_points
self.oversample_ratio = oversample_ratio
self.importance_sample_ratio = importance_sample_ratio
self.mask_coefficient = mask_coefficient
self.dice_coefficient = dice_coefficient
self.class_coefficient = class_coefficient
if num_labels is not None:
self.num_labels = num_labels
self.eos_coef = no_object_coefficient
empty_weight = torch.ones(self.num_labels + 1)
empty_weight[-1] = self.eos_coef # type: ignore
self.register_buffer("empty_weight", empty_weight)
self.matcher = Mask2formerMatcher(
num_points=num_points,
mask_coefficient=mask_coefficient,
dice_coefficient=dice_coefficient,
class_coefficient=class_coefficient,
)
@torch.compiler.disable
def forward(
self,
masks_queries_logits: torch.Tensor,
targets: List[dict],
class_queries_logits: Optional[torch.Tensor] = None,
):
mask_labels = [target["masks"].half() for target in targets]
class_labels = [target["labels"].long() for target in targets]
indices = self.matcher(
masks_queries_logits=masks_queries_logits,
mask_labels=mask_labels,
class_queries_logits=class_queries_logits,
class_labels=class_labels,
)
loss_masks = self.loss_masks(masks_queries_logits, mask_labels, indices, 1)
num_masks = sum(len(tgt) for (_, tgt) in indices)
num_masks_tensor = torch.as_tensor(
num_masks, dtype=torch.float, device=masks_queries_logits.device
)
if dist.is_available() and dist.is_initialized():
dist.all_reduce(num_masks_tensor)
world_size = dist.get_world_size()
else:
world_size = 1
num_masks = torch.clamp(num_masks_tensor / world_size, min=1).item()
for key in loss_masks.keys():
loss_masks[key] = loss_masks[key] / num_masks
loss_classes = self.loss_labels(class_queries_logits, class_labels, indices) # type: ignore
return {**loss_masks, **loss_classes}
def loss_total(self, losses_all_layers, log_fn) -> torch.Tensor:
loss_total = None
for loss_key, loss in losses_all_layers.items():
log_fn(f"train_{loss_key}", loss, sync_dist=True)
if "mask" in loss_key:
weighted_loss = loss * self.mask_coefficient
elif "dice" in loss_key:
weighted_loss = loss * self.dice_coefficient
elif "cross_entropy" in loss_key:
weighted_loss = loss * self.class_coefficient
else:
raise ValueError(f"Unknown loss key: {loss_key}")
if loss_total is None:
loss_total = weighted_loss
else:
loss_total = torch.add(loss_total, weighted_loss)
log_fn("train_loss_total", loss_total, sync_dist=True, prog_bar=True)
return loss_total # type: ignore