Skip to content

Commit

Permalink
[FLINK-703] [scala api] Use complete element as join key
Browse files Browse the repository at this point in the history
This closes apache#572
  • Loading branch information
chiwanpark authored and fhueske committed Apr 21, 2015
1 parent 30a74c7 commit 45e680c
Show file tree
Hide file tree
Showing 9 changed files with 236 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ public ExpressionKeys(String[] expressionsIn, TypeInformation<T> type) {
if (!type.isKeyType()) {
throw new InvalidProgramException("This type (" + type + ") cannot be used as key.");
} else if (expressionsIn.length != 1 || !(Keys.ExpressionKeys.SELECT_ALL_CHAR.equals(expressionsIn[0]) || Keys.ExpressionKeys.SELECT_ALL_CHAR_SCALA.equals(expressionsIn[0]))) {
throw new IllegalArgumentException("Field expression for atomic type must be equal to '*' or '_'.");
throw new InvalidProgramException("Field expression for atomic type must be equal to '*' or '_'.");
}

keyFields = new ArrayList<FlatFieldDescriptor>(1);
Expand All @@ -297,7 +297,7 @@ public ExpressionKeys(String[] expressionsIn, TypeInformation<T> type) {
for (int i = 0; i < expressions.length; i++) {
List<FlatFieldDescriptor> keys = cType.getFlatFields(expressions[i]); // use separate list to do a size check
if(keys.size() == 0) {
throw new IllegalArgumentException("Unable to extract key from expression '"+expressions[i]+"' on key "+cType);
throw new InvalidProgramException("Unable to extract key from expression '"+expressions[i]+"' on key "+cType);
}
keyFields.addAll(keys);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -807,7 +807,6 @@ class DataSet[T: ClassTag](set: JavaDataSet[T]) {
* This will not create a new DataSet, it will just attach the field names which will be
* used for grouping when executing a grouped operation.
*
* This only works on CaseClass DataSets.
*/
def groupBy(firstField: String, otherFields: String*): GroupedDataSet[T] = {
new GroupedDataSet[T](
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,6 @@ private[flink] abstract class UnfinishedKeyPairOperation[L, R, O](
* a [[HalfUnfinishedKeyPairOperation]] on which `equalTo` must be called to specify the
* key for the right side. The result after specifying the right side key is the finished
* operation.
*
* This only works on a CaseClass [[DataSet]].
*/
def where(firstLeftField: String, otherLeftFields: String*) = {
val leftKey = new ExpressionKeys[L](
Expand Down Expand Up @@ -113,8 +111,6 @@ private[flink] class HalfUnfinishedKeyPairOperation[L, R, O](
/**
* Specify the key fields for the right side of the key based operation. This returns
* the finished operation.
*
* This only works on a CaseClass [[DataSet]].
*/
def equalTo(firstRightField: String, otherRightFields: String*): O = {
val rightKey = new ExpressionKeys[R](
Expand All @@ -125,7 +121,6 @@ private[flink] class HalfUnfinishedKeyPairOperation[L, R, O](
leftKey + " Right: " + rightKey)
}
unfinished.finish(leftKey, rightKey)

}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -383,5 +383,47 @@ class CoGroupITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mo
env.execute()
expectedResult = "-1,20000,Flink\n" + "-1,10000,Flink\n" + "-1,30000,Flink\n"
}

@Test
def testCoGroupWithAtomic1(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = CollectionDataSets.getSmall3TupleDataSet(env)
val ds2 = env.fromElements(0, 1, 2)
val coGroupDs = ds1.coGroup(ds2).where(0).equalTo("*") {
(first, second, out: Collector[(Int, Long, String)]) =>
for (p <- first) {
for (t <- second) {
if (p._1 == t) {
out.collect(p)
}
}
}
}

coGroupDs.writeAsText(resultPath, writeMode = WriteMode.OVERWRITE)
env.execute()
expectedResult = "(1,1,Hi)\n(2,2,Hello)"
}

@Test
def testCoGroupWithAtomic2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromElements(0, 1, 2)
val ds2 = CollectionDataSets.getSmall3TupleDataSet(env)
val coGroupDs = ds1.coGroup(ds2).where("*").equalTo(0) {
(first, second, out: Collector[(Int, Long, String)]) =>
for (p <- first) {
for (t <- second) {
if (p == t._1) {
out.collect(t)
}
}
}
}

coGroupDs.writeAsText(resultPath, writeMode = WriteMode.OVERWRITE)
env.execute()
expectedResult = "(1,1,Hi)\n(2,2,Hello)"
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
*/
package org.apache.flink.api.scala.operators

import java.util

import org.apache.flink.api.common.InvalidProgramException
import org.apache.flink.api.java.operators.Keys.IncompatibleKeysException
import org.junit.Assert
import org.junit.Test
Expand Down Expand Up @@ -268,6 +271,60 @@ class CoGroupOperatorTest {
// Should not work, more than one field position key
ds1.coGroup(ds2).where(1, 3).equalTo { _.myLong }
}

@Test
def testCoGroupWithAtomic1(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromElements(0, 1, 2)

ds1.coGroup(ds2).where(0).equalTo("*")
}

@Test
def testCoGroupWithAtomic2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromElements(0, 1, 2)
val ds2 = env.fromCollection(emptyTupleData)

ds1.coGroup(ds2).where("*").equalTo(0)
}

@Test(expected = classOf[InvalidProgramException])
def testCoGroupWithInvalidAtomic1(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromElements(0, 1, 2)
val ds2 = env.fromCollection(emptyTupleData)

ds1.coGroup(ds2).where("invalidKey")
}

@Test(expected = classOf[InvalidProgramException])
def testCoGroupWithInvalidAtomic2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromElements(0, 1, 2)

ds1.coGroup(ds2).where(0).equalTo("invalidKey")
}

@Test(expected = classOf[InvalidProgramException])
def testCoGroupWithInvalidAtomic3(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromElements(new util.ArrayList[Integer]())
val ds2 = env.fromElements(0, 0, 0)

ds1.coGroup(ds2).where("*")
}

@Test(expected = classOf[InvalidProgramException])
def testCoGroupWithInvalidAtomic4(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromElements(0, 0, 0)
val ds2 = env.fromElements(new util.ArrayList[Integer]())

ds1.coGroup(ds2).where("*").equalTo("*")
}
}


Original file line number Diff line number Diff line change
Expand Up @@ -743,6 +743,15 @@ class GroupReduceITCase(mode: TestExecutionMode) extends MultipleProgramsTestBas
expected = "b\nccc\nee\n"
}

@Test
def testWithAtomic1: Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromElements(0, 1, 1, 2)
val reduceDs = ds.groupBy("*").reduceGroup((ints: Iterator[Int]) => ints.next())
reduceDs.writeAsText(resultPath, WriteMode.OVERWRITE)
env.execute()
expected = "0\n1\n2"
}
}

@RichGroupReduceFunction.Combinable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
*/
package org.apache.flink.api.scala.operators

import java.util

import org.apache.flink.api.scala.util.CollectionDataSets.CustomType
import org.junit.Assert
import org.apache.flink.api.common.InvalidProgramException
Expand Down Expand Up @@ -96,7 +98,7 @@ class GroupingTest {
}
}

@Test(expected = classOf[IllegalArgumentException])
@Test(expected = classOf[InvalidProgramException])
def testGroupByKeyFields2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val longDs = env.fromCollection(emptyLongData)
Expand Down Expand Up @@ -146,7 +148,7 @@ class GroupingTest {
}
}

@Test(expected = classOf[IllegalArgumentException])
@Test(expected = classOf[InvalidProgramException])
def testGroupByKeyExpressions2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment

Expand Down Expand Up @@ -224,5 +226,37 @@ class GroupingTest {
case e: Exception => Assert.fail()
}
}

@Test
def testAtomicValue1(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromElements(0, 1, 2)

ds.groupBy("*")
}

@Test(expected = classOf[InvalidProgramException])
def testAtomicValueInvalid1(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromElements(0, 1, 2)

ds.groupBy("invalidKey")
}

@Test(expected = classOf[InvalidProgramException])
def testAtomicValueInvalid2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromElements(0, 1, 2)

ds.groupBy("_", "invalidKey")
}

@Test(expected = classOf[InvalidProgramException])
def testAtomicValueInvalid3(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds = env.fromElements(new util.ArrayList[Integer]())

ds.groupBy("*")
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -384,4 +384,26 @@ class JoinITCase(mode: TestExecutionMode) extends MultipleProgramsTestBase(mode)
"2 Second (20,200,2000,Two) 20000,(20000,20,200,2000,Two,2,Second)\n" +
"3 Third (30,300,3000,Three) 30000,(30000,30,300,3000,Three,3,Third)\n"
}

@Test
def testWithAtomic1(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = CollectionDataSets.getSmall3TupleDataSet(env)
val ds2 = env.fromElements(0, 1, 2)
val joinDs = ds1.join(ds2).where(0).equalTo("*")
joinDs.writeAsCsv(resultPath, writeMode = WriteMode.OVERWRITE)
env.execute()
expected = "(1,1,Hi),1\n(2,2,Hello),2"
}

@Test
def testWithAtomic2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromElements(0, 1, 2)
val ds2 = CollectionDataSets.getSmall3TupleDataSet(env)
val joinDs = ds1.join(ds2).where("*").equalTo(0)
joinDs.writeAsCsv(resultPath, writeMode = WriteMode.OVERWRITE)
env.execute()
expected = "1,(1,1,Hi)\n2,(2,2,Hello)"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,13 @@
*/
package org.apache.flink.api.scala.operators

import org.apache.flink.api.java.operators.Keys.IncompatibleKeysException
import org.apache.flink.api.scala.util.CollectionDataSets.CustomType
import org.junit.Assert
import org.apache.flink.api.common.InvalidProgramException
import org.junit.Ignore
import org.junit.Test
import java.util

import org.apache.flink.api.common.InvalidProgramException
import org.apache.flink.api.java.operators.Keys.IncompatibleKeysException
import org.apache.flink.api.scala._
import org.apache.flink.api.scala.util.CollectionDataSets.CustomType
import org.junit.{Assert, Test}

class JoinOperatorTest {

Expand Down Expand Up @@ -272,5 +271,68 @@ class JoinOperatorTest {
// should not work, more than one field position key
ds1.join(ds2).where(1, 3) equalTo { _.myLong }
}

@Test
def testJoinWithAtomic(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromCollection(emptyLongData)

ds1.join(ds2).where(1).equalTo("*")
}

@Test(expected = classOf[InvalidProgramException])
def testJoinWithInvalidAtomic1(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromCollection(emptyLongData)

ds1.join(ds2).where(1).equalTo("invalidKey")
}

@Test(expected = classOf[InvalidProgramException])
def testJoinWithInvalidAtomic2(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyLongData)
val ds2 = env.fromCollection(emptyTupleData)

ds1.join(ds2).where("invalidKey").equalTo(1)
}

@Test(expected = classOf[InvalidProgramException])
def testJoinWithInvalidAtomic3(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyTupleData)
val ds2 = env.fromCollection(emptyLongData)

ds1.join(ds2).where(1).equalTo("_", "invalidKey")
}

@Test(expected = classOf[InvalidProgramException])
def testJoinWithInvalidAtomic4(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyLongData)
val ds2 = env.fromCollection(emptyTupleData)

ds1.join(ds2).where("_", "invalidKey").equalTo(1)
}

@Test(expected = classOf[InvalidProgramException])
def testJoinWithInvalidAtomic5(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromElements(new util.ArrayList[Integer]())
val ds2 = env.fromCollection(emptyLongData)

ds1.join(ds2).where("*")
}

@Test(expected = classOf[InvalidProgramException])
def testJoinWithInvalidAtomic6(): Unit = {
val env = ExecutionEnvironment.getExecutionEnvironment
val ds1 = env.fromCollection(emptyLongData)
val ds2 = env.fromElements(new util.ArrayList[Integer]())

ds1.join(ds2).where("*").equalTo("*")
}
}

0 comments on commit 45e680c

Please sign in to comment.