Skip to content

Commit

Permalink
auc fix
Browse files Browse the repository at this point in the history
  • Loading branch information
rodrigosnader committed Jun 30, 2022
1 parent b7ed97a commit 7058b18
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/wavy/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
"regression": {
"loss": "MSE",
"optimizer": "adam",
"metrics": ["MAE"],
"metrics": ["mae"],
"last_activation": "linear",
},
"classification": {
Expand Down Expand Up @@ -223,7 +223,7 @@ def score(self, on=None, **kwargs):
self.x_val, self.y_val, verbose=0, **kwargs
)

indexes = [self.model.metrics_names.index(metric) for metric in self.metrics]
indexes = [self.model.metrics_names.index(metric.lower()) for metric in self.metrics]

return pd.DataFrame(
{key: [value[index] for index in indexes] for key, value in dic.items()},
Expand Down Expand Up @@ -736,6 +736,7 @@ def residuals(self):


def compute_score_per_model(*models, on="val"):
# BUG
"""
Compute score per model
Expand All @@ -754,6 +755,7 @@ def compute_score_per_model(*models, on="val"):


def compute_default_scores(x, y, model_type, epochs=10, verbose=0, **kwargs):
# BUG
"""
Compute default scores for a model.
Expand All @@ -768,7 +770,7 @@ def compute_default_scores(x, y, model_type, epochs=10, verbose=0, **kwargs):
Returns:
pd.DataFrame: Scores
"""
models = [BaselineShift, DenseModel, ConvModel]
models = [BaselineConstant, BaselineShift, DenseModel]
models = [model(x=x, y=y, model_type=model_type) for model in models]
for model in models:
model.fit(epochs=epochs, verbose=verbose, **kwargs)
Expand Down

0 comments on commit 7058b18

Please sign in to comment.