From f562d49dd432d7d540cd624db3f7d83c4b5628fa Mon Sep 17 00:00:00 2001 From: Aljoscha Krettek Date: Wed, 8 Oct 2014 17:45:30 +0200 Subject: [PATCH] Simplify Pojo/Tuple/CaseClass comparator extractKeys() method Also fixes a bug with Java/Scala interop: TupleTypeComparator was only checking for nested Java Tuples and Pojos, not Scala Case classes. --- .../java/typeutils/runtime/PojoComparator.java | 9 +-------- .../java/typeutils/runtime/TupleComparator.java | 9 +-------- .../scala/typeutils/CaseClassComparator.scala | 16 ++++++++++++---- 3 files changed, 14 insertions(+), 20 deletions(-) diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/PojoComparator.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/PojoComparator.java index 51d8090ad2ef9..2cccfcf65b8f2 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/PojoComparator.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/PojoComparator.java @@ -340,14 +340,7 @@ public PojoComparator duplicate() { public int extractKeys(Object record, Object[] target, int index) { int localIndex = index; for (int i = 0; i < comparators.length; i++) { - if(comparators[i] instanceof CompositeTypeComparator) { - localIndex += comparators[i].extractKeys(accessField(keyFields[i], record), target, localIndex) -1; - } else { - // non-composite case (= atomic). We can assume this to have only one key. - // comparators[i].extractKeys(accessField(keyFields[i], record), target, i); - target[localIndex] = accessField(keyFields[i], record); - } - localIndex++; + localIndex += comparators[i].extractKeys(accessField(keyFields[i], record), target, localIndex); } return localIndex - index; } diff --git a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/TupleComparator.java b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/TupleComparator.java index 61a1567c8a66f..89b77945d962e 100644 --- a/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/TupleComparator.java +++ b/flink-java/src/main/java/org/apache/flink/api/java/typeutils/runtime/TupleComparator.java @@ -147,14 +147,7 @@ public void putNormalizedKey(T value, MemorySegment target, int offset, int numB public int extractKeys(Object record, Object[] target, int index) { int localIndex = index; for(int i = 0; i < comparators.length; i++) { - // handle nested case - if(comparators[i] instanceof TupleComparator || comparators[i] instanceof PojoComparator) { - localIndex += comparators[i].extractKeys(((Tuple) record).getField(keyPositions[i]), target, localIndex) -1; - } else { - // flat - target[localIndex] = ((Tuple) record).getField(keyPositions[i]); - } - localIndex++; + localIndex += comparators[i].extractKeys(((Tuple) record).getField(keyPositions[i]), target, localIndex); } return localIndex - index; } diff --git a/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassComparator.scala b/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassComparator.scala index 1353b44420b87..bde009c4bb1a1 100644 --- a/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassComparator.scala +++ b/flink-scala/src/main/scala/org/apache/flink/api/scala/typeutils/CaseClassComparator.scala @@ -17,7 +17,8 @@ */ package org.apache.flink.api.scala.typeutils -import org.apache.flink.api.common.typeutils.{TypeComparator, TypeSerializer} +import org.apache.flink.api.common.typeutils.{CompositeTypeComparator, TypeComparator, +TypeSerializer} import org.apache.flink.api.java.typeutils.runtime.TupleComparatorBase import org.apache.flink.core.memory.MemorySegment import org.apache.flink.types.{KeyFieldOutOfBoundsException, NullKeyFieldException} @@ -140,9 +141,16 @@ class CaseClassComparator[T <: Product]( } def extractKeys(value: AnyRef, target: Array[AnyRef], index: Int) = { - for (i <- 0 until keyPositions.length ) { - target(index + i) = value.asInstanceOf[T].productElement(keyPositions(i)).asInstanceOf[AnyRef] + val in = value.asInstanceOf[T] + + var localIndex: Int = index + for (i <- 0 until comparators.length) { + localIndex += comparators(i).extractKeys( + in.productElement(keyPositions(i)), + target, + localIndex) } - keyPositions.length + + localIndex - index } }