Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-918] Random module #13039

Merged
merged 11 commits into from
Dec 14, 2018
Prev Previous commit
Next Next commit
Merge branch 'master' into random_module_2
# Conflicts:
#	scala-package/macros/src/main/scala/org/apache/mxnet/APIDocGenerator.scala
#	scala-package/macros/src/main/scala/org/apache/mxnet/GeneratorBase.scala
  • Loading branch information
mdespriee committed Nov 29, 2018
commit 0b8837ea21bcb7483efec2f320ed3a4a5f033370
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,17 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {
}

def generateAPIDocFromBackend(func: Func, withParam: Boolean = true): String = {
val desc = func.desc.split("\n")
.mkString(" * <pre>", "\n * ", "\n * </pre>")
def fixDesc(desc: String): String = {
var curDesc = desc
var prevDesc = ""
while ( curDesc != prevDesc ) {
prevDesc = curDesc
curDesc = curDesc.replace("[[", "`[ [").replace("]]", "] ]")
}
curDesc
}
val desc = fixDesc(func.desc).split("\n")
.mkString(" *\n * {{{\n *\n * ", "\n * ", "\n * }}}\n * ")

val params = func.listOfArgs.map { absClassArg =>
s" * @param ${absClassArg.safeArgName}\t\t${fixDesc(absClassArg.argDesc)}"
Expand Down Expand Up @@ -178,6 +187,64 @@ private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {
|def ${func.name}$typeParameter (${argDef.mkString(", ")}): $returnType""".stripMargin
}

def generateJavaAPISignature(func : Func) : String = {
val useParamObject = func.listOfArgs.count(arg => arg.isOptional) >= 2
var argDef = ListBuffer[String]()
var classDef = ListBuffer[String]()
var requiredParam = ListBuffer[String]()
func.listOfArgs.foreach(absClassArg => {
val currArgName = absClassArg.safeArgName
// scalastyle:off
if (absClassArg.isOptional && useParamObject) {
classDef +=
s"""private var $currArgName: ${absClassArg.argType} = null
|/**
| * @param $currArgName\t\t${absClassArg.argDesc}
| */
|def set${currArgName.capitalize}($currArgName : ${absClassArg.argType}): ${func.name}Param = {
| this.$currArgName = $currArgName
| this
| }""".stripMargin
}
else {
requiredParam += s" * @param $currArgName\t\t${absClassArg.argDesc}"
argDef += s"$currArgName : ${absClassArg.argType}"
}
classDef += s"def get${currArgName.capitalize}() = this.$currArgName"
// scalastyle:on
})
val experimentalTag = "@Experimental"
val returnType = "Array[NDArray]"
val scalaDoc = generateAPIDocFromBackend(func)
val scalaDocNoParam = generateAPIDocFromBackend(func, false)
if(useParamObject) {
classDef +=
s"""private var out : org.apache.mxnet.NDArray = null
|def setOut(out : NDArray) : ${func.name}Param = {
| this.out = out
| this
| }
| def getOut() = this.out
| """.stripMargin
s"""$scalaDocNoParam
| $experimentalTag
| def ${func.name}(po: ${func.name}Param) : $returnType
| /**
| * This Param Object is specifically used for ${func.name}
| ${requiredParam.mkString("\n")}
| */
| class ${func.name}Param(${argDef.mkString(",")}) {
| ${classDef.mkString("\n ")}
| }""".stripMargin
} else {
argDef += "out : NDArray"
s"""$scalaDoc
|$experimentalTag
| def ${func.name}(${argDef.mkString(", ")}) : $returnType
| """.stripMargin
}
}

def writeFile(FILE_PATH: String, packageDef: String, className: String,
imports: String, absFuncs: Seq[String]): String = {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ private[mxnet] abstract class GeneratorBase {

case class Func(name: String, desc: String, listOfArgs: List[Arg], returnType: String)

// filter the operators to generate (in the non type-safe apis)
protected def functionsToGenerate(isSymbol: Boolean, isContrib: Boolean): List[Func] = {
val l = getBackEndFunctions(isSymbol)
def functionsToGenerate(isSymbol: Boolean, isContrib: Boolean,
isJava: Boolean = false): List[Func] = {
val l = getBackEndFunctions(isSymbol, isJava)
if (isContrib) {
l.filter(func => func.name.startsWith("_contrib_") || !func.name.startsWith("_"))
} else {
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.