Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-531] CNN Examples for Scala new API #11292

Merged
merged 11 commits into from
Jul 16, 2018
Prev Previous commit
Next Next commit
Add changes to pass CPU and GPU test
  • Loading branch information
lanking520 committed Jul 11, 2018
commit b40f55d847403ce41caa6b04ff64ca16b0b3a285
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ object CNNTextClassification {
}

def setupCnnModel(ctx: Context, batchSize: Int, sentenceSize: Int, numEmbed: Int,
numLabel: Int = 2, numFilter: Int = 100, filterList: Array[Int ] = Array(3, 4, 5),
dropout: Float = 0.5f): CNNModel = {
numLabel: Int = 2, numFilter: Int = 100, filterList: Array[Int ] = Array(3, 4, 5),
dropout: Float = 0.5f): CNNModel = {

val cnn = makeTextCNN(sentenceSize, numEmbed, batchSize,
numLabel, filterList, numFilter, dropout)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,9 @@ class CNNClassifierExampleSuite extends FunSuite with BeforeAndAfterAll {
"/scala-example-ci/CNN/" + w2vModelName
+ " -P " + tempDirPath + "/CNN/ -q") !

val modelDirPath = tempDirPath + File.separator + "CNN/"
val modelDirPath = tempDirPath + File.separator + "CNN"

val output = CNNTextClassification.test(modelDirPath + w2vModelName,
val output = CNNTextClassification.test(modelDirPath + File.separator + w2vModelName,
modelDirPath, context, modelDirPath)

Process("rm -rf " + modelDirPath) !
Expand All @@ -65,6 +65,5 @@ class CNNClassifierExampleSuite extends FunSuite with BeforeAndAfterAll {
} else {
logger.info("Skip this test as it intended for GPU only")
}

}
}