Skip to content

Commit

Permalink
[FLINK-9470] Allow querying the key in KeyedProcessFunction
Browse files Browse the repository at this point in the history
  • Loading branch information
aljoscha committed Jul 11, 2018
1 parent 53e6657 commit cde504e
Show file tree
Hide file tree
Showing 6 changed files with 121 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,11 @@ public abstract class Context {
* @param value The record to emit.
*/
public abstract <X> void output(OutputTag<X> outputTag, X value);

/**
* Get key of the element being processed.
*/
public abstract K getCurrentKey();
}

/**
Expand All @@ -124,6 +129,7 @@ public abstract class OnTimerContext extends Context {
/**
* Get key of the firing timer.
*/
@Override
public abstract K getCurrentKey();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,12 @@ public abstract class ReadOnlyContext extends BaseBroadcastProcessFunction.ReadO
* A {@link TimerService} for querying time and registering timers.
*/
public abstract TimerService timerService();


/**
* Get key of the element being processed.
*/
public abstract KS getCurrentKey();
}

/**
Expand All @@ -174,6 +180,7 @@ public abstract class OnTimerContext extends ReadOnlyContext {
/**
* Get the key of the firing timer.
*/
@Override
public abstract KS getCurrentKey();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,12 @@ public <X> void output(OutputTag<X> outputTag, X value) {

output.collect(outputTag, new StreamRecord<>(value, element.getTimestamp()));
}

@Override
@SuppressWarnings("unchecked")
public K getCurrentKey() {
return (K) KeyedProcessOperator.this.getCurrentKey();
}
}

private class OnTimerContextImpl extends KeyedProcessFunction<K, IN, OUT>.OnTimerContext {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,13 @@ public <K, V> ReadOnlyBroadcastState<K, V> getBroadcastState(MapStateDescriptor
}
return state;
}

@Override
@SuppressWarnings("unchecked")
public KS getCurrentKey() {
return (KS) CoBroadcastWithKeyedOperator.this.getCurrentKey();
}

}

private class OnTimerContextImpl extends KeyedBroadcastProcessFunction<KS, IN1, IN2, OUT>.OnTimerContext {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeutils.base.IntSerializer;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.streaming.api.TimeDomain;
import org.apache.flink.streaming.api.TimerService;
Expand All @@ -43,6 +44,7 @@
import java.util.concurrent.ConcurrentLinkedQueue;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

/**
* Tests {@link KeyedProcessOperator}.
Expand All @@ -52,6 +54,48 @@ public class KeyedProcessOperatorTest extends TestLogger {
@Rule
public ExpectedException expectedException = ExpectedException.none();

@Test
public void testKeyQuerying() throws Exception {

class KeyQueryingProcessFunction extends KeyedProcessFunction<Integer, Tuple2<Integer, String>, String> {

@Override
public void processElement(
Tuple2<Integer, String> value,
Context ctx,
Collector<String> out) throws Exception {

assertTrue("Did not get expected key.", ctx.getCurrentKey().equals(value.f0));

// we check that we receive this output, to ensure that the assert was actually checked
out.collect(value.f1);
}
}

KeyedProcessOperator<Integer, Tuple2<Integer, String>, String> operator =
new KeyedProcessOperator<>(new KeyQueryingProcessFunction());

try (
OneInputStreamOperatorTestHarness<Tuple2<Integer, String>, String> testHarness =
new KeyedOneInputStreamOperatorTestHarness<>(operator, (in) -> in.f0 , BasicTypeInfo.INT_TYPE_INFO)) {

testHarness.setup();
testHarness.open();

testHarness.processElement(new StreamRecord<>(Tuple2.of(5, "5"), 12L));
testHarness.processElement(new StreamRecord<>(Tuple2.of(42, "42"), 13L));

ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<>();
expectedOutput.add(new StreamRecord<>("5", 12L));
expectedOutput.add(new StreamRecord<>("42", 13L));

TestHarnessUtil.assertOutputEquals(
"Output was not correct.",
expectedOutput,
testHarness.getOutput());
}
}

@Test
public void testTimestampAndWatermarkQuerying() throws Exception {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.flink.api.common.typeinfo.BasicTypeInfo;
import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.runtime.checkpoint.OperatorSubtaskState;
import org.apache.flink.runtime.state.KeyedStateFunction;
import org.apache.flink.streaming.api.functions.co.KeyedBroadcastProcessFunction;
Expand Down Expand Up @@ -70,6 +71,56 @@ public class CoBroadcastWithKeyedOperatorTest {
BasicTypeInfo.INT_TYPE_INFO
);

@Test
public void testKeyQuerying() throws Exception {

class KeyQueryingProcessFunction extends KeyedBroadcastProcessFunction<Integer, Tuple2<Integer, String>, String, String> {

@Override
public void processElement(
Tuple2<Integer, String> value,
ReadOnlyContext ctx,
Collector<String> out) throws Exception {
assertTrue("Did not get expected key.", ctx.getCurrentKey().equals(value.f0));

// we check that we receive this output, to ensure that the assert was actually checked
out.collect(value.f1);

}

@Override
public void processBroadcastElement(
String value,
Context ctx,
Collector<String> out) throws Exception {

}
}

CoBroadcastWithKeyedOperator<Integer, Tuple2<Integer, String>, String, String> operator =
new CoBroadcastWithKeyedOperator<>(new KeyQueryingProcessFunction(), Collections.emptyList());

try (
TwoInputStreamOperatorTestHarness<Tuple2<Integer, String>, String, String> testHarness =
new KeyedTwoInputStreamOperatorTestHarness<>(operator, (in) -> in.f0 , null, BasicTypeInfo.INT_TYPE_INFO)) {

testHarness.setup();
testHarness.open();

testHarness.processElement1(new StreamRecord<>(Tuple2.of(5, "5"), 12L));
testHarness.processElement1(new StreamRecord<>(Tuple2.of(42, "42"), 13L));

ConcurrentLinkedQueue<Object> expectedOutput = new ConcurrentLinkedQueue<>();
expectedOutput.add(new StreamRecord<>("5", 12L));
expectedOutput.add(new StreamRecord<>("42", 13L));

TestHarnessUtil.assertOutputEquals(
"Output was not correct.",
expectedOutput,
testHarness.getOutput());
}
}

/** Test the iteration over the keyed state on the broadcast side. */
@Test
public void testAccessToKeyedStateIt() throws Exception {
Expand Down

0 comments on commit cde504e

Please sign in to comment.