-
Notifications
You must be signed in to change notification settings - Fork 32
/
platt_scaling.py
42 lines (32 loc) · 1.18 KB
/
platt_scaling.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
from abc import ABC, abstractmethod
from typing import Any
import torch
from torch import Tensor, nn, optim
class PlattMixin(ABC):
"""Mixin for classifier-like objects that can be Platt scaled."""
bias: nn.Parameter
scale: nn.Parameter
@abstractmethod
def __call__(self, *args: Any, **kwds: Any) -> Any: ...
def platt_scale(self, labels: Tensor, hiddens: Tensor, max_iter: int = 100):
"""Fit the scale and bias terms to data with LBFGS.
Args:
labels: Binary labels of shape [batch].
hiddens: Hidden states of shape [batch, dim].
max_iter: Maximum number of iterations for LBFGS.
"""
opt = optim.LBFGS(
[self.bias, self.scale],
line_search_fn="strong_wolfe",
max_iter=max_iter,
tolerance_change=torch.finfo(hiddens.dtype).eps,
tolerance_grad=torch.finfo(hiddens.dtype).eps,
)
def closure():
opt.zero_grad()
loss = nn.functional.binary_cross_entropy_with_logits(
self(hiddens), labels.float()
)
loss.backward()
return float(loss)
opt.step(closure)