diff --git a/elk/calibration.py b/elk/calibration.py index db56fa02..23a48b8e 100644 --- a/elk/calibration.py +++ b/elk/calibration.py @@ -34,8 +34,8 @@ def update(self, labels: Tensor, probs: Tensor) -> "CalibrationError": assert labels.shape == probs.shape assert torch.is_floating_point(probs) - self.labels.append(probs) - self.pred_probs.append(labels) + self.labels.append(labels) + self.pred_probs.append(probs) return self def compute(self, p: int = 2) -> CalibrationEstimate: