From fa9f8872d633cdce4885d58bbbe011145e3b18be Mon Sep 17 00:00:00 2001 From: Giacomo Gamba Date: Mon, 22 Mar 2021 02:45:22 +0100 Subject: [PATCH 1/7] [FLINK-21696][docs] Add output parameter in WatermarkGenerator scala sections --- .../dev/datastream/event-time/generating_watermarks.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/content/docs/dev/datastream/event-time/generating_watermarks.md b/docs/content/docs/dev/datastream/event-time/generating_watermarks.md index 4f41236732604c..f41df5fa3df77a 100644 --- a/docs/content/docs/dev/datastream/event-time/generating_watermarks.md +++ b/docs/content/docs/dev/datastream/event-time/generating_watermarks.md @@ -325,11 +325,11 @@ class BoundedOutOfOrdernessGenerator extends WatermarkGenerator[MyEvent] { var currentMaxTimestamp: Long = _ - override def onEvent(element: MyEvent, eventTimestamp: Long): Unit = { + override def onEvent(element: MyEvent, eventTimestamp: Long, output: WatermarkOutput): Unit = { currentMaxTimestamp = max(eventTimestamp, currentMaxTimestamp) } - override def onPeriodicEmit(): Unit = { + override def onPeriodicEmit(output: WatermarkOutput): Unit = { // emit the watermark as current highest timestamp minus the out-of-orderness bound output.emitWatermark(new Watermark(currentMaxTimestamp - maxOutOfOrderness - 1)); } @@ -344,11 +344,11 @@ class TimeLagWatermarkGenerator extends WatermarkGenerator[MyEvent] { val maxTimeLag = 5000L // 5 seconds - override def onEvent(element: MyEvent, eventTimestamp: Long): Unit = { + override def onEvent(element: MyEvent, eventTimestamp: Long, output: WatermarkOutput): Unit = { // don't need to do anything because we work on processing time } - override def onPeriodicEmit(): Unit = { + override def onPeriodicEmit(output: WatermarkOutput): Unit = { output.emitWatermark(new Watermark(System.currentTimeMillis() - maxTimeLag)); } } From a365bc5dae6bfb698b220c98cbdd5c91a2ad78f8 Mon Sep 17 00:00:00 2001 From: Dian Fu Date: Mon, 22 Mar 2021 18:16:36 +0800 Subject: [PATCH 2/7] [FLINK-21937][python] Support batch mode in Python DataStream API for basic operations This closes #15349. --- flink-python/pyflink/common/configuration.py | 2 +- .../pyflink/common/execution_config.py | 2 +- .../pyflink/common/restart_strategy.py | 2 +- flink-python/pyflink/common/serialization.py | 2 +- .../pyflink/dataset/execution_environment.py | 2 +- flink-python/pyflink/datastream/connectors.py | 2 +- .../pyflink/datastream/state_backend.py | 2 +- .../stream_execution_environment.py | 22 +- .../datastream/tests/test_data_stream.py | 628 ++++++++++-------- .../datastream/tests/test_state_backend.py | 2 +- .../test_stream_execution_environment.py | 12 - flink-python/pyflink/table/expression.py | 2 +- flink-python/pyflink/table/expressions.py | 2 +- flink-python/pyflink/table/sinks.py | 9 +- flink-python/pyflink/table/statement_set.py | 2 +- flink-python/pyflink/table/table.py | 4 +- .../pyflink/table/table_environment.py | 6 +- flink-python/pyflink/table/table_schema.py | 2 +- .../table/tests/test_table_environment_api.py | 2 +- flink-python/pyflink/table/types.py | 2 +- flink-python/pyflink/table/udf.py | 10 +- flink-python/pyflink/table/utils.py | 2 +- .../pyflink/testing/source_sink_utils.py | 9 +- .../pyflink/testing/test_case_utils.py | 27 +- .../pyflink/util/{utils.py => java_utils.py} | 25 +- .../flink/python/util/PythonConfigUtil.java | 4 - 26 files changed, 435 insertions(+), 351 deletions(-) rename flink-python/pyflink/util/{utils.py => java_utils.py} (90%) diff --git a/flink-python/pyflink/common/configuration.py b/flink-python/pyflink/common/configuration.py index 15579c6da88998..54e82962f4321a 100644 --- a/flink-python/pyflink/common/configuration.py +++ b/flink-python/pyflink/common/configuration.py @@ -20,7 +20,7 @@ from py4j.java_gateway import JavaObject from pyflink.java_gateway import get_gateway -from pyflink.util.utils import add_jars_to_context_class_loader +from pyflink.util.java_utils import add_jars_to_context_class_loader class Configuration: diff --git a/flink-python/pyflink/common/execution_config.py b/flink-python/pyflink/common/execution_config.py index 4fcf7b9e31eebe..4e7ceae37a4ba5 100644 --- a/flink-python/pyflink/common/execution_config.py +++ b/flink-python/pyflink/common/execution_config.py @@ -23,7 +23,7 @@ from pyflink.common.input_dependency_constraint import InputDependencyConstraint from pyflink.common.restart_strategy import RestartStrategies, RestartStrategyConfiguration from pyflink.java_gateway import get_gateway -from pyflink.util.utils import load_java_class +from pyflink.util.java_utils import load_java_class __all__ = ['ExecutionConfig'] diff --git a/flink-python/pyflink/common/restart_strategy.py b/flink-python/pyflink/common/restart_strategy.py index ac4f334660da8c..b1c4621bf3b1d4 100644 --- a/flink-python/pyflink/common/restart_strategy.py +++ b/flink-python/pyflink/common/restart_strategy.py @@ -22,7 +22,7 @@ from py4j.java_gateway import get_java_class from pyflink.java_gateway import get_gateway -from pyflink.util.utils import to_j_flink_time, from_j_flink_time +from pyflink.util.java_utils import to_j_flink_time, from_j_flink_time __all__ = ['RestartStrategies', 'RestartStrategyConfiguration'] diff --git a/flink-python/pyflink/common/serialization.py b/flink-python/pyflink/common/serialization.py index ab276632135738..7c473e9701542d 100644 --- a/flink-python/pyflink/common/serialization.py +++ b/flink-python/pyflink/common/serialization.py @@ -19,7 +19,7 @@ from pyflink.common import typeinfo from pyflink.common.typeinfo import TypeInformation -from pyflink.util.utils import load_java_class +from pyflink.util.java_utils import load_java_class from pyflink.java_gateway import get_gateway from typing import Union diff --git a/flink-python/pyflink/dataset/execution_environment.py b/flink-python/pyflink/dataset/execution_environment.py index 52fa72e14397ed..821189c5d6a794 100644 --- a/flink-python/pyflink/dataset/execution_environment.py +++ b/flink-python/pyflink/dataset/execution_environment.py @@ -19,7 +19,7 @@ from pyflink.common.job_execution_result import JobExecutionResult from pyflink.common.restart_strategy import RestartStrategies, RestartStrategyConfiguration from pyflink.java_gateway import get_gateway -from pyflink.util.utils import load_java_class +from pyflink.util.java_utils import load_java_class class ExecutionEnvironment(object): diff --git a/flink-python/pyflink/datastream/connectors.py b/flink-python/pyflink/datastream/connectors.py index dcd7e93af7bdd4..4919d593147082 100644 --- a/flink-python/pyflink/datastream/connectors.py +++ b/flink-python/pyflink/datastream/connectors.py @@ -24,7 +24,7 @@ from pyflink.common.typeinfo import RowTypeInfo, TypeInformation from pyflink.datastream.functions import SourceFunction, SinkFunction from pyflink.java_gateway import get_gateway -from pyflink.util.utils import load_java_class, to_jarray +from pyflink.util.java_utils import load_java_class, to_jarray from py4j.java_gateway import java_import diff --git a/flink-python/pyflink/datastream/state_backend.py b/flink-python/pyflink/datastream/state_backend.py index 70387d66c314f9..f3c74b968fc748 100644 --- a/flink-python/pyflink/datastream/state_backend.py +++ b/flink-python/pyflink/datastream/state_backend.py @@ -23,7 +23,7 @@ from typing import List, Optional from pyflink.java_gateway import get_gateway -from pyflink.util.utils import load_java_class +from pyflink.util.java_utils import load_java_class __all__ = [ 'StateBackend', diff --git a/flink-python/pyflink/datastream/stream_execution_environment.py b/flink-python/pyflink/datastream/stream_execution_environment.py index cf8afd860be818..bf800a9c2c5a7e 100644 --- a/flink-python/pyflink/datastream/stream_execution_environment.py +++ b/flink-python/pyflink/datastream/stream_execution_environment.py @@ -36,7 +36,7 @@ from pyflink.datastream.time_characteristic import TimeCharacteristic from pyflink.java_gateway import get_gateway from pyflink.serializers import PickleSerializer -from pyflink.util.utils import load_java_class, add_jars_to_context_class_loader +from pyflink.util.java_utils import load_java_class, add_jars_to_context_class_loader, invoke_method __all__ = ['StreamExecutionEnvironment'] @@ -768,10 +768,22 @@ def _from_collection(self, elements: List[Any], execution_config ) - j_data_stream_source = self._j_stream_execution_environment.createInput( - j_input_format, - out_put_type_info.get_java_type_info() - ) + JInputFormatSourceFunction = gateway.jvm.org.apache.flink.streaming.api.functions.\ + source.InputFormatSourceFunction + JBoundedness = gateway.jvm.org.apache.flink.api.connector.source.Boundedness + + j_data_stream_source = invoke_method( + self._j_stream_execution_environment, + "org.apache.flink.streaming.api.environment.StreamExecutionEnvironment", + "addSource", + [JInputFormatSourceFunction(j_input_format, out_put_type_info.get_java_type_info()), + "Collection Source", + out_put_type_info.get_java_type_info(), + JBoundedness.BOUNDED], + ["org.apache.flink.streaming.api.functions.source.SourceFunction", + "java.lang.String", + "org.apache.flink.api.common.typeinfo.TypeInformation", + "org.apache.flink.api.connector.source.Boundedness"]) j_data_stream_source.forceNonParallel() return DataStream(j_data_stream=j_data_stream_source) finally: diff --git a/flink-python/pyflink/datastream/tests/test_data_stream.py b/flink-python/pyflink/datastream/tests/test_data_stream.py index 82b07d80a4ebef..2209370136148c 100644 --- a/flink-python/pyflink/datastream/tests/test_data_stream.py +++ b/flink-python/pyflink/datastream/tests/test_data_stream.py @@ -23,7 +23,7 @@ from pyflink.common import Row from pyflink.common.typeinfo import Types from pyflink.common.watermark_strategy import WatermarkStrategy, TimestampAssigner -from pyflink.datastream import StreamExecutionEnvironment, TimeCharacteristic, RuntimeContext +from pyflink.datastream import TimeCharacteristic, RuntimeContext from pyflink.datastream.data_stream import DataStream from pyflink.datastream.functions import CoMapFunction, CoFlatMapFunction, AggregateFunction, \ ReduceFunction @@ -35,46 +35,21 @@ AggregatingStateDescriptor from pyflink.datastream.tests.test_util import DataStreamTestSinkFunction from pyflink.java_gateway import get_gateway -from pyflink.testing.test_case_utils import PyFlinkTestCase, invoke_java_object_method +from pyflink.testing.test_case_utils import invoke_java_object_method, \ + PyFlinkBatchTestCase, PyFlinkStreamingTestCase -class DataStreamTests(PyFlinkTestCase): +class DataStreamTests(object): def setUp(self) -> None: - self.env = StreamExecutionEnvironment.get_execution_environment() - self.env.set_parallelism(2) - getConfigurationMethod = invoke_java_object_method( + super(DataStreamTests, self).setUp() + config = invoke_java_object_method( self.env._j_stream_execution_environment, "getConfiguration") - getConfigurationMethod.setString("akka.ask.timeout", "20 s") + config.setString("akka.ask.timeout", "20 s") self.test_sink = DataStreamTestSinkFunction() - def test_data_stream_name(self): - ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')]) - test_name = 'test_name' - ds.name(test_name) - self.assertEqual(test_name, ds.get_name()) - - def test_set_parallelism(self): - parallelism = 3 - ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')]).map(lambda x: x) - ds.set_parallelism(parallelism).add_sink(self.test_sink) - plan = eval(str(self.env.get_execution_plan())) - self.assertEqual(parallelism, plan['nodes'][1]['parallelism']) - - def test_set_max_parallelism(self): - max_parallelism = 4 - self.env.set_parallelism(8) - ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')]).map(lambda x: x) - ds.set_parallelism(max_parallelism).add_sink(self.test_sink) - plan = eval(str(self.env.get_execution_plan())) - self.assertEqual(max_parallelism, plan['nodes'][1]['parallelism']) - - def test_force_non_parallel(self): - self.env.set_parallelism(8) - ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')]) - ds.force_non_parallel().add_sink(self.test_sink) - plan = eval(str(self.env.get_execution_plan())) - self.assertEqual(1, plan['nodes'][0]['parallelism']) + def tearDown(self) -> None: + self.test_sink.clear() def test_reduce_function_without_data_types(self): ds = self.env.from_collection([(1, 'a'), (2, 'a'), (3, 'a'), (4, 'b')], @@ -575,191 +550,6 @@ def test_print_with_align_output(self): self.assertEqual(3, len(plan['nodes'])) self.assertEqual("Sink: Print to Std. Out", plan['nodes'][2]['type']) - def test_union_stream(self): - ds_1 = self.env.from_collection([1, 2, 3]) - ds_2 = self.env.from_collection([4, 5, 6]) - ds_3 = self.env.from_collection([7, 8, 9]) - - united_stream = ds_3.union(ds_1, ds_2) - - united_stream.map(lambda x: x + 1).add_sink(self.test_sink) - exec_plan = eval(self.env.get_execution_plan()) - source_ids = [] - union_node_pre_ids = [] - for node in exec_plan['nodes']: - if node['pact'] == 'Data Source': - source_ids.append(node['id']) - if node['pact'] == 'Operator': - for pre in node['predecessors']: - union_node_pre_ids.append(pre['id']) - - source_ids.sort() - union_node_pre_ids.sort() - self.assertEqual(source_ids, union_node_pre_ids) - - def test_project(self): - ds = self.env.from_collection([[1, 2, 3, 4], [5, 6, 7, 8]], - type_info=Types.TUPLE( - [Types.INT(), Types.INT(), Types.INT(), Types.INT()])) - ds.project(1, 3).map(lambda x: (x[0], x[1] + 1)).add_sink(self.test_sink) - exec_plan = eval(self.env.get_execution_plan()) - self.assertEqual(exec_plan['nodes'][1]['type'], 'Projection') - - def test_broadcast(self): - ds_1 = self.env.from_collection([1, 2, 3]) - ds_1.broadcast().map(lambda x: x + 1).set_parallelism(3).add_sink(self.test_sink) - exec_plan = eval(self.env.get_execution_plan()) - broadcast_node = exec_plan['nodes'][1] - pre_ship_strategy = broadcast_node['predecessors'][0]['ship_strategy'] - self.assertEqual(pre_ship_strategy, 'BROADCAST') - - def test_rebalance(self): - ds_1 = self.env.from_collection([1, 2, 3]) - ds_1.rebalance().map(lambda x: x + 1).set_parallelism(3).add_sink(self.test_sink) - exec_plan = eval(self.env.get_execution_plan()) - rebalance_node = exec_plan['nodes'][1] - pre_ship_strategy = rebalance_node['predecessors'][0]['ship_strategy'] - self.assertEqual(pre_ship_strategy, 'REBALANCE') - - def test_rescale(self): - ds_1 = self.env.from_collection([1, 2, 3]) - ds_1.rescale().map(lambda x: x + 1).set_parallelism(3).add_sink(self.test_sink) - exec_plan = eval(self.env.get_execution_plan()) - rescale_node = exec_plan['nodes'][1] - pre_ship_strategy = rescale_node['predecessors'][0]['ship_strategy'] - self.assertEqual(pre_ship_strategy, 'RESCALE') - - def test_shuffle(self): - ds_1 = self.env.from_collection([1, 2, 3]) - ds_1.shuffle().map(lambda x: x + 1).set_parallelism(3).add_sink(self.test_sink) - exec_plan = eval(self.env.get_execution_plan()) - shuffle_node = exec_plan['nodes'][1] - pre_ship_strategy = shuffle_node['predecessors'][0]['ship_strategy'] - self.assertEqual(pre_ship_strategy, 'SHUFFLE') - - def test_partition_custom(self): - ds = self.env.from_collection([('a', 0), ('b', 0), ('c', 1), ('d', 1), ('e', 2), - ('f', 7), ('g', 7), ('h', 8), ('i', 8), ('j', 9)], - type_info=Types.ROW([Types.STRING(), Types.INT()])) - - expected_num_partitions = 5 - - def my_partitioner(key, num_partitions): - assert expected_num_partitions, num_partitions - return key % num_partitions - - partitioned_stream = ds.map(lambda x: x, output_type=Types.ROW([Types.STRING(), - Types.INT()]))\ - .set_parallelism(4).partition_custom(my_partitioner, lambda x: x[1]) - - JPartitionCustomTestMapFunction = get_gateway().jvm\ - .org.apache.flink.python.util.PartitionCustomTestMapFunction - test_map_stream = DataStream(partitioned_stream - ._j_data_stream.map(JPartitionCustomTestMapFunction())) - test_map_stream.set_parallelism(expected_num_partitions).add_sink(self.test_sink) - - self.env.execute('test_partition_custom') - - def test_keyed_stream_partitioning(self): - ds = self.env.from_collection([('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)]) - keyed_stream = ds.key_by(lambda x: x[1]) - with self.assertRaises(Exception): - keyed_stream.shuffle() - - with self.assertRaises(Exception): - keyed_stream.rebalance() - - with self.assertRaises(Exception): - keyed_stream.rescale() - - with self.assertRaises(Exception): - keyed_stream.broadcast() - - with self.assertRaises(Exception): - keyed_stream.forward() - - def test_slot_sharing_group(self): - source_operator_name = 'collection source' - map_operator_name = 'map_operator' - slot_sharing_group_1 = 'slot_sharing_group_1' - slot_sharing_group_2 = 'slot_sharing_group_2' - ds_1 = self.env.from_collection([1, 2, 3]).name(source_operator_name) - ds_1.slot_sharing_group(slot_sharing_group_1).map(lambda x: x + 1).set_parallelism(3)\ - .name(map_operator_name).slot_sharing_group(slot_sharing_group_2)\ - .add_sink(self.test_sink) - - j_generated_stream_graph = self.env._j_stream_execution_environment \ - .getStreamGraph("test start new_chain", True) - - j_stream_nodes = list(j_generated_stream_graph.getStreamNodes().toArray()) - for j_stream_node in j_stream_nodes: - if j_stream_node.getOperatorName() == source_operator_name: - self.assertEqual(j_stream_node.getSlotSharingGroup(), slot_sharing_group_1) - elif j_stream_node.getOperatorName() == map_operator_name: - self.assertEqual(j_stream_node.getSlotSharingGroup(), slot_sharing_group_2) - - def test_chaining_strategy(self): - chained_operator_name_0 = "map_operator_0" - chained_operator_name_1 = "map_operator_1" - chained_operator_name_2 = "map_operator_2" - - ds = self.env.from_collection([1, 2, 3]) - ds.map(lambda x: x).set_parallelism(2).name(chained_operator_name_0)\ - .map(lambda x: x).set_parallelism(2).name(chained_operator_name_1)\ - .map(lambda x: x).set_parallelism(2).name(chained_operator_name_2)\ - .add_sink(self.test_sink) - - def assert_chainable(j_stream_graph, expected_upstream_chainable, - expected_downstream_chainable): - j_stream_nodes = list(j_stream_graph.getStreamNodes().toArray()) - for j_stream_node in j_stream_nodes: - if j_stream_node.getOperatorName() == chained_operator_name_1: - JStreamingJobGraphGenerator = get_gateway().jvm \ - .org.apache.flink.streaming.api.graph.StreamingJobGraphGenerator - - j_in_stream_edge = j_stream_node.getInEdges().get(0) - upstream_chainable = JStreamingJobGraphGenerator.isChainable(j_in_stream_edge, - j_stream_graph) - self.assertEqual(expected_upstream_chainable, upstream_chainable) - - j_out_stream_edge = j_stream_node.getOutEdges().get(0) - downstream_chainable = JStreamingJobGraphGenerator.isChainable( - j_out_stream_edge, j_stream_graph) - self.assertEqual(expected_downstream_chainable, downstream_chainable) - - # The map_operator_1 has the same parallelism with map_operator_0 and map_operator_2, and - # ship_strategy for map_operator_0 and map_operator_1 is FORWARD, so the map_operator_1 - # can be chained with map_operator_0 and map_operator_2. - j_generated_stream_graph = self.env._j_stream_execution_environment\ - .getStreamGraph("test start new_chain", True) - assert_chainable(j_generated_stream_graph, True, True) - - ds = self.env.from_collection([1, 2, 3]) - # Start a new chain for map_operator_1 - ds.map(lambda x: x).set_parallelism(2).name(chained_operator_name_0) \ - .map(lambda x: x).set_parallelism(2).name(chained_operator_name_1).start_new_chain() \ - .map(lambda x: x).set_parallelism(2).name(chained_operator_name_2) \ - .add_sink(self.test_sink) - - j_generated_stream_graph = self.env._j_stream_execution_environment \ - .getStreamGraph("test start new_chain", True) - # We start a new chain for map operator, therefore, it cannot be chained with upstream - # operator, but can be chained with downstream operator. - assert_chainable(j_generated_stream_graph, False, True) - - ds = self.env.from_collection([1, 2, 3]) - # Disable chaining for map_operator_1 - ds.map(lambda x: x).set_parallelism(2).name(chained_operator_name_0) \ - .map(lambda x: x).set_parallelism(2).name(chained_operator_name_1).disable_chaining() \ - .map(lambda x: x).set_parallelism(2).name(chained_operator_name_2) \ - .add_sink(self.test_sink) - - j_generated_stream_graph = self.env._j_stream_execution_environment \ - .getStreamGraph("test start new_chain", True) - # We disable chaining for map_operator_1, therefore, it cannot be chained with - # upstream and downstream operators. - assert_chainable(j_generated_stream_graph, False, False) - def test_primitive_array_type_info(self): ds = self.env.from_collection([(1, [1.1, 1.2, 1.30]), (2, [2.1, 2.2, 2.3]), (3, [3.1, 3.2, 3.3])], @@ -814,7 +604,7 @@ def test_sql_timestamp_type_info(self): expected = ['+I[2021-01-09, 12:00:00, 2021-01-09 12:00:00.011]'] self.assertEqual(expected, results) - def test_timestamp_assigner_and_watermark_strategy(self): + def test_process_function(self): self.env.set_parallelism(1) self.env.get_config().set_auto_watermark_interval(2000) self.env.set_stream_time_characteristic(TimeCharacteristic.EventTime) @@ -829,99 +619,52 @@ class MyTimestampAssigner(TimestampAssigner): def extract_timestamp(self, value, record_timestamp) -> int: return int(value[1]) - class MyProcessFunction(KeyedProcessFunction): + class MyProcessFunction(ProcessFunction): def process_element(self, value, ctx): current_timestamp = ctx.timestamp() current_watermark = ctx.timer_service().current_watermark() - current_key = ctx.get_current_key() - yield "current key: {}, current timestamp: {}, current watermark: {}, " \ - "current_value: {}".format(str(current_key), str(current_timestamp), - str(current_watermark), str(value)) + yield "current timestamp: {}, current watermark: {}, current_value: {}"\ + .format(str(current_timestamp), str(current_watermark), str(value)) - def on_timer(self, timestamp, ctx): + def on_timer(self, timestamp, ctx, out): pass watermark_strategy = WatermarkStrategy.for_monotonous_timestamps()\ .with_timestamp_assigner(MyTimestampAssigner()) data_stream.assign_timestamps_and_watermarks(watermark_strategy)\ - .key_by(lambda x: x[0], key_type_info=Types.INT()) \ .process(MyProcessFunction(), output_type=Types.STRING()).add_sink(self.test_sink) - self.env.execute('test time stamp assigner with keyed process function') + self.env.execute('test process function') result = self.test_sink.get_results() - expected_result = ["current key: 1, current timestamp: 1603708211000, current watermark: " + expected_result = ["current timestamp: 1603708211000, current watermark: " "9223372036854775807, current_value: Row(f0=1, f1='1603708211000')", - "current key: 2, current timestamp: 1603708224000, current watermark: " + "current timestamp: 1603708224000, current watermark: " "9223372036854775807, current_value: Row(f0=2, f1='1603708224000')", - "current key: 3, current timestamp: 1603708226000, current watermark: " + "current timestamp: 1603708226000, current watermark: " "9223372036854775807, current_value: Row(f0=3, f1='1603708226000')", - "current key: 4, current timestamp: 1603708289000, current watermark: " + "current timestamp: 1603708289000, current watermark: " "9223372036854775807, current_value: Row(f0=4, f1='1603708289000')"] result.sort() expected_result.sort() self.assertEqual(expected_result, result) - def test_process_function(self): + def test_keyed_process_function_with_state(self): self.env.set_parallelism(1) self.env.get_config().set_auto_watermark_interval(2000) self.env.set_stream_time_characteristic(TimeCharacteristic.EventTime) - data_stream = self.env.from_collection([(1, '1603708211000'), - (2, '1603708224000'), - (3, '1603708226000'), - (4, '1603708289000')], - type_info=Types.ROW([Types.INT(), Types.STRING()])) + data_stream = self.env.from_collection([(1, 'hi', '1603708211000'), + (2, 'hello', '1603708224000'), + (3, 'hi', '1603708226000'), + (4, 'hello', '1603708289000'), + (5, 'hi', '1603708291000'), + (6, 'hello', '1603708293000')], + type_info=Types.ROW([Types.INT(), Types.STRING(), + Types.STRING()])) class MyTimestampAssigner(TimestampAssigner): def extract_timestamp(self, value, record_timestamp) -> int: - return int(value[1]) - - class MyProcessFunction(ProcessFunction): - - def process_element(self, value, ctx): - current_timestamp = ctx.timestamp() - current_watermark = ctx.timer_service().current_watermark() - yield "current timestamp: {}, current watermark: {}, current_value: {}"\ - .format(str(current_timestamp), str(current_watermark), str(value)) - - def on_timer(self, timestamp, ctx, out): - pass - - watermark_strategy = WatermarkStrategy.for_monotonous_timestamps()\ - .with_timestamp_assigner(MyTimestampAssigner()) - data_stream.assign_timestamps_and_watermarks(watermark_strategy)\ - .process(MyProcessFunction(), output_type=Types.STRING()).add_sink(self.test_sink) - self.env.execute('test process function') - result = self.test_sink.get_results() - expected_result = ["current timestamp: 1603708211000, current watermark: " - "9223372036854775807, current_value: Row(f0=1, f1='1603708211000')", - "current timestamp: 1603708224000, current watermark: " - "9223372036854775807, current_value: Row(f0=2, f1='1603708224000')", - "current timestamp: 1603708226000, current watermark: " - "9223372036854775807, current_value: Row(f0=3, f1='1603708226000')", - "current timestamp: 1603708289000, current watermark: " - "9223372036854775807, current_value: Row(f0=4, f1='1603708289000')"] - result.sort() - expected_result.sort() - self.assertEqual(expected_result, result) - - def test_keyed_process_function_with_state(self): - self.env.set_parallelism(1) - self.env.get_config().set_auto_watermark_interval(2000) - self.env.set_stream_time_characteristic(TimeCharacteristic.EventTime) - data_stream = self.env.from_collection([(1, 'hi', '1603708211000'), - (2, 'hello', '1603708224000'), - (3, 'hi', '1603708226000'), - (4, 'hello', '1603708289000'), - (5, 'hi', '1603708291000'), - (6, 'hello', '1603708293000')], - type_info=Types.ROW([Types.INT(), Types.STRING(), - Types.STRING()])) - - class MyTimestampAssigner(TimestampAssigner): - - def extract_timestamp(self, value, record_timestamp) -> int: - return int(value[2]) + return int(value[2]) class MyProcessFunction(KeyedProcessFunction): @@ -1063,8 +806,317 @@ def process_element(self, value, ctx): expected_result.sort() self.assertEqual(expected_result, result) - def tearDown(self) -> None: - self.test_sink.clear() + +class StreamingModeDataStreamTests(DataStreamTests, PyFlinkStreamingTestCase): + def test_data_stream_name(self): + ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')]) + test_name = 'test_name' + ds.name(test_name) + self.assertEqual(test_name, ds.get_name()) + + def test_set_parallelism(self): + parallelism = 3 + ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')]).map(lambda x: x) + ds.set_parallelism(parallelism).add_sink(self.test_sink) + plan = eval(str(self.env.get_execution_plan())) + self.assertEqual(parallelism, plan['nodes'][1]['parallelism']) + + def test_set_max_parallelism(self): + max_parallelism = 4 + self.env.set_parallelism(8) + ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')]).map(lambda x: x) + ds.set_parallelism(max_parallelism).add_sink(self.test_sink) + plan = eval(str(self.env.get_execution_plan())) + self.assertEqual(max_parallelism, plan['nodes'][1]['parallelism']) + + def test_force_non_parallel(self): + self.env.set_parallelism(8) + ds = self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')]) + ds.force_non_parallel().add_sink(self.test_sink) + plan = eval(str(self.env.get_execution_plan())) + self.assertEqual(1, plan['nodes'][0]['parallelism']) + + def test_union_stream(self): + ds_1 = self.env.from_collection([1, 2, 3]) + ds_2 = self.env.from_collection([4, 5, 6]) + ds_3 = self.env.from_collection([7, 8, 9]) + + united_stream = ds_3.union(ds_1, ds_2) + + united_stream.map(lambda x: x + 1).add_sink(self.test_sink) + exec_plan = eval(self.env.get_execution_plan()) + source_ids = [] + union_node_pre_ids = [] + for node in exec_plan['nodes']: + if node['pact'] == 'Data Source': + source_ids.append(node['id']) + if node['pact'] == 'Operator': + for pre in node['predecessors']: + union_node_pre_ids.append(pre['id']) + + source_ids.sort() + union_node_pre_ids.sort() + self.assertEqual(source_ids, union_node_pre_ids) + + def test_project(self): + ds = self.env.from_collection([[1, 2, 3, 4], [5, 6, 7, 8]], + type_info=Types.TUPLE( + [Types.INT(), Types.INT(), Types.INT(), Types.INT()])) + ds.project(1, 3).map(lambda x: (x[0], x[1] + 1)).add_sink(self.test_sink) + exec_plan = eval(self.env.get_execution_plan()) + self.assertEqual(exec_plan['nodes'][1]['type'], 'Projection') + + def test_broadcast(self): + ds_1 = self.env.from_collection([1, 2, 3]) + ds_1.broadcast().map(lambda x: x + 1).set_parallelism(3).add_sink(self.test_sink) + exec_plan = eval(self.env.get_execution_plan()) + broadcast_node = exec_plan['nodes'][1] + pre_ship_strategy = broadcast_node['predecessors'][0]['ship_strategy'] + self.assertEqual(pre_ship_strategy, 'BROADCAST') + + def test_rebalance(self): + ds_1 = self.env.from_collection([1, 2, 3]) + ds_1.rebalance().map(lambda x: x + 1).set_parallelism(3).add_sink(self.test_sink) + exec_plan = eval(self.env.get_execution_plan()) + rebalance_node = exec_plan['nodes'][1] + pre_ship_strategy = rebalance_node['predecessors'][0]['ship_strategy'] + self.assertEqual(pre_ship_strategy, 'REBALANCE') + + def test_rescale(self): + ds_1 = self.env.from_collection([1, 2, 3]) + ds_1.rescale().map(lambda x: x + 1).set_parallelism(3).add_sink(self.test_sink) + exec_plan = eval(self.env.get_execution_plan()) + rescale_node = exec_plan['nodes'][1] + pre_ship_strategy = rescale_node['predecessors'][0]['ship_strategy'] + self.assertEqual(pre_ship_strategy, 'RESCALE') + + def test_shuffle(self): + ds_1 = self.env.from_collection([1, 2, 3]) + ds_1.shuffle().map(lambda x: x + 1).set_parallelism(3).add_sink(self.test_sink) + exec_plan = eval(self.env.get_execution_plan()) + shuffle_node = exec_plan['nodes'][1] + pre_ship_strategy = shuffle_node['predecessors'][0]['ship_strategy'] + self.assertEqual(pre_ship_strategy, 'SHUFFLE') + + def test_partition_custom(self): + ds = self.env.from_collection([('a', 0), ('b', 0), ('c', 1), ('d', 1), ('e', 2), + ('f', 7), ('g', 7), ('h', 8), ('i', 8), ('j', 9)], + type_info=Types.ROW([Types.STRING(), Types.INT()])) + + expected_num_partitions = 5 + + def my_partitioner(key, num_partitions): + assert expected_num_partitions, num_partitions + return key % num_partitions + + partitioned_stream = ds.map(lambda x: x, output_type=Types.ROW([Types.STRING(), + Types.INT()]))\ + .set_parallelism(4).partition_custom(my_partitioner, lambda x: x[1]) + + JPartitionCustomTestMapFunction = get_gateway().jvm\ + .org.apache.flink.python.util.PartitionCustomTestMapFunction + test_map_stream = DataStream(partitioned_stream + ._j_data_stream.map(JPartitionCustomTestMapFunction())) + test_map_stream.set_parallelism(expected_num_partitions).add_sink(self.test_sink) + + self.env.execute('test_partition_custom') + + def test_keyed_stream_partitioning(self): + ds = self.env.from_collection([('ab', 1), ('bdc', 2), ('cfgs', 3), ('deeefg', 4)]) + keyed_stream = ds.key_by(lambda x: x[1]) + with self.assertRaises(Exception): + keyed_stream.shuffle() + + with self.assertRaises(Exception): + keyed_stream.rebalance() + + with self.assertRaises(Exception): + keyed_stream.rescale() + + with self.assertRaises(Exception): + keyed_stream.broadcast() + + with self.assertRaises(Exception): + keyed_stream.forward() + + def test_slot_sharing_group(self): + source_operator_name = 'collection source' + map_operator_name = 'map_operator' + slot_sharing_group_1 = 'slot_sharing_group_1' + slot_sharing_group_2 = 'slot_sharing_group_2' + ds_1 = self.env.from_collection([1, 2, 3]).name(source_operator_name) + ds_1.slot_sharing_group(slot_sharing_group_1).map(lambda x: x + 1).set_parallelism(3)\ + .name(map_operator_name).slot_sharing_group(slot_sharing_group_2)\ + .add_sink(self.test_sink) + + j_generated_stream_graph = self.env._j_stream_execution_environment \ + .getStreamGraph("test start new_chain", True) + + j_stream_nodes = list(j_generated_stream_graph.getStreamNodes().toArray()) + for j_stream_node in j_stream_nodes: + if j_stream_node.getOperatorName() == source_operator_name: + self.assertEqual(j_stream_node.getSlotSharingGroup(), slot_sharing_group_1) + elif j_stream_node.getOperatorName() == map_operator_name: + self.assertEqual(j_stream_node.getSlotSharingGroup(), slot_sharing_group_2) + + def test_chaining_strategy(self): + chained_operator_name_0 = "map_operator_0" + chained_operator_name_1 = "map_operator_1" + chained_operator_name_2 = "map_operator_2" + + ds = self.env.from_collection([1, 2, 3]) + ds.map(lambda x: x).set_parallelism(2).name(chained_operator_name_0)\ + .map(lambda x: x).set_parallelism(2).name(chained_operator_name_1)\ + .map(lambda x: x).set_parallelism(2).name(chained_operator_name_2)\ + .add_sink(self.test_sink) + + def assert_chainable(j_stream_graph, expected_upstream_chainable, + expected_downstream_chainable): + j_stream_nodes = list(j_stream_graph.getStreamNodes().toArray()) + for j_stream_node in j_stream_nodes: + if j_stream_node.getOperatorName() == chained_operator_name_1: + JStreamingJobGraphGenerator = get_gateway().jvm \ + .org.apache.flink.streaming.api.graph.StreamingJobGraphGenerator + + j_in_stream_edge = j_stream_node.getInEdges().get(0) + upstream_chainable = JStreamingJobGraphGenerator.isChainable(j_in_stream_edge, + j_stream_graph) + self.assertEqual(expected_upstream_chainable, upstream_chainable) + + j_out_stream_edge = j_stream_node.getOutEdges().get(0) + downstream_chainable = JStreamingJobGraphGenerator.isChainable( + j_out_stream_edge, j_stream_graph) + self.assertEqual(expected_downstream_chainable, downstream_chainable) + + # The map_operator_1 has the same parallelism with map_operator_0 and map_operator_2, and + # ship_strategy for map_operator_0 and map_operator_1 is FORWARD, so the map_operator_1 + # can be chained with map_operator_0 and map_operator_2. + j_generated_stream_graph = self.env._j_stream_execution_environment\ + .getStreamGraph("test start new_chain", True) + assert_chainable(j_generated_stream_graph, True, True) + + ds = self.env.from_collection([1, 2, 3]) + # Start a new chain for map_operator_1 + ds.map(lambda x: x).set_parallelism(2).name(chained_operator_name_0) \ + .map(lambda x: x).set_parallelism(2).name(chained_operator_name_1).start_new_chain() \ + .map(lambda x: x).set_parallelism(2).name(chained_operator_name_2) \ + .add_sink(self.test_sink) + + j_generated_stream_graph = self.env._j_stream_execution_environment \ + .getStreamGraph("test start new_chain", True) + # We start a new chain for map operator, therefore, it cannot be chained with upstream + # operator, but can be chained with downstream operator. + assert_chainable(j_generated_stream_graph, False, True) + + ds = self.env.from_collection([1, 2, 3]) + # Disable chaining for map_operator_1 + ds.map(lambda x: x).set_parallelism(2).name(chained_operator_name_0) \ + .map(lambda x: x).set_parallelism(2).name(chained_operator_name_1).disable_chaining() \ + .map(lambda x: x).set_parallelism(2).name(chained_operator_name_2) \ + .add_sink(self.test_sink) + + j_generated_stream_graph = self.env._j_stream_execution_environment \ + .getStreamGraph("test start new_chain", True) + # We disable chaining for map_operator_1, therefore, it cannot be chained with + # upstream and downstream operators. + assert_chainable(j_generated_stream_graph, False, False) + + def test_timestamp_assigner_and_watermark_strategy(self): + self.env.set_parallelism(1) + self.env.get_config().set_auto_watermark_interval(2000) + self.env.set_stream_time_characteristic(TimeCharacteristic.EventTime) + data_stream = self.env.from_collection([(1, '1603708211000'), + (2, '1603708224000'), + (3, '1603708226000'), + (4, '1603708289000')], + type_info=Types.ROW([Types.INT(), Types.STRING()])) + + class MyTimestampAssigner(TimestampAssigner): + + def extract_timestamp(self, value, record_timestamp) -> int: + return int(value[1]) + + class MyProcessFunction(KeyedProcessFunction): + + def process_element(self, value, ctx): + current_timestamp = ctx.timestamp() + current_watermark = ctx.timer_service().current_watermark() + current_key = ctx.get_current_key() + yield "current key: {}, current timestamp: {}, current watermark: {}, " \ + "current_value: {}".format(str(current_key), str(current_timestamp), + str(current_watermark), str(value)) + + def on_timer(self, timestamp, ctx): + pass + + watermark_strategy = WatermarkStrategy.for_monotonous_timestamps()\ + .with_timestamp_assigner(MyTimestampAssigner()) + data_stream.assign_timestamps_and_watermarks(watermark_strategy)\ + .key_by(lambda x: x[0], key_type_info=Types.INT()) \ + .process(MyProcessFunction(), output_type=Types.STRING()).add_sink(self.test_sink) + self.env.execute('test time stamp assigner with keyed process function') + result = self.test_sink.get_results() + expected_result = ["current key: 1, current timestamp: 1603708211000, current watermark: " + "9223372036854775807, current_value: Row(f0=1, f1='1603708211000')", + "current key: 2, current timestamp: 1603708224000, current watermark: " + "9223372036854775807, current_value: Row(f0=2, f1='1603708224000')", + "current key: 3, current timestamp: 1603708226000, current watermark: " + "9223372036854775807, current_value: Row(f0=3, f1='1603708226000')", + "current key: 4, current timestamp: 1603708289000, current watermark: " + "9223372036854775807, current_value: Row(f0=4, f1='1603708289000')"] + result.sort() + expected_result.sort() + self.assertEqual(expected_result, result) + + +class BatchModeDataStreamTests(DataStreamTests, PyFlinkBatchTestCase): + + def test_timestamp_assigner_and_watermark_strategy(self): + self.env.set_parallelism(1) + self.env.get_config().set_auto_watermark_interval(2000) + self.env.set_stream_time_characteristic(TimeCharacteristic.EventTime) + data_stream = self.env.from_collection([(1, '1603708211000'), + (2, '1603708224000'), + (3, '1603708226000'), + (4, '1603708289000')], + type_info=Types.ROW([Types.INT(), Types.STRING()])) + + class MyTimestampAssigner(TimestampAssigner): + + def extract_timestamp(self, value, record_timestamp) -> int: + return int(value[1]) + + class MyProcessFunction(KeyedProcessFunction): + + def process_element(self, value, ctx): + current_timestamp = ctx.timestamp() + current_watermark = ctx.timer_service().current_watermark() + current_key = ctx.get_current_key() + yield "current key: {}, current timestamp: {}, current watermark: {}, " \ + "current_value: {}".format(str(current_key), str(current_timestamp), + str(current_watermark), str(value)) + + def on_timer(self, timestamp, ctx): + pass + + watermark_strategy = WatermarkStrategy.for_monotonous_timestamps() \ + .with_timestamp_assigner(MyTimestampAssigner()) + data_stream.assign_timestamps_and_watermarks(watermark_strategy) \ + .key_by(lambda x: x[0], key_type_info=Types.INT()) \ + .process(MyProcessFunction(), output_type=Types.STRING()).add_sink(self.test_sink) + self.env.execute('test time stamp assigner with keyed process function') + result = self.test_sink.get_results() + expected_result = ["current key: 1, current timestamp: 1603708211000, current watermark: " + "-9223372036854775808, current_value: Row(f0=1, f1='1603708211000')", + "current key: 2, current timestamp: 1603708224000, current watermark: " + "-9223372036854775808, current_value: Row(f0=2, f1='1603708224000')", + "current key: 3, current timestamp: 1603708226000, current watermark: " + "-9223372036854775808, current_value: Row(f0=3, f1='1603708226000')", + "current key: 4, current timestamp: 1603708289000, current watermark: " + "-9223372036854775808, current_value: Row(f0=4, f1='1603708289000')"] + result.sort() + expected_result.sort() + self.assertEqual(expected_result, result) class MyMapFunction(MapFunction): diff --git a/flink-python/pyflink/datastream/tests/test_state_backend.py b/flink-python/pyflink/datastream/tests/test_state_backend.py index 3ce18dcd2a7c5e..a22615fd77aa52 100644 --- a/flink-python/pyflink/datastream/tests/test_state_backend.py +++ b/flink-python/pyflink/datastream/tests/test_state_backend.py @@ -21,7 +21,7 @@ from pyflink.java_gateway import get_gateway from pyflink.pyflink_gateway_server import on_windows from pyflink.testing.test_case_utils import PyFlinkTestCase -from pyflink.util.utils import load_java_class +from pyflink.util.java_utils import load_java_class class MemoryStateBackendTests(PyFlinkTestCase): diff --git a/flink-python/pyflink/datastream/tests/test_stream_execution_environment.py b/flink-python/pyflink/datastream/tests/test_stream_execution_environment.py index cff7236b6f6330..b28b18ee0860d1 100644 --- a/flink-python/pyflink/datastream/tests/test_stream_execution_environment.py +++ b/flink-python/pyflink/datastream/tests/test_stream_execution_environment.py @@ -581,17 +581,5 @@ def add_from_file(i): self.assertIsNotNone(env_config_with_dependencies['python.files']) self.assertIsNotNone(env_config_with_dependencies['python.archives']) - def test_batch_execution_mode(self): - # set the runtime execution mode to BATCH - JRuntimeExecutionMode = get_gateway().jvm \ - .org.apache.flink.api.common.RuntimeExecutionMode.BATCH - self.env._j_stream_execution_environment.setRuntimeMode(JRuntimeExecutionMode) - self.env.from_collection([(1, 'Hi', 'Hello'), (2, 'Hello', 'Hi')]).map(lambda x: x) \ - .add_sink(self.test_sink) - - # Running jobs in Batch mode is not supported yet, it should throw an exception. - with self.assertRaises(Exception): - self.env.get_execution_plan() - def tearDown(self) -> None: self.test_sink.clear() diff --git a/flink-python/pyflink/table/expression.py b/flink-python/pyflink/table/expression.py index 3a242d5f3f01cc..0fed620c48fae8 100644 --- a/flink-python/pyflink/table/expression.py +++ b/flink-python/pyflink/table/expression.py @@ -21,7 +21,7 @@ from pyflink import add_version_doc from pyflink.java_gateway import get_gateway from pyflink.table.types import DataType, _to_java_data_type -from pyflink.util.utils import to_jarray +from pyflink.util.java_utils import to_jarray __all__ = ['Expression', 'TimeIntervalUnit', 'TimePointUnit'] diff --git a/flink-python/pyflink/table/expressions.py b/flink-python/pyflink/table/expressions.py index c756824d836e2d..44182ae72a5f8d 100644 --- a/flink-python/pyflink/table/expressions.py +++ b/flink-python/pyflink/table/expressions.py @@ -22,7 +22,7 @@ from pyflink.table.expression import Expression, _get_java_expression, TimePointUnit from pyflink.table.types import _to_java_data_type, DataType, _to_java_type from pyflink.table.udf import UserDefinedFunctionWrapper, UserDefinedTableFunctionWrapper -from pyflink.util.utils import to_jarray, load_java_class +from pyflink.util.java_utils import to_jarray, load_java_class __all__ = ['if_then_else', 'lit', 'col', 'range_', 'and_', 'or_', 'UNBOUNDED_ROW', 'UNBOUNDED_RANGE', 'CURRENT_ROW', 'CURRENT_RANGE', 'current_date', 'current_time', diff --git a/flink-python/pyflink/table/sinks.py b/flink-python/pyflink/table/sinks.py index 66f42df8127423..26df8a040f9d6b 100644 --- a/flink-python/pyflink/table/sinks.py +++ b/flink-python/pyflink/table/sinks.py @@ -18,7 +18,7 @@ from pyflink.java_gateway import get_gateway from pyflink.table.types import _to_java_type -from pyflink.util import utils +from pyflink.util import java_utils __all__ = ['TableSink', 'CsvTableSink', 'WriteMode'] @@ -70,8 +70,9 @@ def __init__(self, field_names, field_types, path, field_delimiter=',', num_file raise Exception('Unsupported write_mode: %s' % write_mode) j_csv_table_sink = gateway.jvm.CsvTableSink( path, field_delimiter, num_files, j_write_mode) - j_field_names = utils.to_jarray(gateway.jvm.String, field_names) - j_field_types = utils.to_jarray(gateway.jvm.TypeInformation, - [_to_java_type(field_type) for field_type in field_types]) + j_field_names = java_utils.to_jarray(gateway.jvm.String, field_names) + j_field_types = java_utils.to_jarray( + gateway.jvm.TypeInformation, + [_to_java_type(field_type) for field_type in field_types]) j_csv_table_sink = j_csv_table_sink.configure(j_field_names, j_field_types) super(CsvTableSink, self).__init__(j_csv_table_sink) diff --git a/flink-python/pyflink/table/statement_set.py b/flink-python/pyflink/table/statement_set.py index 5b097945d54fcb..697733596c033e 100644 --- a/flink-python/pyflink/table/statement_set.py +++ b/flink-python/pyflink/table/statement_set.py @@ -17,7 +17,7 @@ ################################################################################ from pyflink.table import ExplainDetail from pyflink.table.table_result import TableResult -from pyflink.util.utils import to_j_explain_detail_arr +from pyflink.util.java_utils import to_j_explain_detail_arr __all__ = ['StatementSet'] diff --git a/flink-python/pyflink/table/table.py b/flink-python/pyflink/table/table.py index 95a2f135ee5095..bde78806c7c935 100644 --- a/flink-python/pyflink/table/table.py +++ b/flink-python/pyflink/table/table.py @@ -34,8 +34,8 @@ from pyflink.table.utils import tz_convert_from_internal, to_expression_jarray from pyflink.table.window import OverWindow, GroupWindow -from pyflink.util.utils import to_jarray -from pyflink.util.utils import to_j_explain_detail_arr +from pyflink.util.java_utils import to_jarray +from pyflink.util.java_utils import to_j_explain_detail_arr __all__ = ['Table', 'GroupedTable', 'GroupWindowedTable', 'OverWindowedTable', 'WindowGroupedTable'] diff --git a/flink-python/pyflink/table/table_environment.py b/flink-python/pyflink/table/table_environment.py index f218a06bf0f848..9ddfdea9b4a1b7 100644 --- a/flink-python/pyflink/table/table_environment.py +++ b/flink-python/pyflink/table/table_environment.py @@ -48,8 +48,8 @@ from pyflink.table.udf import UserDefinedFunctionWrapper, AggregateFunction, udaf, \ UserDefinedAggregateFunctionWrapper, udtaf, TableAggregateFunction from pyflink.table.utils import to_expression_jarray -from pyflink.util import utils -from pyflink.util.utils import get_j_env_configuration, is_local_deployment, load_java_class, \ +from pyflink.util import java_utils +from pyflink.util.java_utils import get_j_env_configuration, is_local_deployment, load_java_class, \ to_j_explain_detail_arr, to_jarray __all__ = [ @@ -524,7 +524,7 @@ def scan(self, *table_path: str) -> Table: """ warnings.warn("Deprecated in 1.10. Use from_path instead.", DeprecationWarning) gateway = get_gateway() - j_table_paths = utils.to_jarray(gateway.jvm.String, table_path) + j_table_paths = java_utils.to_jarray(gateway.jvm.String, table_path) j_table = self._j_tenv.scan(j_table_paths) return Table(j_table, self) diff --git a/flink-python/pyflink/table/table_schema.py b/flink-python/pyflink/table/table_schema.py index 9a0b0523e14ce1..3c4b5eab0e394f 100644 --- a/flink-python/pyflink/table/table_schema.py +++ b/flink-python/pyflink/table/table_schema.py @@ -19,7 +19,7 @@ from pyflink.java_gateway import get_gateway from pyflink.table.types import _to_java_type, _from_java_type, DataType, RowType -from pyflink.util.utils import to_jarray +from pyflink.util.java_utils import to_jarray __all__ = ['TableSchema'] diff --git a/flink-python/pyflink/table/tests/test_table_environment_api.py b/flink-python/pyflink/table/tests/test_table_environment_api.py index c32b1398ca75ce..7c6778c6004410 100644 --- a/flink-python/pyflink/table/tests/test_table_environment_api.py +++ b/flink-python/pyflink/table/tests/test_table_environment_api.py @@ -44,7 +44,7 @@ PyFlinkOldBatchTableTestCase, PyFlinkBlinkBatchTableTestCase, PyFlinkBlinkStreamTableTestCase, \ PyFlinkLegacyBlinkBatchTableTestCase, PyFlinkLegacyFlinkStreamTableTestCase, \ PyFlinkLegacyBlinkStreamTableTestCase, _load_specific_flink_module_jars -from pyflink.util.utils import get_j_env_configuration +from pyflink.util.java_utils import get_j_env_configuration class TableEnvironmentTest(object): diff --git a/flink-python/pyflink/table/types.py b/flink-python/pyflink/table/types.py index 4754a5f5aadb96..05ce897713b690 100644 --- a/flink-python/pyflink/table/types.py +++ b/flink-python/pyflink/table/types.py @@ -32,7 +32,7 @@ from typing import List, Union from pyflink.common.types import _create_row -from pyflink.util.utils import to_jarray, is_instance_of +from pyflink.util.java_utils import to_jarray, is_instance_of from pyflink.java_gateway import get_gateway from pyflink.common import Row, RowKind diff --git a/flink-python/pyflink/table/udf.py b/flink-python/pyflink/table/udf.py index 0e4a657907ba9e..16d8931c47ddd6 100644 --- a/flink-python/pyflink/table/udf.py +++ b/flink-python/pyflink/table/udf.py @@ -25,7 +25,7 @@ from pyflink.metrics import MetricGroup from pyflink.table import Expression from pyflink.table.types import DataType, _to_java_type, _to_java_data_type -from pyflink.util import utils +from pyflink.util import java_utils __all__ = ['FunctionContext', 'AggregateFunction', 'ScalarFunction', 'TableFunction', 'TableAggregateFunction', 'udf', 'udtf', 'udaf', 'udtaf'] @@ -377,7 +377,7 @@ def get_python_function_kind(): raise TypeError("Unsupported func_type: %s." % self._func_type) if self._input_types is not None: - j_input_types = utils.to_jarray( + j_input_types = java_utils.to_jarray( gateway.jvm.TypeInformation, [_to_java_type(i) for i in self._input_types]) else: j_input_types = None @@ -458,8 +458,8 @@ def __init__(self, func, input_types, result_types, deterministic=None, name=Non def _create_judf(self, serialized_func, j_input_types, j_function_kind): gateway = get_gateway() - j_result_types = utils.to_jarray(gateway.jvm.TypeInformation, - [_to_java_type(i) for i in self._result_types]) + j_result_types = java_utils.to_jarray(gateway.jvm.TypeInformation, + [_to_java_type(i) for i in self._result_types]) j_result_type = gateway.jvm.org.apache.flink.api.java.typeutils.RowTypeInfo(j_result_types) PythonTableFunction = gateway.jvm \ .org.apache.flink.table.functions.python.PythonTableFunction @@ -514,7 +514,7 @@ def _create_judf(self, serialized_func, j_input_types, j_function_kind): if j_input_types is not None: gateway = get_gateway() - j_input_types = utils.to_jarray( + j_input_types = java_utils.to_jarray( gateway.jvm.DataType, [_to_java_data_type(i) for i in self._input_types]) j_result_type = _to_java_data_type(self._result_type) j_accumulator_type = _to_java_data_type(self._accumulator_type) diff --git a/flink-python/pyflink/table/utils.py b/flink-python/pyflink/table/utils.py index f9c682cb7cb563..9f256010e34972 100644 --- a/flink-python/pyflink/table/utils.py +++ b/flink-python/pyflink/table/utils.py @@ -22,7 +22,7 @@ from pyflink.java_gateway import get_gateway from pyflink.table.types import DataType, LocalZonedTimestampType, Row, RowType, \ TimeType, DateType, ArrayType, MapType, TimestampType, FloatType -from pyflink.util.utils import to_jarray +from pyflink.util.java_utils import to_jarray import datetime import pickle diff --git a/flink-python/pyflink/testing/source_sink_utils.py b/flink-python/pyflink/testing/source_sink_utils.py index e8fd012df44a29..52e9caf0468dea 100644 --- a/flink-python/pyflink/testing/source_sink_utils.py +++ b/flink-python/pyflink/testing/source_sink_utils.py @@ -25,7 +25,7 @@ from pyflink.java_gateway import get_gateway from pyflink.table.sinks import TableSink from pyflink.table.types import _to_java_type -from pyflink.util import utils +from pyflink.util import java_utils class TestTableSink(TableSink): @@ -37,9 +37,10 @@ class TestTableSink(TableSink): def __init__(self, j_table_sink, field_names, field_types): gateway = get_gateway() - j_field_names = utils.to_jarray(gateway.jvm.String, field_names) - j_field_types = utils.to_jarray(gateway.jvm.TypeInformation, - [_to_java_type(field_type) for field_type in field_types]) + j_field_names = java_utils.to_jarray(gateway.jvm.String, field_names) + j_field_types = java_utils.to_jarray( + gateway.jvm.TypeInformation, + [_to_java_type(field_type) for field_type in field_types]) j_table_sink = j_table_sink.configure(j_field_names, j_field_types) super(TestTableSink, self).__init__(j_table_sink) diff --git a/flink-python/pyflink/testing/test_case_utils.py b/flink-python/pyflink/testing/test_case_utils.py index c717de03c130e2..479b53db62496d 100644 --- a/flink-python/pyflink/testing/test_case_utils.py +++ b/flink-python/pyflink/testing/test_case_utils.py @@ -29,6 +29,7 @@ from py4j.protocol import Py4JJavaError from pyflink.common import JobExecutionResult +from pyflink.datastream.execution_mode import RuntimeExecutionMode from pyflink.table import TableConfig from pyflink.table.sources import CsvTableSource from pyflink.dataset.execution_environment import ExecutionEnvironment @@ -38,7 +39,7 @@ TableEnvironment from pyflink.table.environment_settings import EnvironmentSettings from pyflink.java_gateway import get_gateway -from pyflink.util.utils import add_jars_to_context_class_loader, to_jarray +from pyflink.util.java_utils import add_jars_to_context_class_loader, to_jarray if os.getenv("VERBOSE"): log_level = logging.DEBUG @@ -258,6 +259,30 @@ def setUp(self): "python.fn-execution.bundle.size", "1") +class PyFlinkStreamingTestCase(PyFlinkTestCase): + """ + Base class for streaming tests. + """ + + def setUp(self): + super(PyFlinkStreamingTestCase, self).setUp() + self.env = StreamExecutionEnvironment.get_execution_environment() + self.env.set_parallelism(2) + self.env.set_runtime_mode(RuntimeExecutionMode.STREAMING) + + +class PyFlinkBatchTestCase(PyFlinkTestCase): + """ + Base class for batch tests. + """ + + def setUp(self): + super(PyFlinkBatchTestCase, self).setUp() + self.env = StreamExecutionEnvironment.get_execution_environment() + self.env.set_parallelism(2) + self.env.set_runtime_mode(RuntimeExecutionMode.BATCH) + + class PythonAPICompletenessTestCase(object): """ Base class for Python API completeness tests, i.e., diff --git a/flink-python/pyflink/util/utils.py b/flink-python/pyflink/util/java_utils.py similarity index 90% rename from flink-python/pyflink/util/utils.py rename to flink-python/pyflink/util/java_utils.py index 065b537be9708b..3ffe22332d5b0e 100644 --- a/flink-python/pyflink/util/utils.py +++ b/flink-python/pyflink/util/java_utils.py @@ -82,15 +82,24 @@ def is_instance_of(java_object, java_class): def get_j_env_configuration(t_env): if is_instance_of(t_env._get_j_env(), "org.apache.flink.api.java.ExecutionEnvironment"): - j_configuration = t_env._get_j_env().getConfiguration() + return t_env._get_j_env().getConfiguration() else: - env_clazz = load_java_class( - "org.apache.flink.streaming.api.environment.StreamExecutionEnvironment") - method = env_clazz.getDeclaredMethod( - "getConfiguration", to_jarray(get_gateway().jvm.Class, [])) - method.setAccessible(True) - j_configuration = method.invoke(t_env._get_j_env(), to_jarray(get_gateway().jvm.Object, [])) - return j_configuration + return invoke_method( + t_env._get_j_env(), + "org.apache.flink.streaming.api.environment.StreamExecutionEnvironment", + "getConfiguration" + ) + + +def invoke_method(obj, object_type, method_name, args=None, arg_types=None): + env_clazz = load_java_class(object_type) + method = env_clazz.getDeclaredMethod( + method_name, + to_jarray( + get_gateway().jvm.Class, + [load_java_class(arg_type) for arg_type in arg_types or []])) + method.setAccessible(True) + return method.invoke(obj, to_jarray(get_gateway().jvm.Object, args or [])) def is_local_deployment(j_configuration): diff --git a/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java b/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java index 78818c3b5a2944..6681c1f0fa4564 100644 --- a/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java +++ b/flink-python/src/main/java/org/apache/flink/python/util/PythonConfigUtil.java @@ -168,10 +168,6 @@ public static StreamGraph generateStreamGraphWithDependencies( Configuration mergedConfig = getEnvConfigWithDependencies(env); boolean executedInBatchMode = isExecuteInBatchMode(env, mergedConfig); - if (executedInBatchMode) { - throw new UnsupportedOperationException( - "Batch mode is still not supported in Python DataStream API."); - } if (mergedConfig.getBoolean(PythonOptions.USE_MANAGED_MEMORY)) { Field transformationsField = From 79fd455ced999676947f0a73b22e27384af63049 Mon Sep 17 00:00:00 2001 From: Timo Walther Date: Fri, 19 Mar 2021 12:33:57 +0100 Subject: [PATCH 3/7] [hotfix][table-runtime-blink] Enable ExternalTypeInfo for all kinds of data types --- .../flink/table/runtime/typeutils/ExternalTypeInfo.java | 3 +-- .../table/runtime/typeutils/ExternalTypeInfoTest.java | 9 +++++++++ 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/typeutils/ExternalTypeInfo.java b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/typeutils/ExternalTypeInfo.java index 6652fb9540a8cd..ad8e6e50ad39b1 100644 --- a/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/typeutils/ExternalTypeInfo.java +++ b/flink-table/flink-table-runtime-blink/src/main/java/org/apache/flink/table/runtime/typeutils/ExternalTypeInfo.java @@ -77,8 +77,7 @@ private static TypeSerializer createExternalTypeSerializer(DataType dataT return (TypeSerializer) rawType.getTypeSerializer(); } } - throw new UnsupportedOperationException( - "External type information is not fully implemented yet."); + return ExternalSerializer.of(dataType); } // -------------------------------------------------------------------------------------------- diff --git a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/typeutils/ExternalTypeInfoTest.java b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/typeutils/ExternalTypeInfoTest.java index 63aabb65042a0d..240096b730daa3 100644 --- a/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/typeutils/ExternalTypeInfoTest.java +++ b/flink-table/flink-table-runtime-blink/src/test/java/org/apache/flink/table/runtime/typeutils/ExternalTypeInfoTest.java @@ -40,6 +40,15 @@ protected ExternalTypeInfo[] getTestData() { DataTypes.RAW( ByteBuffer.class, new KryoSerializer<>(ByteBuffer.class, new ExecutionConfig()))), + ExternalTypeInfo.of(DataTypes.INT()), + ExternalTypeInfo.of( + DataTypes.ROW( + DataTypes.FIELD("a", DataTypes.INT()), + DataTypes.FIELD("b", DataTypes.DATE().bridgedTo(Integer.class)))), + ExternalTypeInfo.of( + DataTypes.ROW( + DataTypes.FIELD("a", DataTypes.INT()), + DataTypes.FIELD("b", DataTypes.DATE()))) }; } } From 619ad4bd0fc522c14952ba2e9c37c3fd0099053b Mon Sep 17 00:00:00 2001 From: Timo Walther Date: Mon, 22 Mar 2021 09:52:53 +0100 Subject: [PATCH 4/7] [FLINK-21872][table-api-java] Add utility for DataStream API's DataType, Schema, and projection This closes #15345. --- .../catalog/ExternalSchemaTranslator.java | 332 ++++++++++++++++++ .../catalog/ExternalSchemaTranslatorTest.java | 249 +++++++++++++ .../org/apache/flink/table/api/Schema.java | 9 +- .../table/expressions/SqlCallExpression.java | 18 + 4 files changed, 607 insertions(+), 1 deletion(-) create mode 100644 flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/ExternalSchemaTranslator.java create mode 100644 flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/catalog/ExternalSchemaTranslatorTest.java diff --git a/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/ExternalSchemaTranslator.java b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/ExternalSchemaTranslator.java new file mode 100644 index 00000000000000..6f7f240e7afbad --- /dev/null +++ b/flink-table/flink-table-api-java/src/main/java/org/apache/flink/table/catalog/ExternalSchemaTranslator.java @@ -0,0 +1,332 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.catalog; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.Schema.UnresolvedColumn; +import org.apache.flink.table.api.Schema.UnresolvedPhysicalColumn; +import org.apache.flink.table.api.ValidationException; +import org.apache.flink.table.types.DataType; +import org.apache.flink.table.types.logical.LogicalType; +import org.apache.flink.table.types.logical.LogicalTypeRoot; +import org.apache.flink.table.types.logical.RowType; +import org.apache.flink.table.types.logical.StructuredType; +import org.apache.flink.table.types.logical.utils.LogicalTypeChecks; +import org.apache.flink.table.types.logical.utils.LogicalTypeUtils; +import org.apache.flink.table.types.utils.DataTypeUtils; +import org.apache.flink.table.types.utils.TypeInfoDataTypeConverter; + +import javax.annotation.Nullable; + +import java.util.Collections; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.hasRoot; + +/** + * Utility to derive a physical {@link DataType}, {@link Schema}, and projections when entering or + * leaving the table ecosystem from and to other APIs where {@link TypeInformation} is required. + */ +public final class ExternalSchemaTranslator { + + /** + * Converts the given {@link TypeInformation} and an optional declared {@link Schema} (possibly + * incomplete) into the final {@link InputResult}. + * + *

This method serves three types of use cases: + * + *

    + *
  • 1. Derive physical columns from the input type information. + *
  • 2. Derive physical columns but merge them with declared computed columns and other + * schema information. + *
  • 3. Derive and enrich physical columns and merge other schema information. + *
+ */ + public static InputResult fromExternal( + DataTypeFactory dataTypeFactory, + TypeInformation inputTypeInfo, + @Nullable Schema declaredSchema) { + final DataType inputDataType = + TypeInfoDataTypeConverter.toDataType(dataTypeFactory, inputTypeInfo); + final LogicalType inputType = inputDataType.getLogicalType(); + + // we don't allow modifying the number of columns during enrichment, therefore we preserve + // whether the original type was qualified as a top-level record or not + final boolean isTopLevelRecord = LogicalTypeChecks.isCompositeType(inputType); + + // no schema has been declared by the user, + // the schema will be entirely derived from the input + if (declaredSchema == null) { + final Schema.Builder builder = Schema.newBuilder(); + addPhysicalDataTypeFields(builder, inputDataType); + return new InputResult(inputDataType, isTopLevelRecord, builder.build(), null); + } + + final List declaredColumns = declaredSchema.getColumns(); + + // the declared schema does not contain physical information, + // thus, it only enriches the non-physical column parts + if (declaredColumns.stream().noneMatch(ExternalSchemaTranslator::isPhysical)) { + final Schema.Builder builder = Schema.newBuilder(); + addPhysicalDataTypeFields(builder, inputDataType); + builder.fromSchema(declaredSchema); + return new InputResult(inputDataType, isTopLevelRecord, builder.build(), null); + } + + // the declared schema enriches the physical data type and the derived schema, + // it possibly projects the result + final DataType patchedDataType = + patchDataTypeFromDeclaredSchema(dataTypeFactory, inputDataType, declaredColumns); + final Schema patchedSchema = + createPatchedSchema(isTopLevelRecord, patchedDataType, declaredSchema); + final int[] projections = extractProjections(patchedSchema, declaredSchema); + return new InputResult(patchedDataType, isTopLevelRecord, patchedSchema, projections); + } + + private static int[] extractProjections(Schema patchedSchema, Schema declaredSchema) { + final List patchedColumns = + patchedSchema.getColumns().stream() + .map(UnresolvedColumn::getName) + .collect(Collectors.toList()); + return declaredSchema.getColumns().stream() + .map(UnresolvedColumn::getName) + .mapToInt(patchedColumns::indexOf) + .toArray(); + } + + private static Schema createPatchedSchema( + boolean isTopLevelRecord, DataType patchedDataType, Schema declaredSchema) { + final Schema.Builder builder = Schema.newBuilder(); + + // physical columns + if (isTopLevelRecord) { + addPhysicalDataTypeFields(builder, patchedDataType); + } else { + builder.column( + LogicalTypeUtils.getAtomicName(Collections.emptyList()), patchedDataType); + } + + // remaining schema + final List nonPhysicalColumns = + declaredSchema.getColumns().stream() + .filter(c -> !isPhysical(c)) + .collect(Collectors.toList()); + builder.fromColumns(nonPhysicalColumns); + declaredSchema + .getWatermarkSpecs() + .forEach( + spec -> + builder.watermark( + spec.getColumnName(), spec.getWatermarkExpression())); + declaredSchema + .getPrimaryKey() + .ifPresent( + key -> + builder.primaryKeyNamed( + key.getConstraintName(), key.getColumnNames())); + return builder.build(); + } + + private static DataType patchDataTypeFromDeclaredSchema( + DataTypeFactory dataTypeFactory, + DataType inputDataType, + List declaredColumns) { + final List physicalColumns = + declaredColumns.stream() + .filter(ExternalSchemaTranslator::isPhysical) + .map(UnresolvedPhysicalColumn.class::cast) + .collect(Collectors.toList()); + + DataType patchedDataType = inputDataType; + for (UnresolvedPhysicalColumn physicalColumn : physicalColumns) { + patchedDataType = + patchDataTypeFromColumn(dataTypeFactory, patchedDataType, physicalColumn); + } + return patchedDataType; + } + + private static DataType patchDataTypeFromColumn( + DataTypeFactory dataTypeFactory, + DataType dataType, + UnresolvedPhysicalColumn physicalColumn) { + final List fieldNames = DataTypeUtils.flattenToNames(dataType); + final String columnName = physicalColumn.getName(); + if (!fieldNames.contains(columnName)) { + throw new ValidationException( + String.format( + "Unable to find a field named '%s' in the physical data type derived " + + "from the given type information for schema declaration. " + + "Make sure that the type information is not a generic raw " + + "type. Currently available fields are: %s", + columnName, fieldNames)); + } + final DataType columnDataType = + dataTypeFactory.createDataType(physicalColumn.getDataType()); + final LogicalType type = dataType.getLogicalType(); + + // the following lines make assumptions on what comes out of the TypeInfoDataTypeConverter + // e.g. we can assume that there will be no DISTINCT type and only anonymously defined + // structured types without a super type + if (hasRoot(type, LogicalTypeRoot.ROW)) { + return patchRowDataType(dataType, columnName, columnDataType); + } else if (hasRoot(type, LogicalTypeRoot.STRUCTURED_TYPE)) { + return patchStructuredDataType(dataType, columnName, columnDataType); + } else { + // this also covers the case where a top-level generic type enters the + // Table API, the type can be patched to a more specific type but the schema will still + // keep it nested in a single field without flattening + return columnDataType; + } + } + + private static DataType patchRowDataType( + DataType dataType, String patchedFieldName, DataType patchedFieldDataType) { + final RowType type = (RowType) dataType.getLogicalType(); + final List oldFieldNames = DataTypeUtils.flattenToNames(dataType); + final List oldFieldDataTypes = dataType.getChildren(); + final Class oldConversion = dataType.getConversionClass(); + + final DataTypes.Field[] fields = + patchFields( + oldFieldNames, oldFieldDataTypes, patchedFieldName, patchedFieldDataType); + + final DataType newDataType = DataTypes.ROW(fields).bridgedTo(oldConversion); + if (!type.isNullable()) { + return newDataType.notNull(); + } + return newDataType; + } + + private static DataType patchStructuredDataType( + DataType dataType, String patchedFieldName, DataType patchedFieldDataType) { + final StructuredType type = (StructuredType) dataType.getLogicalType(); + final List oldFieldNames = DataTypeUtils.flattenToNames(dataType); + final List oldFieldDataTypes = dataType.getChildren(); + final Class oldConversion = dataType.getConversionClass(); + + final DataTypes.Field[] fields = + patchFields( + oldFieldNames, oldFieldDataTypes, patchedFieldName, patchedFieldDataType); + + final DataType newDataType = + DataTypes.STRUCTURED( + type.getImplementationClass() + .orElseThrow(IllegalStateException::new), + fields) + .bridgedTo(oldConversion); + if (!type.isNullable()) { + return newDataType.notNull(); + } + return newDataType; + } + + private static DataTypes.Field[] patchFields( + List oldFieldNames, + List oldFieldDataTypes, + String patchedFieldName, + DataType patchedFieldDataType) { + return IntStream.range(0, oldFieldNames.size()) + .mapToObj( + pos -> { + final String oldFieldName = oldFieldNames.get(pos); + final DataType newFieldDataType; + if (oldFieldName.equals(patchedFieldName)) { + newFieldDataType = patchedFieldDataType; + } else { + newFieldDataType = oldFieldDataTypes.get(pos); + } + return DataTypes.FIELD(oldFieldName, newFieldDataType); + }) + .toArray(DataTypes.Field[]::new); + } + + private static void addPhysicalDataTypeFields(Schema.Builder builder, DataType dataType) { + final List fieldDataTypes = DataTypeUtils.flattenToDataTypes(dataType); + final List fieldNames = DataTypeUtils.flattenToNames(dataType); + builder.fromFields(fieldNames, fieldDataTypes); + } + + private static boolean isPhysical(UnresolvedColumn column) { + return column instanceof UnresolvedPhysicalColumn; + } + + // -------------------------------------------------------------------------------------------- + // Result representation + // -------------------------------------------------------------------------------------------- + + /** Result of {@link #fromExternal(DataTypeFactory, TypeInformation, Schema)}. */ + public static class InputResult { + + /** + * Data type expected from the first table ecosystem operator for input conversion. The data + * type might not be a row type and can possibly be nullable. + */ + private final DataType physicalDataType; + + /** + * Whether the first table ecosystem operator should treat the physical record as top-level + * record and thus perform implicit flattening. Otherwise the record needs to be wrapped in + * a top-level row. + */ + private final boolean isTopLevelRecord; + + /** + * Schema derived from the physical data type. It does not include the projections of the + * user-provided schema. + */ + private final Schema schema; + + /** + * List of indices to adjust the presents and order of columns from {@link #schema} for the + * final column structure. + */ + private final @Nullable int[] projections; + + private InputResult( + DataType physicalDataType, + boolean isTopLevelRecord, + Schema schema, + @Nullable int[] projections) { + this.physicalDataType = physicalDataType; + this.isTopLevelRecord = isTopLevelRecord; + this.schema = schema; + this.projections = projections; + } + + public DataType getPhysicalDataType() { + return physicalDataType; + } + + public boolean isTopLevelRecord() { + return isTopLevelRecord; + } + + public Schema getSchema() { + return schema; + } + + public @Nullable int[] getProjections() { + return projections; + } + } +} diff --git a/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/catalog/ExternalSchemaTranslatorTest.java b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/catalog/ExternalSchemaTranslatorTest.java new file mode 100644 index 00000000000000..367fc95cb248ca --- /dev/null +++ b/flink-table/flink-table-api-java/src/test/java/org/apache/flink/table/catalog/ExternalSchemaTranslatorTest.java @@ -0,0 +1,249 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.table.catalog; + +import org.apache.flink.api.common.typeinfo.TypeInformation; +import org.apache.flink.api.common.typeinfo.Types; +import org.apache.flink.table.api.DataTypes; +import org.apache.flink.table.api.Schema; +import org.apache.flink.table.api.ValidationException; +import org.apache.flink.table.types.utils.DataTypeFactoryMock; +import org.apache.flink.types.Row; + +import org.junit.Test; + +import java.math.BigDecimal; +import java.time.DayOfWeek; +import java.util.Optional; + +import static org.apache.flink.core.testutils.FlinkMatchers.containsMessage; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +/** Tests for {@link ExternalSchemaTranslator}. */ +public class ExternalSchemaTranslatorTest { + + @Test + public void testInputFromRow() { + final TypeInformation inputTypeInfo = + Types.ROW(Types.ROW(Types.INT, Types.BOOLEAN), Types.ENUM(DayOfWeek.class)); + + final ExternalSchemaTranslator.InputResult result = + ExternalSchemaTranslator.fromExternal( + dataTypeFactoryWithRawType(DayOfWeek.class), inputTypeInfo, null); + + assertEquals( + DataTypes.ROW( + DataTypes.FIELD( + "f0", + DataTypes.ROW( + DataTypes.FIELD("f0", DataTypes.INT()), + DataTypes.FIELD("f1", DataTypes.BOOLEAN()))), + DataTypes.FIELD( + "f1", DataTypeFactoryMock.dummyRaw(DayOfWeek.class))) + .notNull(), + result.getPhysicalDataType()); + + assertTrue(result.isTopLevelRecord()); + + assertEquals( + Schema.newBuilder() + .column( + "f0", + DataTypes.ROW( + DataTypes.FIELD("f0", DataTypes.INT()), + DataTypes.FIELD("f1", DataTypes.BOOLEAN()))) + .column("f1", DataTypeFactoryMock.dummyRaw(DayOfWeek.class)) + .build(), + result.getSchema()); + + assertNull(result.getProjections()); + } + + @Test + public void testInputFromAtomic() { + final TypeInformation inputTypeInfo = Types.GENERIC(Row.class); + + final ExternalSchemaTranslator.InputResult result = + ExternalSchemaTranslator.fromExternal( + dataTypeFactoryWithRawType(Row.class), inputTypeInfo, null); + + assertEquals(DataTypeFactoryMock.dummyRaw(Row.class), result.getPhysicalDataType()); + + assertFalse(result.isTopLevelRecord()); + + assertEquals( + Schema.newBuilder().column("f0", DataTypeFactoryMock.dummyRaw(Row.class)).build(), + result.getSchema()); + + assertNull(result.getProjections()); + } + + @Test + public void testInputFromRowWithNonPhysicalDeclaredSchema() { + final TypeInformation inputTypeInfo = Types.ROW(Types.INT, Types.LONG); + + final ExternalSchemaTranslator.InputResult result = + ExternalSchemaTranslator.fromExternal( + dataTypeFactory(), + inputTypeInfo, + Schema.newBuilder() + .columnByExpression("computed", "f1 + 42") + .columnByExpression("computed2", "f1 - 1") + .primaryKeyNamed("pk", "f0") + .build()); + + assertEquals( + DataTypes.ROW( + DataTypes.FIELD("f0", DataTypes.INT()), + DataTypes.FIELD("f1", DataTypes.BIGINT())) + .notNull(), + result.getPhysicalDataType()); + + assertTrue(result.isTopLevelRecord()); + + assertEquals( + Schema.newBuilder() + .column("f0", DataTypes.INT()) + .column("f1", DataTypes.BIGINT()) + .columnByExpression("computed", "f1 + 42") + .columnByExpression("computed2", "f1 - 1") + .primaryKeyNamed("pk", "f0") + .build(), + result.getSchema()); + + assertNull(result.getProjections()); + } + + @Test + public void testInputFromRowWithPhysicalDeclaredSchema() { + final TypeInformation inputTypeInfo = + Types.ROW(Types.INT, Types.LONG, Types.GENERIC(BigDecimal.class), Types.BOOLEAN); + + final ExternalSchemaTranslator.InputResult result = + ExternalSchemaTranslator.fromExternal( + dataTypeFactoryWithRawType(BigDecimal.class), + inputTypeInfo, + Schema.newBuilder() + .primaryKeyNamed("pk", "f0") + .column("f1", DataTypes.BIGINT()) // reordered + .column("f0", DataTypes.INT()) + .columnByExpression("computed", "f1 + 42") + .column("f2", DataTypes.DECIMAL(10, 2)) // enriches + .columnByExpression("computed2", "f1 - 1") + .build()); + + assertEquals( + DataTypes.ROW( + DataTypes.FIELD("f0", DataTypes.INT()), + DataTypes.FIELD("f1", DataTypes.BIGINT()), + DataTypes.FIELD("f2", DataTypes.DECIMAL(10, 2)), + DataTypes.FIELD("f3", DataTypes.BOOLEAN())) + .notNull(), + result.getPhysicalDataType()); + + assertTrue(result.isTopLevelRecord()); + + assertEquals( + Schema.newBuilder() + .column("f0", DataTypes.INT()) + .column("f1", DataTypes.BIGINT()) + .column("f2", DataTypes.DECIMAL(10, 2)) + .column("f3", DataTypes.BOOLEAN()) + .columnByExpression("computed", "f1 + 42") + .columnByExpression("computed2", "f1 - 1") + .primaryKeyNamed("pk", "f0") + .build(), + result.getSchema()); + + assertArrayEquals(new int[] {1, 0, 4, 2, 5}, result.getProjections()); + } + + @Test + public void testInputFromAtomicWithPhysicalDeclaredSchema() { + final TypeInformation inputTypeInfo = Types.GENERIC(Row.class); + + final ExternalSchemaTranslator.InputResult result = + ExternalSchemaTranslator.fromExternal( + dataTypeFactoryWithRawType(Row.class), + inputTypeInfo, + Schema.newBuilder() + .columnByExpression("f0_0", "f0.f0_0") + .column( + "f0", + DataTypes.ROW( + DataTypes.FIELD("f0_0", DataTypes.INT()), + DataTypes.FIELD("f0_1", DataTypes.BOOLEAN()))) + .columnByExpression("f0_1", "f0.f0_1") + .build()); + + assertEquals( + DataTypes.ROW( + DataTypes.FIELD("f0_0", DataTypes.INT()), + DataTypes.FIELD("f0_1", DataTypes.BOOLEAN())), + result.getPhysicalDataType()); + + assertFalse(result.isTopLevelRecord()); + + assertEquals( + Schema.newBuilder() + .column( + "f0", + DataTypes.ROW( + DataTypes.FIELD("f0_0", DataTypes.INT()), + DataTypes.FIELD("f0_1", DataTypes.BOOLEAN()))) + .columnByExpression("f0_0", "f0.f0_0") + .columnByExpression("f0_1", "f0.f0_1") + .build(), + result.getSchema()); + + assertArrayEquals(new int[] {1, 0, 2}, result.getProjections()); + } + + @Test + public void testInvalidDeclaredSchemaColumn() { + final TypeInformation inputTypeInfo = Types.ROW(Types.INT, Types.LONG); + + try { + ExternalSchemaTranslator.fromExternal( + dataTypeFactory(), + inputTypeInfo, + Schema.newBuilder().column("INVALID", DataTypes.BIGINT()).build()); + } catch (ValidationException e) { + assertThat( + e, + containsMessage( + "Unable to find a field named 'INVALID' in the physical data type")); + } + } + + private static DataTypeFactory dataTypeFactoryWithRawType(Class rawType) { + final DataTypeFactoryMock dataTypeFactory = new DataTypeFactoryMock(); + dataTypeFactory.dataType = Optional.of(DataTypeFactoryMock.dummyRaw(rawType)); + return dataTypeFactory; + } + + private static DataTypeFactory dataTypeFactory() { + return new DataTypeFactoryMock(); + } +} diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/api/Schema.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/api/Schema.java index bdf81f2a68246e..bf99fbd8f1ba34 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/api/Schema.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/api/Schema.java @@ -61,7 +61,8 @@ * *

This class is used in the API and catalogs to define an unresolved schema that will be * translated to {@link ResolvedSchema}. Some methods of this class perform basic validation, - * however, the main validation happens during the resolution. + * however, the main validation happens during the resolution. Thus, an unresolved schema can be + * incomplete and might be enriched or merged with a different schema at a later stage. * *

Since an instance of this class is unresolved, it should not be directly persisted. The {@link * #toString()} shows only a summary of the contained objects. @@ -213,6 +214,12 @@ public Builder fromFields( return this; } + /** Adopts all columns from the given list. */ + public Builder fromColumns(List unresolvedColumns) { + columns.addAll(unresolvedColumns); + return this; + } + /** * Declares a physical column that is appended to this schema. * diff --git a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/SqlCallExpression.java b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/SqlCallExpression.java index ee0444904ebc56..637f4326b0ff75 100644 --- a/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/SqlCallExpression.java +++ b/flink-table/flink-table-common/src/main/java/org/apache/flink/table/expressions/SqlCallExpression.java @@ -22,6 +22,7 @@ import java.util.Collections; import java.util.List; +import java.util.Objects; /** * A call to a SQL expression. @@ -65,6 +66,23 @@ public R accept(ExpressionVisitor visitor) { return visitor.visit(this); } + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + SqlCallExpression that = (SqlCallExpression) o; + return sqlExpression.equals(that.sqlExpression); + } + + @Override + public int hashCode() { + return Objects.hash(sqlExpression); + } + @Override public String toString() { return asSummaryString(); From 662b0610cfe89df9aacfffc9d09e0d773e0a681b Mon Sep 17 00:00:00 2001 From: Xintong Song Date: Mon, 22 Mar 2021 17:26:57 +0800 Subject: [PATCH 5/7] [hotfix][core] Remove unnecessary not-freed check for releasing segments in memory manager. --- .../main/java/org/apache/flink/core/memory/MemorySegment.java | 2 ++ .../java/org/apache/flink/runtime/memory/MemoryManager.java | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/flink-core/src/main/java/org/apache/flink/core/memory/MemorySegment.java b/flink-core/src/main/java/org/apache/flink/core/memory/MemorySegment.java index 3d135e713bac5c..37543efdee369c 100644 --- a/flink-core/src/main/java/org/apache/flink/core/memory/MemorySegment.java +++ b/flink-core/src/main/java/org/apache/flink/core/memory/MemorySegment.java @@ -19,6 +19,7 @@ package org.apache.flink.core.memory; import org.apache.flink.annotation.Internal; +import org.apache.flink.annotation.VisibleForTesting; import org.apache.flink.util.Preconditions; import javax.annotation.Nonnull; @@ -219,6 +220,7 @@ public int size() { * * @return true, if the memory segment has been freed, false otherwise. */ + @VisibleForTesting public boolean isFreed() { return address > addressLimit; } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/memory/MemoryManager.java b/flink-runtime/src/main/java/org/apache/flink/runtime/memory/MemoryManager.java index 36e043dff22493..22e1fbce534902 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/memory/MemoryManager.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/memory/MemoryManager.java @@ -346,7 +346,7 @@ private MemorySegment releaseSegmentsForOwnerUntilNextOwner( while (segmentsIterator.hasNext()) { MemorySegment segment = segmentsIterator.next(); try { - if (segment == null || segment.isFreed()) { + if (segment == null) { continue; } Object nextOwner = segment.getOwner(); From e2ecd7a85b4b2a8029f615f89ee6c8f6e488905b Mon Sep 17 00:00:00 2001 From: Xintong Song Date: Thu, 18 Mar 2021 17:42:42 +0800 Subject: [PATCH 6/7] [FLINK-21800][core] Guard MemorySegment against concurrent frees. This closes #15273 --- .../flink/core/memory/MemorySegment.java | 29 +++++++++++------- .../OffHeapUnsafeMemorySegmentTest.java | 30 +++++++++++++++++++ 2 files changed, 49 insertions(+), 10 deletions(-) diff --git a/flink-core/src/main/java/org/apache/flink/core/memory/MemorySegment.java b/flink-core/src/main/java/org/apache/flink/core/memory/MemorySegment.java index 37543efdee369c..ce39a122654362 100644 --- a/flink-core/src/main/java/org/apache/flink/core/memory/MemorySegment.java +++ b/flink-core/src/main/java/org/apache/flink/core/memory/MemorySegment.java @@ -33,6 +33,7 @@ import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.ReadOnlyBufferException; +import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Consumer; import java.util.function.Function; @@ -135,6 +136,8 @@ public final class MemorySegment { */ private final boolean allowWrap; + private final AtomicBoolean isFreedAtomic; + /** * Creates a new memory segment that represents the memory of the byte array. * @@ -155,6 +158,7 @@ public final class MemorySegment { this.owner = owner; this.allowWrap = true; this.cleaner = null; + this.isFreedAtomic = new AtomicBoolean(false); } /** @@ -200,6 +204,7 @@ public final class MemorySegment { this.owner = owner; this.allowWrap = allowWrap; this.cleaner = cleaner; + this.isFreedAtomic = new AtomicBoolean(false); } // ------------------------------------------------------------------------ @@ -233,17 +238,21 @@ public boolean isFreed() { * memory segment object has become garbage collected. */ public void free() { - if (checkMultipleFree && isFreed()) { - throw new IllegalStateException("MemorySegment can be freed only once!"); - } - // this ensures we can place no more data and trigger - // the checks for the freed segment - address = addressLimit + 1; - if (cleaner != null) { - cleaner.run(); + if (isFreedAtomic.getAndSet(true)) { + // the segment has already been freed + if (checkMultipleFree) { + throw new IllegalStateException("MemorySegment can be freed only once!"); + } + } else { + // this ensures we can place no more data and trigger + // the checks for the freed segment + address = addressLimit + 1; + offHeapBuffer = null; // to enable GC of unsafe memory + if (cleaner != null) { + cleaner.run(); + cleaner = null; + } } - offHeapBuffer = null; // to enable GC of unsafe memory - cleaner = null; } /** diff --git a/flink-core/src/test/java/org/apache/flink/core/memory/OffHeapUnsafeMemorySegmentTest.java b/flink-core/src/test/java/org/apache/flink/core/memory/OffHeapUnsafeMemorySegmentTest.java index d9a21c1a70b042..297c0339931255 100644 --- a/flink-core/src/test/java/org/apache/flink/core/memory/OffHeapUnsafeMemorySegmentTest.java +++ b/flink-core/src/test/java/org/apache/flink/core/memory/OffHeapUnsafeMemorySegmentTest.java @@ -23,7 +23,10 @@ import org.junit.runners.Parameterized; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.atomic.AtomicInteger; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; /** Tests for the {@link MemorySegment} in off-heap mode using unsafe memory. */ @@ -58,4 +61,31 @@ public void testCallCleanerOnFree() { .free(); assertTrue(cleanerFuture.isDone()); } + + @Test + public void testCallCleanerOnceOnConcurrentFree() throws InterruptedException { + final AtomicInteger counter = new AtomicInteger(0); + final Runnable cleaner = + () -> { + try { + counter.incrementAndGet(); + // make the cleaner unlikely to finish before another invocation (if any) + Thread.sleep(10); + } catch (InterruptedException e) { + e.printStackTrace(); + } + }; + + final MemorySegment segment = + MemorySegmentFactory.allocateOffHeapUnsafeMemory(10, null, cleaner); + + final Thread t1 = new Thread(segment::free); + final Thread t2 = new Thread(segment::free); + t1.start(); + t2.start(); + t1.join(); + t2.join(); + + assertThat(counter.get(), is(1)); + } } From 8ba86cbf68aecdbc36fc25a9b85a1c8cb9c3ad31 Mon Sep 17 00:00:00 2001 From: lincoln-lil Date: Wed, 24 Mar 2021 17:08:37 +0800 Subject: [PATCH 7/7] [FLINK-21946] [table-planner-blink] FlinkRelMdUtil.numDistinctVals produces exceptional Double.NaN result when domainSize is in range(0,1) --- .../planner/plan/utils/FlinkRelMdUtil.scala | 2 +- .../FlinkRelMdDistinctRowCountTest.scala | 3 +-- .../plan/utils/FlinkRelMdUtilTest.scala | 7 ++++++ .../runtime/batch/sql/join/JoinITCase.scala | 23 ++++++++++++++++--- 4 files changed, 29 insertions(+), 6 deletions(-) diff --git a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRelMdUtil.scala b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRelMdUtil.scala index 9204bdc8c72a6c..b904d4461dc575 100644 --- a/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRelMdUtil.scala +++ b/flink-table/flink-table-planner-blink/src/main/scala/org/apache/flink/table/planner/plan/utils/FlinkRelMdUtil.scala @@ -236,7 +236,7 @@ object FlinkRelMdUtil { */ def numDistinctVals(domainSize: Double, numSelected: Double): Double = { val EPS = 1e-9 - if (Math.abs(1 / domainSize) < EPS) { + if (Math.abs(1 / domainSize) < EPS || domainSize < 1) { // ln(1+x) ~= x for small x val dSize = RelMdUtil.capInfinity(domainSize) val numSel = RelMdUtil.capInfinity(numSelected) diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdDistinctRowCountTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdDistinctRowCountTest.scala index d81e8cdcff01a3..308870298e9e66 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdDistinctRowCountTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/metadata/FlinkRelMdDistinctRowCountTest.scala @@ -21,7 +21,6 @@ package org.apache.flink.table.planner.plan.metadata import org.apache.flink.table.planner.plan.nodes.physical.batch.BatchPhysicalRank import org.apache.flink.table.planner.plan.utils.FlinkRelMdUtil -import org.apache.calcite.rel.metadata.RelMdUtil import org.apache.calcite.sql.fun.SqlStdOperatorTable._ import org.apache.calcite.util.ImmutableBitSet import org.junit.Assert._ @@ -82,7 +81,7 @@ class FlinkRelMdDistinctRowCountTest extends FlinkRelMdHandlerTestBase { mq.getDistinctRowCount(logicalValues, ImmutableBitSet.of(0, 1), null)) (0 until logicalValues.getRowType.getFieldCount).foreach { idx => - assertEquals(Double.NaN, mq.getDistinctRowCount(emptyValues, ImmutableBitSet.of(idx), null)) + assertEquals(1.0, mq.getDistinctRowCount(emptyValues, ImmutableBitSet.of(idx), null)) } } diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/utils/FlinkRelMdUtilTest.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/utils/FlinkRelMdUtilTest.scala index cd3949252b6b47..f90d22db9373bb 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/utils/FlinkRelMdUtilTest.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/plan/utils/FlinkRelMdUtilTest.scala @@ -28,6 +28,13 @@ class FlinkRelMdUtilTest { Assert.assertEquals( RelMdUtil.numDistinctVals(1e5, 1e4), FlinkRelMdUtil.numDistinctVals(1e5, 1e4)) + + Assert.assertEquals( + BigDecimal(0.31606027941427883), + BigDecimal.valueOf(FlinkRelMdUtil.numDistinctVals(0.5, 0.5))) + + // This case should be removed once CALCITE-4351 is fixed. + Assert.assertEquals(Double.NaN, RelMdUtil.numDistinctVals(0.5, 0.5)) } @Test diff --git a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/join/JoinITCase.scala b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/join/JoinITCase.scala index 148269a74e8345..39ebeca0ca4238 100644 --- a/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/join/JoinITCase.scala +++ b/flink-table/flink-table-planner-blink/src/test/scala/org/apache/flink/table/planner/runtime/batch/sql/join/JoinITCase.scala @@ -24,7 +24,8 @@ import org.apache.flink.api.common.typeinfo.Types import org.apache.flink.api.common.typeutils.TypeComparator import org.apache.flink.api.dag.Transformation import org.apache.flink.api.java.typeutils.{GenericTypeInfo, RowTypeInfo} -import org.apache.flink.streaming.api.transformations.{OneInputTransformation, LegacySinkTransformation, TwoInputTransformation} +import org.apache.flink.streaming.api.transformations.{LegacySinkTransformation, OneInputTransformation, TwoInputTransformation} +import org.apache.flink.table.api.internal.TableEnvironmentInternal import org.apache.flink.table.planner.delegation.PlannerBase import org.apache.flink.table.planner.expressions.utils.FuncWithOpen import org.apache.flink.table.planner.runtime.batch.sql.join.JoinType.{BroadcastHashJoin, HashJoin, JoinType, NestedLoopJoin, SortMergeJoin} @@ -34,12 +35,13 @@ import org.apache.flink.table.planner.runtime.utils.TestData._ import org.apache.flink.table.planner.sinks.CollectRowTableSink import org.apache.flink.table.planner.utils.{TestingStatementSet, TestingTableEnvironment} import org.apache.flink.table.runtime.operators.CodeGenOperatorFactory +import org.apache.flink.types.Row + import org.junit.runner.RunWith import org.junit.runners.Parameterized import org.junit.{Assert, Before, Test} -import java.util -import org.apache.flink.table.api.internal.TableEnvironmentInternal +import java.util import scala.collection.JavaConversions._ import scala.collection.Seq @@ -588,6 +590,21 @@ class JoinITCase(expectedJoinType: JoinType) extends BatchTestBase { Seq(row(2, 1.0), row(2, 1.0))) } + @Test + def testCorrelatedExist2(): Unit = { + val data: Seq[Row] = Seq( + row(0L), + row(123456L), + row(-123456L), + row(2147483647L), + row(-2147483647L)) + registerCollection("t1", data, new RowTypeInfo(LONG_TYPE_INFO), "f1") + + checkResult( + "select * from t1 o where exists (select 1 from t1 i where i.f1=o.f1 limit 0)", + Seq()) + } + @Test def testCorrelatedNotExist(): Unit = { checkResult(