Skip to content

Commit

Permalink
tidy up dbstream
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxHalford committed Jun 23, 2023
1 parent 0af58db commit 7fdf107
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 12 deletions.
8 changes: 6 additions & 2 deletions docs/releases/unreleased.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
- Added `bandit.BayesUCB`.
- Added `bandit.evaluate_offline`, for evaluating bandits on historical (logged) data.

## cluster

- `DBStream` will now only recluster on demand, rather than at every call to `learn_one`.

## compat

- The `predict_many` method scikit-learn models wrapped with `compat.convert_sklearn_to_river` raised an exception if the model had not been fitted on any data yet. Instead, default predictions will be produced, which is consistent with the rest of River.
Expand Down Expand Up @@ -35,9 +39,9 @@
## tree

- Expose the `min_branch_fraction` parameter to avoid splits where most of the data goes to a single branch. Affects
classification trees.
classification trees.
- Added the `max_share_to_split` parameter to Hoeffding Tree classifiers. This parameters avoids splitting when the majority
class has most of the data.
class has most of the data.

## utils

Expand Down
22 changes: 12 additions & 10 deletions river/cluster/dbstream.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ class DBSTREAM(base.Clusterer):
relative to the area cover by micro clusters. This parameter is used to determine
whether a micro cluster or a shared density is weak.
Attributes
----------
n_clusters
Expand Down Expand Up @@ -111,11 +110,13 @@ class DBSTREAM(base.Clusterer):
... [4, 1.5], [4, 2.25], [4, 2.5], [4, 3], [4, 3.25], [4, 3.5]
... ]
>>> dbstream = cluster.DBSTREAM(clustering_threshold = 1.5,
... fading_factor = 0.05,
... cleanup_interval = 4,
... intersection_factor = 0.5,
... minimum_weight = 1)
>>> dbstream = cluster.DBSTREAM(
... clustering_threshold=1.5,
... fading_factor=0.05,
... cleanup_interval=4,
... intersection_factor=0.5,
... minimum_weight=1
... )
>>> for x, _ in stream.iter_array(X):
... dbstream = dbstream.learn_one(x)
Expand All @@ -128,6 +129,7 @@ class DBSTREAM(base.Clusterer):
>>> dbstream._n_clusters
2
"""

def __init__(
Expand All @@ -148,9 +150,9 @@ def __init__(
self.minimum_weight = minimum_weight

self._n_clusters: int = 0
self._clusters: typing.Dict[int, "DBSTREAMMicroCluster"] = {}
self._clusters: typing.Dict[int, DBSTREAMMicroCluster] = {}
self._centers: typing.Dict = {}
self._micro_clusters: typing.Dict[int, "DBSTREAMMicroCluster"] = {}
self._micro_clusters: typing.Dict[int, DBSTREAMMicroCluster] = {}

self.s: dict[int, dict[int, float]] = {}
self.s_t: dict[int, dict[int, float]] = {}
Expand Down Expand Up @@ -405,7 +407,7 @@ def n_clusters(self) -> int:
return self._n_clusters

@property
def clusters(self) -> typing.Dict[int, "DBSTREAMMicroCluster"]:
def clusters(self) -> typing.Dict[int, DBSTREAMMicroCluster]:
self._recluster()
return self._clusters

Expand All @@ -415,7 +417,7 @@ def centers(self) -> typing.Dict:
return self._centers

@property
def micro_clusters(self) -> typing.Dict[int, "DBSTREAMMicroCluster"]:
def micro_clusters(self) -> typing.Dict[int, DBSTREAMMicroCluster]:
return self._micro_clusters


Expand Down

0 comments on commit 7fdf107

Please sign in to comment.