Skip to content

Commit

Permalink
[FLINK-1191] Add support for Scala Collections and Special Types
Browse files Browse the repository at this point in the history
"The special types" are Option and Either. This should work for all
Scala collections except SortedSet and SortedMap, for which the type
checker prints an error message.
  • Loading branch information
aljoscha committed Nov 5, 2014
1 parent 1e1df6d commit bd66a08
Show file tree
Hide file tree
Showing 14 changed files with 1,247 additions and 51 deletions.
1 change: 1 addition & 0 deletions flink-dist/src/main/flink-bin/LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,7 @@ BSD-style licenses:
- Scala Library (https://www.scala-lang.org/) - Copyright (c) 2002-2014 EPFL, Copyright (c) 2011-2014 Typesafe, Inc.
- Scala Compiler (BSD-like) - (https://www.scala-lang.org/) - Copyright (c) 2002-2014 EPFL, Copyright (c) 2011-2014 Typesafe, Inc.
- Scala Compiler Reflect (BSD-like) - (https://www.scala-lang.org/) - Copyright (c) 2002-2014 EPFL, Copyright (c) 2011-2014 Typesafe, Inc.
- Scala Quasiquotes (BSD-like) - (https://scalamacros.org/) - Copyright (c) 2002-2014 EPFL, Copyright (c) 2011-2014 Typesafe, Inc.
- ASM (BSD-like) - (https://asm.ow2.org/) - Copyright (c) 2000-2011 INRIA, France Telecom


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@
*/
package org.apache.flink.api.scala.codegen

import scala.collection.GenTraversableOnce
import scala.collection.mutable
import scala.collection._
import scala.reflect.macros.Context
import scala.util.DynamicVariable

Expand Down Expand Up @@ -56,30 +55,71 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C]
cache.getOrElseUpdate(tpe) { id =>
tpe match {
case PrimitiveType(default, wrapper) => PrimitiveDescriptor(id, tpe, default, wrapper)

case BoxedPrimitiveType(default, wrapper, box, unbox) =>
BoxedPrimitiveDescriptor(id, tpe, default, wrapper, box, unbox)
case ListType(elemTpe, iter) =>
analyzeList(id, tpe, elemTpe, iter)

case ArrayType(elemTpe) => analyzeArray(id, tpe, elemTpe)

case NothingType() => NothingDesciptor(id, tpe)

case TraversableType(elemTpe) => analyzeTraversable(id, tpe, elemTpe)

case EitherType(leftTpe, rightTpe) => analyzeEither(id, tpe, leftTpe, rightTpe)

case OptionType(elemTpe) => analyzeOption(id, tpe, elemTpe)

case CaseClassType() => analyzeCaseClass(id, tpe)

case ValueType() => ValueDescriptor(id, tpe)

case WritableType() => WritableDescriptor(id, tpe)
case CaseClassType() => analyzeCaseClass(id, tpe)

case JavaType() =>
// It's a Java Class, let the TypeExtractor deal with it...
c.warning(c.enclosingPosition, s"Type $tpe is a java class. Will be analyzed by " +
s"TypeExtractor at runtime.")
GenericClassDescriptor(id, tpe)

case _ => analyzePojo(id, tpe)
}
}
}

private def analyzeList(
private def analyzeArray(
id: Int,
tpe: Type,
elemTpe: Type,
iter: Tree => Tree): UDTDescriptor = analyze(elemTpe) match {
elemTpe: Type): UDTDescriptor = analyze(elemTpe) match {
case UnsupportedDescriptor(_, _, errs) => UnsupportedDescriptor(id, tpe, errs)
case desc => ListDescriptor(id, tpe, iter, desc)
case desc => ArrayDescriptor(id, tpe, desc)
}

private def analyzeTraversable(
id: Int,
tpe: Type,
elemTpe: Type): UDTDescriptor = analyze(elemTpe) match {
case UnsupportedDescriptor(_, _, errs) => UnsupportedDescriptor(id, tpe, errs)
case desc => TraversableDescriptor(id, tpe, desc)
}

private def analyzeEither(
id: Int,
tpe: Type,
leftTpe: Type,
rightTpe: Type): UDTDescriptor = analyze(leftTpe) match {
case UnsupportedDescriptor(_, _, errs) => UnsupportedDescriptor(id, tpe, errs)
case leftDesc => analyze(rightTpe) match {
case UnsupportedDescriptor(_, _, errs) => UnsupportedDescriptor(id, tpe, errs)
case rightDesc => EitherDescriptor(id, tpe, leftDesc, rightDesc)
}
}

private def analyzeOption(
id: Int,
tpe: Type,
elemTpe: Type): UDTDescriptor = analyze(elemTpe) match {
case UnsupportedDescriptor(_, _, errs) => UnsupportedDescriptor(id, tpe, errs)
case elemDesc => OptionDescriptor(id, tpe, elemDesc)
}

private def analyzePojo(id: Int, tpe: Type): UDTDescriptor = {
Expand All @@ -99,7 +139,7 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C]
.filterNot { _.annotations.exists( _.tpe <:< typeOf[scala.transient]) }

if (fields.isEmpty) {
c.warning(c.enclosingPosition, "Type $tpe has no fields that are visible from Scala Type" +
c.warning(c.enclosingPosition, s"Type $tpe has no fields that are visible from Scala Type" +
" analysis. Falling back to Java Type Analysis (TypeExtractor).")
return GenericClassDescriptor(id, tpe)
}
Expand Down Expand Up @@ -210,48 +250,72 @@ private[flink] trait TypeAnalyzer[C <: Context] { this: MacroContextHolder[C]
boxedPrimitives.get(tpe.typeSymbol)
}

private object ListType {
private object ArrayType {
def unapply(tpe: Type): Option[Type] = tpe match {
case TypeRef(_, _, elemTpe :: Nil) if tpe <:< typeOf[Array[_]] => Some(elemTpe)
case _ => None
}
}

def unapply(tpe: Type): Option[(Type, Tree => Tree)] = tpe match {

case ArrayType(elemTpe) =>
val iter = { source: Tree =>
Select(source, newTermName("iterator"))
}
Some(elemTpe, iter)
private object TraversableType {
def unapply(tpe: Type): Option[Type] = tpe match {
case _ if tpe <:< typeOf[BitSet] => Some(typeOf[Int])

case TraversableType(elemTpe) =>
val iter = { source: Tree => Select(source, newTermName("toIterator")) }
Some(elemTpe, iter)
case _ if tpe <:< typeOf[SortedMap[_, _]] => None
case _ if tpe <:< typeOf[SortedSet[_]] => None

case _ if tpe <:< typeOf[TraversableOnce[_]] =>
// val traversable = tpe.baseClasses
// .map(tpe.baseType)
// .find(t => t.erasure =:= typeOf[TraversableOnce[_]].erasure)

val traversable = tpe.baseType(typeOf[TraversableOnce[_]].typeSymbol)

traversable match {
case TypeRef(_, _, elemTpe :: Nil) =>
Some(elemTpe.asSeenFrom(tpe, tpe.typeSymbol))
case _ => None
}

case _ => None
}
}

private object ArrayType {
def unapply(tpe: Type): Option[Type] = tpe match {
case TypeRef(_, _, elemTpe :: Nil) if tpe <:< typeOf[Array[_]] => Some(elemTpe)
case _ => None
}
}
private object CaseClassType {
def unapply(tpe: Type): Boolean = tpe.typeSymbol.asClass.isCaseClass
}

private object NothingType {
def unapply(tpe: Type): Boolean = tpe =:= typeOf[Nothing]
}

private object TraversableType {
def unapply(tpe: Type): Option[Type] = tpe match {
case _ if tpe <:< typeOf[GenTraversableOnce[_]] =>
// val abstrElemTpe = genTraversableOnceClass.typeConstructor.typeParams.head.tpe
// val elemTpe = abstrElemTpe.asSeenFrom(tpe, genTraversableOnceClass)
// Some(elemTpe)
// TODO make sure this works as it should
tpe match {
case TypeRef(_, _, elemTpe :: Nil) => Some(elemTpe.asSeenFrom(tpe, tpe.typeSymbol))
}

case _ => None
private object EitherType {
def unapply(tpe: Type): Option[(Type, Type)] = {
if (tpe <:< typeOf[Either[_, _]]) {
val either = tpe.baseType(typeOf[Either[_, _]].typeSymbol)
either match {
case TypeRef(_, _, leftTpe :: rightTpe :: Nil) =>
Some(leftTpe, rightTpe)
}
} else {
None
}
}
}

private object CaseClassType {
def unapply(tpe: Type): Boolean = tpe.typeSymbol.asClass.isCaseClass
private object OptionType {
def unapply(tpe: Type): Option[Type] = {
if (tpe <:< typeOf[Option[_]]) {
val option = tpe.baseType(typeOf[Option[_]].typeSymbol)
option match {
case TypeRef(_, _, elemTpe :: Nil) =>
Some(elemTpe)
}
} else {
None
}
}
}

private object ValueType {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,27 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C]
override def canBeKey = wrapper <:< typeOf[org.apache.flink.types.Key[_]]
}

case class NothingDesciptor(id: Int, tpe: Type)
extends UDTDescriptor {
override val isPrimitiveProduct = false
override def flatten = Seq(this)
override def canBeKey = false
}

case class EitherDescriptor(id: Int, tpe: Type, left: UDTDescriptor, right: UDTDescriptor)
extends UDTDescriptor {
override val isPrimitiveProduct = false
override def flatten = Seq(this)
override def canBeKey = false
}

case class OptionDescriptor(id: Int, tpe: Type, elem: UDTDescriptor)
extends UDTDescriptor {
override val isPrimitiveProduct = false
override def flatten = Seq(this)
override def canBeKey = false
}

case class BoxedPrimitiveDescriptor(
id: Int, tpe: Type, default: Literal, wrapper: Type, box: Tree => Tree, unbox: Tree => Tree)
extends UDTDescriptor {
Expand All @@ -96,19 +117,32 @@ private[flink] trait TypeDescriptors[C <: Context] { this: MacroContextHolder[C]
}
}

case class ListDescriptor(id: Int, tpe: Type, iter: Tree => Tree, elem: UDTDescriptor)
case class ArrayDescriptor(id: Int, tpe: Type, elem: UDTDescriptor)
extends UDTDescriptor {
override def canBeKey = false
override def flatten = this +: elem.flatten

override def hashCode() = (id, tpe, elem).hashCode()
override def equals(that: Any) = that match {
case that @ ArrayDescriptor(thatId, thatTpe, thatElem) =>
(id, tpe, elem).equals((thatId, thatTpe, thatElem))
case _ => false
}
}

case class TraversableDescriptor(id: Int, tpe: Type, elem: UDTDescriptor)
extends UDTDescriptor {
override def canBeKey = false
override def flatten = this +: elem.flatten

def getInnermostElem: UDTDescriptor = elem match {
case list: ListDescriptor => list.getInnermostElem
case list: TraversableDescriptor => list.getInnermostElem
case _ => elem
}

override def hashCode() = (id, tpe, elem).hashCode()
override def equals(that: Any) = that match {
case that @ ListDescriptor(thatId, thatTpe, _, thatElem) =>
case that @ TraversableDescriptor(thatId, thatTpe, thatElem) =>
(id, tpe, elem).equals((thatId, thatTpe, thatElem))
case _ => false
}
Expand Down
Loading

0 comments on commit bd66a08

Please sign in to comment.