Skip to content

Commit

Permalink
add valid for graph on report 2
Browse files Browse the repository at this point in the history
  • Loading branch information
orbxball committed Mar 16, 2017
1 parent 7ab989b commit 2a6da32
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion hw1/stable.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,14 @@ def extract_feature(M, features, squares, cubics):
std = np.std(x_data, axis=0)
x_data = (x_data - mean) / (std + 1e-20)

number = 5000
x_data = x_data[:number]
y_data = y_data[:number]

valid_num = -20
x_data_valid = x_data[valid_num:]
y_data_valid = y_data[valid_num:]

# ydata = b + w * xdata
b = 0.0
w = np.zeros(length*9)
Expand All @@ -58,21 +66,23 @@ def extract_feature(M, features, squares, cubics):
for e in range(epoch):
# Calculate the value of the loss function
error = y_data - b - np.dot(x_data, w) #shape: (5652,)
error2 = y_data_valid - b - np.dot(x_data_valid, w) #shape: (valid_num,)

# Calculate gradient
b_grad = -2*np.sum(error)*1 #shape: ()
w_grad = -2*np.dot(error, x_data) #shape: (162,)
b_lr = b_lr + b_grad**2
w_lr = w_lr + w_grad**2
loss = np.mean(np.square(error))
valid_loss = np.mean(np.square(error2))

# Update parameters.
b = b - lr/np.sqrt(b_lr) * b_grad
w = w - lr/np.sqrt(w_lr) * w_grad

# Print loss
if (e+1) % 1000 == 0:
print('epoch:{}\n Loss:{}'.format(e+1, np.sqrt(loss)))
print('epoch:{}\n Loss:{} valid:{}\n'.format(e+1, np.sqrt(loss), np.sqrt(valid_loss)))


# Test
Expand Down

0 comments on commit 2a6da32

Please sign in to comment.