Skip to content

Commit

Permalink
Fix formatting of neural.py
Browse files Browse the repository at this point in the history
Fix formatting of neural.py to make PEP8 compliant.
  • Loading branch information
gkhayes committed Apr 2, 2019
1 parent 078ff37 commit 8d36c2d
Showing 1 changed file with 64 additions and 41 deletions.
105 changes: 64 additions & 41 deletions mlrose/neural.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def gradient_descent(problem, max_attempts=10, max_iters=np.inf,
random_state: int, default: None
If random_state is a positive integer, random_state is the seed used
by np.random.seed(); otherwise, the random seed is not set.
curve: bool, default: False
Boolean to keep fitness values for a curve.
If :code:`False`, then no curve is stored.
Expand All @@ -111,7 +111,7 @@ def gradient_descent(problem, max_attempts=10, max_iters=np.inf,
best_fitness: float
Value of fitness function at best state.
fitness_curve: array
Numpy array containing the fitness at every iteration.
Only returned if input argument :code:`curve` is :code:`True`.
Expand All @@ -136,9 +136,9 @@ def gradient_descent(problem, max_attempts=10, max_iters=np.inf,
problem.reset()
else:
problem.set_state(init_state)

if curve:
fitness_curve=[]
fitness_curve = []

attempts = 0
iters = 0
Expand All @@ -163,15 +163,15 @@ def gradient_descent(problem, max_attempts=10, max_iters=np.inf,
if next_fitness > problem.get_maximize()*best_fitness:
best_fitness = problem.get_maximize()*next_fitness
best_state = next_state

if curve:
fitness_curve.append(problem.get_fitness())

problem.set_state(next_state)

if curve:
return best_state, best_fitness, np.asarray(fitness_curve)

return best_state, best_fitness


Expand Down Expand Up @@ -540,16 +540,21 @@ def fit(self, X, y=None, init_weights=None):
for _ in range(self.restarts + 1):
if init_weights is None:
init_weights = np.random.uniform(-1, 1, num_nodes)

if self.curve:
current_weights, current_loss, fitness_curve = random_hill_climb(
problem,
max_attempts=self.max_attempts if self.early_stopping else self.max_iters, max_iters=self.max_iters,
restarts=0, init_state=init_weights, curve=self.curve)
current_weights, current_loss, fitness_curve = \
random_hill_climb(problem,
max_attempts=self.max_attempts if
self.early_stopping else
self.max_iters,
max_iters=self.max_iters,
restarts=0, init_state=init_weights,
curve=self.curve)
else:
current_weights, current_loss = random_hill_climb(
problem,
max_attempts=self.max_attempts if self.early_stopping else self.max_iters,
max_attempts=self.max_attempts if self.early_stopping
else self.max_iters,
max_iters=self.max_iters,
restarts=0, init_state=init_weights, curve=self.curve)

Expand All @@ -565,15 +570,17 @@ def fit(self, X, y=None, init_weights=None):
fitted_weights, loss, fitness_curve = simulated_annealing(
problem,
schedule=self.schedule,
max_attempts=self.max_attempts if self.early_stopping else self.max_iters,
max_attempts=self.max_attempts if self.early_stopping else
self.max_iters,
max_iters=self.max_iters,
init_state=init_weights,
curve=self.curve)
else:
fitted_weights, loss = simulated_annealing(
problem,
schedule=self.schedule,
max_attempts=self.max_attempts if self.early_stopping else self.max_iters,
max_attempts=self.max_attempts if self.early_stopping else
self.max_iters,
max_iters=self.max_iters,
init_state=init_weights,
curve=self.curve)
Expand All @@ -584,36 +591,50 @@ def fit(self, X, y=None, init_weights=None):
problem,
pop_size=self.pop_size,
mutation_prob=self.mutation_prob,
max_attempts=self.max_attempts if self.early_stopping else self.max_iters,
max_attempts=self.max_attempts if self.early_stopping else
self.max_iters,
max_iters=self.max_iters,
curve=self.curve)
else:
fitted_weights, loss = genetic_alg(
problem,
pop_size=self.pop_size, mutation_prob=self.mutation_prob,
max_attempts=self.max_attempts if self.early_stopping else self.max_iters,
max_attempts=self.max_attempts if self.early_stopping else
self.max_iters,
max_iters=self.max_iters,
curve=self.curve)

else: # Gradient descent case
if init_weights is None:
init_weights = np.random.uniform(-1, 1, num_nodes)
fitted_weights, loss, fitness_curve = gradient_descent(
problem,
max_attempts=self.max_attempts if self.early_stopping else self.max_iters,
max_iters=self.max_iters,
curve=self.curve,
init_state=init_weights)

if self.curve:
fitted_weights, loss, fitness_curve = gradient_descent(
problem,
max_attempts=self.max_attempts if self.early_stopping else
self.max_iters,
max_iters=self.max_iters,
curve=self.curve,
init_state=init_weights)

else:
fitted_weights, loss = gradient_descent(
problem,
max_attempts=self.max_attempts if self.early_stopping else
self.max_iters,
max_iters=self.max_iters,
curve=self.curve,
init_state=init_weights)

# Save fitted weights and node list
self.node_list = node_list
self.fitted_weights = fitted_weights
self.loss = loss
self.output_activation = fitness.get_output_activation()

if self.curve:
self.fitness_curve = fitness_curve

return self

def predict(self, X):
Expand Down Expand Up @@ -785,9 +806,9 @@ class NeuralNetwork(BaseNeuralNetwork, ClassifierMixin):
random_state: int, default: None
If random_state is a positive integer, random_state is the seed used
by np.random.seed(); otherwise, the random seed is not set.
curve: bool, default: False
If bool is True, fitness_curve containing the fitness at each training
If bool is True, fitness_curve containing the fitness at each training
iteration is returned.
Attributes
Expand All @@ -804,9 +825,9 @@ class NeuralNetwork(BaseNeuralNetwork, ClassifierMixin):
:code:`predict` is performed for multi-class classification data; or
the predicted probability for class 1 when :code:`predict` is performed
for binary classification data.
fitness_curve: array
Numpy array giving the fitness at each training iteration.
Numpy array giving the fitness at each training iteration.
"""

def __init__(self, hidden_nodes=None,
Expand Down Expand Up @@ -897,10 +918,10 @@ class LinearRegression(BaseNeuralNetwork, RegressorMixin):
random_state: int, default: None
If random_state is a positive integer, random_state is the seed used
by np.random.seed(); otherwise, the random seed is not set.
curve: bool, default: False
If bool is true, curve containing the fitness at each training
iteration is returned.
If bool is true, curve containing the fitness at each training
iteration is returned.
Attributes
----------
Expand All @@ -910,9 +931,9 @@ class LinearRegression(BaseNeuralNetwork, RegressorMixin):
loss: float
Value of loss function for fitted weights when :code:`fit` is
performed.
fitness_curve: array
Numpy array giving the fitness at each training iteration.
Numpy array giving the fitness at each training iteration.
"""

def __init__(self, algorithm='random_hill_climb', max_iters=100, bias=True,
Expand Down Expand Up @@ -983,9 +1004,9 @@ class LogisticRegression(BaseNeuralNetwork, ClassifierMixin):
random_state: int, default: None
If random_state is a positive integer, random_state is the seed used
by np.random.seed(); otherwise, the random seed is not set.
curve: bool, default: False
If bool is true, curve containing the fitness at each training
If bool is true, curve containing the fitness at each training
iteration is returned.
Attributes
Expand All @@ -996,20 +1017,22 @@ class LogisticRegression(BaseNeuralNetwork, ClassifierMixin):
loss: float
Value of loss function for fitted weights when :code:`fit` is
performed.
fitness_curve: array
Numpy array giving the fitness at each training iteration.
"""

def __init__(self, algorithm='random_hill_climb', max_iters=100, bias=True,
learning_rate=0.1, early_stopping=False, clip_max=1e+10,
restarts=0, schedule=GeomDecay(), pop_size=200, mutation_prob=0.1,
max_attempts=10, random_state=None, curve=False):
restarts=0, schedule=GeomDecay(), pop_size=200,
mutation_prob=0.1, max_attempts=10, random_state=None,
curve=False):

BaseNeuralNetwork.__init__(
self, hidden_nodes=[], activation='sigmoid',
algorithm=algorithm, max_iters=max_iters, bias=bias,
is_classifier=True, learning_rate=learning_rate,
early_stopping=early_stopping, clip_max=clip_max, restarts=restarts,
schedule=schedule, pop_size=pop_size, mutation_prob=mutation_prob,
max_attempts=max_attempts, random_state=random_state, curve=curve)
early_stopping=early_stopping, clip_max=clip_max,
restarts=restarts, schedule=schedule, pop_size=pop_size,
mutation_prob=mutation_prob, max_attempts=max_attempts,
random_state=random_state, curve=curve)

0 comments on commit 8d36c2d

Please sign in to comment.