Skip to content

Commit

Permalink
[FLINK-12301] Fix ScalaCaseClassSerializer to support value types
Browse files Browse the repository at this point in the history
We now use Scala reflection because it correctly deals with Scala
language features.
  • Loading branch information
Igal Shilman authored and aljoscha committed May 13, 2019
1 parent 9c2bcae commit 9caf2c4
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@

package org.apache.flink.api.scala.typeutils

import java.io.ObjectInputStream

import org.apache.flink.api.common.typeutils.CompositeTypeSerializerUtil.delegateCompatibilityCheckToNewSnapshot
import org.apache.flink.api.common.typeutils.TypeSerializerConfigSnapshot.SelfResolvingTypeSerializer
import org.apache.flink.api.common.typeutils._
import org.apache.flink.api.java.typeutils.runtime.TupleSerializerConfigSnapshot
import org.apache.flink.api.scala.typeutils.ScalaCaseClassSerializer.lookupConstructor

import java.io.ObjectInputStream
import java.lang.invoke.{MethodHandle, MethodHandles}

import scala.collection.JavaConverters._
import scala.reflect.runtime.universe

Expand All @@ -38,25 +37,24 @@ import scala.reflect.runtime.universe
*/
@SerialVersionUID(1L)
class ScalaCaseClassSerializer[T <: Product](
clazz: Class[T],
scalaFieldSerializers: Array[TypeSerializer[_]]
) extends CaseClassSerializer[T](clazz, scalaFieldSerializers)
with SelfResolvingTypeSerializer[T] {
clazz: Class[T],
scalaFieldSerializers: Array[TypeSerializer[_]]
) extends CaseClassSerializer[T](clazz, scalaFieldSerializers)
with SelfResolvingTypeSerializer[T] {

@transient
private var constructor = lookupConstructor(clazz)

override def createInstance(fields: Array[AnyRef]): T = {
constructor.invoke(fields).asInstanceOf[T]
constructor(fields)
}

override def snapshotConfiguration(): TypeSerializerSnapshot[T] = {
new ScalaCaseClassSerializerSnapshot[T](this)
}

override def resolveSchemaCompatibilityViaRedirectingToNewSnapshotClass(
s: TypeSerializerConfigSnapshot[T]
): TypeSerializerSchemaCompatibility[T] = {
s: TypeSerializerConfigSnapshot[T]): TypeSerializerSchemaCompatibility[T] = {

require(s.isInstanceOf[TupleSerializerConfigSnapshot[_]])

Expand Down Expand Up @@ -85,22 +83,8 @@ class ScalaCaseClassSerializer[T <: Product](

object ScalaCaseClassSerializer {

def lookupConstructor[T](clazz: Class[_]): MethodHandle = {
val types = findPrimaryConstructorParameterTypes(clazz, clazz.getClassLoader)

val constructor = clazz.getConstructor(types: _*)

val handle = MethodHandles
.lookup()
.unreflectConstructor(constructor)
.asSpreader(classOf[Array[AnyRef]], types.length)

handle
}

private def findPrimaryConstructorParameterTypes(cls: Class[_], cl: ClassLoader):
List[Class[_]] = {
val rootMirror = universe.runtimeMirror(cl)
def lookupConstructor[T](cls: Class[T]): Array[AnyRef] => T = {
val rootMirror = universe.runtimeMirror(cls.getClassLoader)
val classSymbol = rootMirror.classSymbol(cls)

require(
Expand All @@ -113,30 +97,21 @@ object ScalaCaseClassSerializer {
|""".stripMargin
)

val primaryConstructorSymbol = findPrimaryConstructorMethodSymbol(classSymbol)
val scalaTypes = getArgumentsTypes(primaryConstructorSymbol)
scalaTypes.map(tpe => scalaTypeToJavaClass(rootMirror)(tpe))
}

private def findPrimaryConstructorMethodSymbol(classSymbol: universe.ClassSymbol):
universe.MethodSymbol = {
classSymbol.toType
val primaryConstructorSymbol = classSymbol.toType
.decl(universe.termNames.CONSTRUCTOR)
.alternatives
.collectFirst({
case constructorSymbol: universe.MethodSymbol if constructorSymbol.isPrimaryConstructor =>
constructorSymbol
})
.head
.asMethod
}

private def getArgumentsTypes(primaryConstructorSymbol: universe.MethodSymbol):
List[universe.Type] = {
primaryConstructorSymbol.typeSignature
.paramLists
.head
.map(symbol => symbol.typeSignature)
}
val classMirror = rootMirror.reflectClass(classSymbol)
val constructorMethodMirror = classMirror.reflectConstructor(primaryConstructorSymbol)

private def scalaTypeToJavaClass(mirror: universe.Mirror)(scalaType: universe.Type): Class[_] = {
val erasure = scalaType.erasure
mirror.runtimeClass(erasure)
arr: Array[AnyRef] => {
constructorMethodMirror.apply(arr: _*).asInstanceOf[T]
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,11 @@

package org.apache.flink.api.scala.typeutils

import org.apache.flink.api.scala.typeutils.ScalaCaseClassSerializerReflectionTest.{Generic, HigherKind, SimpleCaseClass}
import org.apache.flink.api.scala.typeutils.ScalaCaseClassSerializerReflectionTest._
import org.apache.flink.util.TestLogger

import org.junit.Assert.assertEquals
import org.junit.Test

import java.lang.invoke.MethodHandle

/**
* Test obtaining the primary constructor of a case class
Expand All @@ -34,40 +32,40 @@ class ScalaCaseClassSerializerReflectionTest extends TestLogger {

@Test
def usageExample(): Unit = {
val constructor: MethodHandle = ScalaCaseClassSerializer
val constructor = ScalaCaseClassSerializer
.lookupConstructor(classOf[SimpleCaseClass])

val actual = constructor.invoke(Array("hi", 1.asInstanceOf[Any]))
val actual = constructor(Array("hi", 1.asInstanceOf[AnyRef]))

assertEquals(SimpleCaseClass("hi", 1), actual)
}

@Test
def genericCaseClass(): Unit = {
val constructor: MethodHandle = ScalaCaseClassSerializer
val constructor = ScalaCaseClassSerializer
.lookupConstructor(classOf[Generic[_]])

val actual = constructor.invoke(Array(1.asInstanceOf[AnyRef]))
val actual = constructor(Array(1.asInstanceOf[AnyRef]))

assertEquals(Generic[Int](1), actual)
}

@Test
def caseClassWithParameterizedList(): Unit = {
val constructor: MethodHandle = ScalaCaseClassSerializer
val constructor = ScalaCaseClassSerializer
.lookupConstructor(classOf[HigherKind])

val actual = constructor.invoke(Array(List(1, 2, 3), "hey"))
val actual = constructor(Array(List(1, 2, 3), "hey"))

assertEquals(HigherKind(List(1, 2, 3), "hey"), actual)
}

@Test
def tupleType(): Unit = {
val constructor: MethodHandle = ScalaCaseClassSerializer
val constructor = ScalaCaseClassSerializer
.lookupConstructor(classOf[(String, String, Int)])

val actual = constructor.invoke(Array("a", "b", 7))
val actual = constructor(Array("a", "b", 7.asInstanceOf[AnyRef]))

assertEquals(("a", "b", 7), actual)
}
Expand All @@ -80,6 +78,21 @@ class ScalaCaseClassSerializerReflectionTest extends TestLogger {
ScalaCaseClassSerializer
.lookupConstructor(classOf[outerInstance.InnerCaseClass])
}

@Test
def valueClass(): Unit = {
val constructor = ScalaCaseClassSerializer
.lookupConstructor(classOf[Measurement])

val arguments = Array(
1.asInstanceOf[AnyRef],
new DegreeCelsius(0.5f).asInstanceOf[AnyRef]
)

val actual = constructor(arguments)

assertEquals(Measurement(1, new DegreeCelsius(0.5f)), actual)
}
}

object ScalaCaseClassSerializerReflectionTest {
Expand All @@ -94,6 +107,12 @@ object ScalaCaseClassSerializerReflectionTest {

case class Generic[T](item: T)

class DegreeCelsius(val value: Float) extends AnyVal {
override def toString: String = s"$value °C"
}

case class Measurement(i: Int, temperature: DegreeCelsius)

}

class OuterClass {
Expand Down

0 comments on commit 9caf2c4

Please sign in to comment.