Skip to content

Commit

Permalink
fix missing dropout in HiddenLayerBuilder
Browse files Browse the repository at this point in the history
  • Loading branch information
haifengl committed Feb 1, 2024
1 parent 38abd9d commit b6cd0a1
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
5 changes: 3 additions & 2 deletions core/src/main/java/smile/base/mlp/HiddenLayer.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ public class HiddenLayer extends Layer {
* Constructor.
* @param n the number of neurons.
* @param p the number of input variables (not including bias value).
* @param dropout the dropout rate.
* @param activation the activation function.
*/
public HiddenLayer(int n, int p, ActivationFunction activation) {
super(n, p);
public HiddenLayer(int n, int p, double dropout, ActivationFunction activation) {
super(n, p, dropout);
this.activation = activation;
}

Expand Down
2 changes: 1 addition & 1 deletion core/src/main/java/smile/base/mlp/HiddenLayerBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,6 @@ public String toString() {

@Override
public HiddenLayer build(int p) {
return new HiddenLayer(neurons, p, activation);
return new HiddenLayer(neurons, p, dropout, activation);
}
}
6 changes: 3 additions & 3 deletions core/src/test/java/smile/classification/MLPTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ public void testUSPS() throws Exception {
int k = MathEx.max(USPS.y) + 1;

MLP model = new MLP(Layer.input(p),
Layer.leaky(768, 0.5, 0.02),
Layer.leaky(768, 0.2, 0.02),
Layer.rectifier(192),
Layer.rectifier(30),
Layer.mle(k, OutputFunction.SOFTMAX)
Expand All @@ -246,7 +246,7 @@ public void testUSPS() throws Exception {
System.out.println("Test Error = " + error);
}

assertEquals(109, error);
assertEquals(115, error, 5);

java.nio.file.Path temp = Write.object(model);
Read.object(temp);
Expand Down Expand Up @@ -281,7 +281,7 @@ public void testUSPSMiniBatch() {
double[][] batchx = new double[batch][];
int[] batchy = new int[batch];
int error = 0;
for (int epoch = 1; epoch <= 10; epoch++) {
for (int epoch = 1; epoch <= 8; epoch++) {
System.out.format("----- epoch %d -----%n", epoch);
int[] permutation = MathEx.permutate(x.length);
int i = 0;
Expand Down

0 comments on commit b6cd0a1

Please sign in to comment.