Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
introduce random API
Browse files Browse the repository at this point in the history
  • Loading branch information
mdespriee committed Nov 5, 2018
1 parent 53c5a72 commit 704ab60
Show file tree
Hide file tree
Showing 11 changed files with 433 additions and 89 deletions.
16 changes: 16 additions & 0 deletions scala-package/core/src/main/scala/org/apache/mxnet/Base.scala
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,19 @@ private[mxnet] object Base {
}

class MXNetError(val err: String) extends Exception(err)

// used in the API Symbol.random for functions accepting multiple input types
class SymbolOrValue[T]
object SymbolOrValue {
implicit object FSymbolWitness extends SymbolOrValue[Float]
implicit object ISymbolWitness extends SymbolOrValue[Int]
implicit object SymbolWitness extends SymbolOrValue[Symbol]
}

// used in the API NDArray.random for functions accepting multiple input types
class NDArrayOrValue[T]
object NDArrayOrValue {
implicit object FArrayWitness extends NDArrayOrValue[Float]
implicit object IArrayWitness extends NDArrayOrValue[Int]
implicit object ArrayWitness extends NDArrayOrValue[NDArray]
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ object NDArray extends NDArrayBase {
private val functions: Map[String, NDArrayFunction] = initNDArrayModule()

val api = NDArrayAPI
val random = NDArrayRandomAPI

private def addDependency(froms: Array[NDArray], tos: Array[NDArray]): Unit = {
froms.foreach { from =>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,22 @@
* limitations under the License.
*/
package org.apache.mxnet
@AddNDArrayAPIs(false)

/**
* typesafe NDArray API: NDArray.api._
* Main code will be generated during compile time through Macros
*/
@AddNDArrayAPIs(false)
object NDArrayAPI extends NDArrayAPIBase {
// TODO: Implement CustomOp for NDArray
}

/**
* typesafe NDArray random module: NDArray.random._
* Main code will be generated during compile time through Macros
*/
@AddNDArrayRandomAPIs(false)
object NDArrayRandomAPI extends NDArrayRandomAPIBase {

}

Original file line number Diff line number Diff line change
Expand Up @@ -841,6 +841,7 @@ object Symbol extends SymbolBase {
private val bindReqMap = Map("null" -> 0, "write" -> 1, "add" -> 3)

val api = SymbolAPI
val random = SymbolRandomAPI

def pow(sym1: Symbol, sym2: Symbol): Symbol = {
Symbol.createFromListedSymbols("_Power")(Array(sym1, sym2))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ package org.apache.mxnet
import scala.collection.mutable


@AddSymbolAPIs(false)
/**
* typesafe Symbol API: Symbol.api._
* Main code will be generated during compile time through Macros
*/
@AddSymbolAPIs(false)
object SymbolAPI extends SymbolAPIBase {
def Custom (op_type : String, kwargs : mutable.Map[String, Any],
name : String = null, attr : Map[String, String] = null) : Symbol = {
Expand All @@ -32,3 +32,13 @@ object SymbolAPI extends SymbolAPIBase {
Symbol.createSymbolGeneral("Custom", name, attr, Seq(), map.toMap)
}
}

/**
* typesafe Symbol random module: Symbol.random._
* Main code will be generated during compile time through Macros
*/
@AddSymbolRandomAPIs(false)
object SymbolRandomAPI extends SymbolRandomAPIBase {

}

Original file line number Diff line number Diff line change
Expand Up @@ -576,4 +576,21 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll with Matchers {
assert(arr.internal.toDoubleArray === Array(2d, 2d))
assert(arr.internal.toByteArray === Array(2.toByte, 2.toByte))
}

test("NDArray random module is generated properly") {
val lam = NDArray.ones(1, 2)
val rnd = NDArray.random.poisson(lam = Some(lam), shape = Some(Shape(3, 4)))
val rnd2 = NDArray.random.poisson(lam = Some(1f), shape = Some(Shape(3, 4)))
assert(rnd.shape === Shape(1, 2, 3, 4))
assert(rnd2.shape === Shape(3, 4))
}

test("NDArray random module is generated properly - special case of 'normal'") {
val mu = NDArray.ones(1, 2)
val sigma = NDArray.ones(1, 2) * 2
val rnd = NDArray.random.normal(mu = Some(mu), sigma = Some(sigma), shape = Some(Shape(3, 4)))
val rnd2 = NDArray.random.normal(mu = Some(1f), sigma = Some(2f), shape = Some(Shape(3, 4)))
assert(rnd.shape === Shape(1, 2, 3, 4))
assert(rnd2.shape === Shape(3, 4))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package org.apache.mxnet
import org.scalatest.{BeforeAndAfterAll, FunSuite}

class SymbolSuite extends FunSuite with BeforeAndAfterAll {

test("symbol compose") {
val data = Symbol.Variable("data")

Expand Down Expand Up @@ -71,4 +72,25 @@ class SymbolSuite extends FunSuite with BeforeAndAfterAll {
val data2 = data.clone()
assert(data.toJson === data2.toJson)
}

test("Symbol random module is generated properly") {
val lam = Symbol.Variable("lam")
val rnd = Symbol.random.poisson(lam = Some(lam), shape = Some(Shape(2, 2)))
val rnd2 = Symbol.random.poisson(lam = Some(1f), shape = Some(Shape(2, 2)))
// scalastyle:off println
println(s"Symbol.random.poisson debug info: ${rnd.debugStr}")
println(s"Symbol.random.poisson debug info: ${rnd2.debugStr}")
// scalastyle:on println
}

test("Symbol random module is generated properly - special case of 'normal'") {
val loc = Symbol.Variable("loc")
val scale = Symbol.Variable("scale")
val rnd = Symbol.random.normal(mu = Some(loc), sigma = Some(scale), shape = Some(Shape(2, 2)))
val rnd2 = Symbol.random.normal(mu = Some(1f), sigma = Some(2f), shape = Some(Shape(2, 2)))
// scalastyle:off println
println(s"Symbol.random.sample_normal debug info: ${rnd.debugStr}")
println(s"Symbol.random.random_normal debug info: ${rnd2.debugStr}")
// scalastyle:on println
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@ import scala.collection.mutable.ListBuffer
* Two file namely: SymbolAPIBase.scala and NDArrayAPIBase.scala
* The code will be executed during Macros stage and file live in Core stage
*/
private[mxnet] object APIDocGenerator extends GeneratorBase {
private[mxnet] object APIDocGenerator extends GeneratorBase with RandomHelpers {

def main(args: Array[String]): Unit = {
val FILE_PATH = args(0)
val hashCollector = ListBuffer[String]()
hashCollector += typeSafeClassGen(FILE_PATH, true)
hashCollector += typeSafeClassGen(FILE_PATH, false)
hashCollector += typeSafeRandomClassGen(FILE_PATH, true)
hashCollector += typeSafeRandomClassGen(FILE_PATH, false)
hashCollector += nonTypeSafeClassGen(FILE_PATH, true)
hashCollector += nonTypeSafeClassGen(FILE_PATH, false)
val finalHash = hashCollector.mkString("\n")
Expand Down Expand Up @@ -61,6 +63,22 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
generated)
}

def typeSafeRandomClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
val generated = typeSafeRandomFunctionsToGenerate(isSymbol)
.map { func =>
val scalaDoc = generateAPIDocFromBackend(func)
val typeParameter = randomGenericTypeSpec(isSymbol)
val decl = generateAPISignature(func, isSymbol, typeParameter)
s"$scalaDoc\n$decl"
}

writeFile(
FILE_PATH,
if (isSymbol) "SymbolRandomAPIBase" else "NDArrayRandomAPIBase",
"package org.apache.mxnet",
generated)
}

def nonTypeSafeClassGen(FILE_PATH: String, isSymbol: Boolean): String = {
val absFuncs = functionsToGenerate(isSymbol, isContrib = false)
.map { func =>
Expand Down Expand Up @@ -113,22 +131,22 @@ private[mxnet] object APIDocGenerator extends GeneratorBase {
}
}

def generateAPISignature(func: Func, isSymbol: Boolean): String = {
val argDef = ListBuffer[String]()
def generateAPISignature(func: Func, isSymbol: Boolean, typeParameter: String = ""): String = {
val argDecl = ListBuffer[String]()

argDef ++= typedFunctionCommonArgDef(func)
argDecl ++= buildArgDecl(func)

if (isSymbol) {
argDef += "name : String = null"
argDef += "attr : Map[String, String] = null"
argDecl += "name : String = null"
argDecl += "attr : Map[String, String] = null"
} else {
argDef += "out : Option[NDArray] = None"
argDecl += "out : Option[NDArray] = None"
}

val returnType = func.returnType

s"""@Experimental
|def ${func.name} (${argDef.mkString(", ")}): $returnType""".stripMargin
|def ${func.name}$typeParameter (${argDecl.mkString(", ")}): $returnType""".stripMargin
}

def writeFile(FILE_PATH: String, className: String, packageDef: String,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.mxnet.utils.{CToScalaUtils, OperatorBuildUtils}
import scala.collection.mutable.ListBuffer
import scala.reflect.macros.blackbox

abstract class GeneratorBase {
private[mxnet] abstract class GeneratorBase {
type Handle = Long

case class Arg(argName: String, argType: String, argDesc: String, isOptional: Boolean) {
Expand All @@ -36,7 +36,8 @@ abstract class GeneratorBase {

case class Func(name: String, desc: String, listOfArgs: List[Arg], returnType: String)

def functionsToGenerate(isSymbol: Boolean, isContrib: Boolean): List[Func] = {
// filter the operators to generate (in the non type-safe apis)
protected def functionsToGenerate(isSymbol: Boolean, isContrib: Boolean): List[Func] = {
val l = getBackEndFunctions(isSymbol)
if (isContrib) {
l.filter(func => func.name.startsWith("_contrib_") || !func.name.startsWith("_"))
Expand All @@ -45,7 +46,8 @@ abstract class GeneratorBase {
}
}

def typeSafeFunctionsToGenerate(isSymbol: Boolean, isContrib: Boolean): List[Func] = {
// filter the operators to generate in the type-safe Symbol.api and NDArray.api
protected def typeSafeFunctionsToGenerate(isSymbol: Boolean, isContrib: Boolean): List[Func] = {
// Operators that should not be generated
val notGenerated = Set("Custom")

Expand Down Expand Up @@ -138,8 +140,8 @@ abstract class GeneratorBase {
result
}

protected def typedFunctionCommonArgDef(func: Func): List[String] = {
// build function argument definition, with optionality, and safe names
// build function argument definition, with optionality, and safe names
protected def buildArgDecl(func: Func): List[String] = {
func.listOfArgs.map(arg =>
if (arg.isOptional) {
// let's avoid a stupid Option[Array[...]]
Expand All @@ -155,3 +157,70 @@ abstract class GeneratorBase {
)
}
}

// a mixin to ease generating the Random module
private[mxnet] trait RandomHelpers {
self: GeneratorBase =>

// a generic type spec used in Symbol.random and NDArray.random modules
protected def randomGenericTypeSpec(isSymbol: Boolean): String = {
if (isSymbol) "[T: SymbolOrValue : scala.reflect.runtime.universe.TypeTag]"
else "[T: NDArrayOrValue : scala.reflect.runtime.universe.TypeTag]"
}

// filter the operators to generate in the type-safe Symbol.random and NDArray.random
protected def typeSafeRandomFunctionsToGenerate(isSymbol: Boolean): List[Func] = {
getBackEndFunctions(isSymbol)
.filter(f => f.name.startsWith("_sample_") || f.name.startsWith("_random_"))
.map(f => f.copy(name = f.name.stripPrefix("_")))
// unify _random and _sample
.map(f => unifyRandom(f, isSymbol))
// deduplicate
.groupBy(_.name)
.mapValues(_.head)
.values
.toList
}

// unify call targets (random_xyz and sample_xyz) and unify their argument types
private def unifyRandom(func: Func, isSymbol: Boolean): Func = {
var typeConv = Set("org.apache.mxnet.NDArray", "org.apache.mxnet.Symbol",
"org.apache.mxnet.Base.MXFloat", "Int")

func.copy(
name = func.name.replaceAll("(random|sample)_", ""),
listOfArgs = func.listOfArgs
.map(hackNormalFunc)
.map(arg =>
if (typeConv(arg.argType)) arg.copy(argType = "T")
else arg
)
// TODO: some functions are non consistent in random_ vs sample_ regarding optionality
// we may try to unify that as well here.
)
}

// hacks to manage the fact that random_normal and sample_normal have
// non-consistent parameter naming in the back-end
// this first one, merge loc/scale and mu/sigma
protected def hackNormalFunc(arg: Arg): Arg = {
if (arg.argName == "loc") arg.copy(argName = "mu")
else if (arg.argName == "scale") arg.copy(argName = "sigma")
else arg
}

// this second one reverts this merge prior to back-end call
protected def unhackNormalFunc(func: Func): String = {
if (func.name.equals("normal")) {
s"""if(target.equals("random_normal")) {
| if(map.contains("mu")) { map("loc") = map("mu"); map.remove("mu") }
| if(map.contains("sigma")) { map("scale") = map("sigma"); map.remove("sigma") }
|}
""".stripMargin
} else {
""
}

}

}
Loading

0 comments on commit 704ab60

Please sign in to comment.