Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-531] CNN Examples for Scala new API #11292

Merged
merged 11 commits into from
Jul 16, 2018
Prev Previous commit
Next Next commit
Change dropout and epoch number
  • Loading branch information
lanking520 committed Jul 11, 2018
commit 4a0a839249bdc7c7133a131bf19fb6324b84047c
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ object CNNTextClassification {

def setupCnnModel(ctx: Context, batchSize: Int, sentenceSize: Int, numEmbed: Int,
numLabel: Int = 2, numFilter: Int = 100, filterList: Array[Int ] = Array(3, 4, 5),
dropout: Float = 0.5f): CNNModel = {
dropout: Float = 0.0f): CNNModel = {

val cnn = makeTextCNN(sentenceSize, numEmbed, batchSize,
numLabel, filterList, numFilter, dropout)
Expand All @@ -98,7 +98,7 @@ object CNNTextClassification {
devLabels: Array[Float], batchSize: Int, saveModelPath: String,
learningRate: Float = 0.001f): Float = {
val maxGradNorm = 0.5f
val epoch = 200
val epoch = 30
val initializer = new Uniform(0.1f)
val opt = new RMSProp(learningRate)
val updater = Optimizer.getUpdater(opt)
Expand Down