Skip to content

Commit

Permalink
[FLINK-1378] [scala] Fix type extraction for nested type parameters
Browse files Browse the repository at this point in the history
Before, something like:
def f[T: TypeInformation](data: DataSet[T]) = {
  val tpe = createTypeInformation[(T, Seq[T])]
  println("Type: " + tpe)
}

f(Seq(1.0f, 2.0f)

would fail because the type extractor could not re-use existing
TypeInformation for nested types.
  • Loading branch information
aljoscha authored and StephanEwen committed Jan 11, 2015
1 parent 3a39352 commit 935e316
Show file tree
Hide file tree
Showing 4 changed files with 54 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ import org.apache.flink.types.StringValue
import org.apache.flink.types.LongValue
import org.apache.flink.types.ShortValue


private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C]
with TypeDescriptors[C] =>

Expand All @@ -54,6 +53,9 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C]

cache.getOrElseUpdate(tpe) { id =>
tpe match {

case TypeParameter() => TypeParameterDescriptor(id, tpe)

case PrimitiveType(default, wrapper) => PrimitiveDescriptor(id, tpe, default, wrapper)

case BoxedPrimitiveType(default, wrapper, box, unbox) =>
Expand Down Expand Up @@ -282,6 +284,10 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C]
}
}

private object TypeParameter {
def unapply(tpe: Type): Boolean = tpe.typeSymbol.isParameter
}

private object CaseClassType {
def unapply(tpe: Type): Boolean = tpe.typeSymbol.asClass.isCaseClass
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C]
def canBeKey = tpe <:< typeOf[Comparable[_]]
}

case class TypeParameterDescriptor(id: Int, tpe: Type) extends UDTDescriptor {
override val isPrimitiveProduct = false
override def flatten = Seq(this)
override def canBeKey = false
}

case class PrimitiveDescriptor(id: Int, tpe: Type, default: Literal, wrapper: Type)
extends UDTDescriptor {
override val isPrimitiveProduct = true
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ private[flink] trait TypeInformationGen[C <: Context] {
mkCaseClassTypeInfo(cc)(c.WeakTypeTag(tpe).asInstanceOf[c.WeakTypeTag[Product]])
.asInstanceOf[c.Expr[TypeInformation[T]]]

case tp: TypeParameterDescriptor => mkTypeParameter(tp)

case p : PrimitiveDescriptor => mkPrimitiveTypeInfo(p.tpe)
case p : BoxedPrimitiveDescriptor => mkPrimitiveTypeInfo(p.tpe)

Expand Down Expand Up @@ -271,6 +273,24 @@ private[flink] trait TypeInformationGen[C <: Context] {
}
}

def mkTypeParameter[T: c.WeakTypeTag](
typeParameter: TypeParameterDescriptor): c.Expr[TypeInformation[T]] = {

val result = c.inferImplicitValue(
c.weakTypeOf[TypeInformation[T]],
silent = true,
withMacrosDisabled = false,
pos = c.enclosingPosition)

if (result.isEmpty) {
c.error(
c.enclosingPosition,
s"could not find implicit value of type TypeInformation[${typeParameter.tpe}].")
}

c.Expr[TypeInformation[T]](result)
}

def mkPrimitiveTypeInfo[T: c.WeakTypeTag](tpe: Type): c.Expr[TypeInformation[T]] = {
val tpeClazz = c.Expr[Class[T]](Literal(Constant(tpe)))
reify {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.io.DataInput
import java.io.DataOutput
import org.apache.flink.api.common.typeinfo._
import org.apache.flink.api.java.typeutils._
import org.apache.flink.api.scala.typeutils.CaseClassTypeInfo
import org.apache.flink.types.{IntValue, StringValue}
import org.apache.hadoop.io.Writable
import org.junit.Assert
Expand Down Expand Up @@ -59,6 +60,26 @@ class TypeInformationGenTest {
Assert.assertEquals(classOf[java.lang.Boolean], ti.getTypeClass)
}

@Test
def testTypeParameters(): Unit = {

val data = Seq(1.0d, 2.0d)

def f[T: TypeInformation](data: Seq[T]): (T, Seq[T]) = {

val ti = createTypeInformation[(T, Seq[T])]

Assert.assertTrue(ti.isTupleType)
val ccti = ti.asInstanceOf[CaseClassTypeInfo[(T, Seq[T])]]
Assert.assertEquals(BasicTypeInfo.DOUBLE_TYPE_INFO, ccti.getTypeAt(0))

(data.head, data)
}

f(data)

}

@Test
def testWritableType(): Unit = {
val ti = createTypeInformation[MyWritable]
Expand Down

0 comments on commit 935e316

Please sign in to comment.