diff --git a/supervised/algorithms/knn.py b/supervised/algorithms/knn.py index 3978e562..c7a879e9 100644 --- a/supervised/algorithms/knn.py +++ b/supervised/algorithms/knn.py @@ -55,6 +55,14 @@ def fit( else: self.model.fit(X, y) + @property + def _classes(self): + # Returns the unique classes based on the fitted model + if hasattr(self.model, "classes_"): + return self.model.classes_ + else: + return None + class KNeighborsAlgorithm(KNNFit, RegressorMixin): algorithm_name = "k-Nearest Neighbors" diff --git a/tests/tests_algorithms/test_knn.py b/tests/tests_algorithms/test_knn.py index 1572a549..3e4119be 100644 --- a/tests/tests_algorithms/test_knn.py +++ b/tests/tests_algorithms/test_knn.py @@ -17,7 +17,11 @@ class KNeighborsRegressorAlgorithmTest(unittest.TestCase): @classmethod def setUpClass(cls): cls.X, cls.y = datasets.make_regression( - n_samples=100, n_features=5, n_informative=4, shuffle=False, random_state=0 + n_samples=100, + n_features=5, + n_informative=4, + shuffle=False, + random_state=0 ) def test_reproduce_fit(self): @@ -77,3 +81,15 @@ def test_is_fitted(self): self.assertFalse(model.is_fitted()) model.fit(self.X, self.y) self.assertTrue(model.is_fitted()) + + def test_classes_attribute(self): + params = {"ml_task": "binary_classification"} + model = KNeighborsAlgorithm(params) + model.fit(self.X,self.y) + + try: + classes = model._classes + except AttributeError: + classes = None + + self.assertTrue(np.array_equal(np.unique(self.y), classes))