Skip to content

Commit

Permalink
fix style
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Dec 6, 2021
1 parent 4063fab commit 18dfa5c
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 7 deletions.
4 changes: 3 additions & 1 deletion project/sentiment_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,9 @@ def render_run_sentiment_interface():
"Learning rate", value=0.01, step=0.001, format="%.3f"
)
n_training_data = col1.number_input("N datapoints from training data", value=450)
n_validation_data = col2.number_input("N datapoints from validation data", value=100)
n_validation_data = col2.number_input(
"N datapoints from validation data", value=100
)
batch_size = st.number_input("Batch size", value=10)

if st.button("Train model"):
Expand Down
9 changes: 3 additions & 6 deletions tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,17 @@
def test_avg(t):
out = minitorch.avgpool2d(t, (2, 2))
assert_close(
out[0, 0, 0, 0],
sum([t[0, 0, i, j] for i in range(2) for j in range(2)]) / 4.0
out[0, 0, 0, 0], sum([t[0, 0, i, j] for i in range(2) for j in range(2)]) / 4.0
)

out = minitorch.avgpool2d(t, (2, 1))
assert_close(
out[0, 0, 0, 0],
sum([t[0, 0, i, j] for i in range(2) for j in range(1)]) / 2.0
out[0, 0, 0, 0], sum([t[0, 0, i, j] for i in range(2) for j in range(1)]) / 2.0
)

out = minitorch.avgpool2d(t, (1, 2))
assert_close(
out[0, 0, 0, 0],
sum([t[0, 0, i, j] for i in range(1) for j in range(2)]) / 2.0
out[0, 0, 0, 0], sum([t[0, 0, i, j] for i in range(1) for j in range(2)]) / 2.0
)
minitorch.grad_check(lambda t: minitorch.avgpool2d(t, (2, 2)), t)

Expand Down

0 comments on commit 18dfa5c

Please sign in to comment.