From f99c4dd395b71877eb70e7fc743d109957205a3d Mon Sep 17 00:00:00 2001 From: Aljoscha Krettek Date: Wed, 14 Feb 2018 12:04:20 +0100 Subject: [PATCH] [FLINK-8411] Don't allow null in ListState.add()/addAll() --- .../kafka/FlinkKafkaConsumerBaseTest.java | 3 + .../streaming/state/RocksDBListState.java | 14 +- .../state/DefaultOperatorStateBackend.java | 1 + .../runtime/state/heap/HeapListState.java | 35 ++-- .../runtime/state/StateBackendTestBase.java | 160 ++++++++++++++++-- 5 files changed, 186 insertions(+), 27 deletions(-) diff --git a/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java b/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java index 5040966337af4..403e627e6d34e 100644 --- a/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java +++ b/flink-connectors/flink-connector-kafka-base/src/test/java/org/apache/flink/streaming/connectors/kafka/FlinkKafkaConsumerBaseTest.java @@ -696,6 +696,7 @@ public Iterable get() throws Exception { @Override public void add(T value) throws Exception { + Preconditions.checkNotNull(value, "You cannot add null to a ListState."); list.add(value); } @@ -717,6 +718,8 @@ public void update(List values) throws Exception { @Override public void addAll(List values) throws Exception { if (values != null) { + values.forEach(v -> Preconditions.checkNotNull(v, "You cannot add null to a ListState.")); + list.addAll(values); } } diff --git a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBListState.java b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBListState.java index 413615b2ea18e..f0481ec45be96 100644 --- a/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBListState.java +++ b/flink-contrib/flink-statebackend-rocksdb/src/main/java/org/apache/flink/contrib/streaming/state/RocksDBListState.java @@ -24,6 +24,7 @@ import org.apache.flink.core.memory.DataInputViewStreamWrapper; import org.apache.flink.core.memory.DataOutputViewStreamWrapper; import org.apache.flink.runtime.state.internal.InternalListState; +import org.apache.flink.util.Preconditions; import org.rocksdb.ColumnFamilyHandle; import org.rocksdb.RocksDBException; @@ -112,9 +113,7 @@ public Iterable get() { @Override public void add(V value) throws IOException { - if (value == null) { - return; - } + Preconditions.checkNotNull(value, "You cannot add null to a ListState."); try { writeCurrentKeyWithGroupAndNamespace(); @@ -169,9 +168,11 @@ public void mergeNamespaces(N target, Collection sources) throws Exception { @Override public void update(List values) throws Exception { + Preconditions.checkNotNull(values, "List of values to add cannot be null."); + clear(); - if (values != null && !values.isEmpty()) { + if (!values.isEmpty()) { try { writeCurrentKeyWithGroupAndNamespace(); byte[] key = keySerializationStream.toByteArray(); @@ -190,7 +191,9 @@ public void update(List values) throws Exception { @Override public void addAll(List values) throws Exception { - if (values != null && !values.isEmpty()) { + Preconditions.checkNotNull(values, "List of values to add cannot be null."); + + if (!values.isEmpty()) { try { writeCurrentKeyWithGroupAndNamespace(); byte[] key = keySerializationStream.toByteArray(); @@ -213,6 +216,7 @@ private byte[] getPreMergedValue(List values) throws IOException { keySerializationStream.reset(); boolean first = true; for (V value : values) { + Preconditions.checkNotNull(value, "You cannot add null to a ListState."); if (first) { first = false; } else { diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java index f4866439b467f..266483f93e940 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/DefaultOperatorStateBackend.java @@ -660,6 +660,7 @@ public Iterable get() { @Override public void add(S value) { + Preconditions.checkNotNull(value, "You cannot add null to a ListState."); internalList.add(value); } diff --git a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapListState.java b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapListState.java index dfc7362e0d994..f7b5cd2d5f08d 100644 --- a/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapListState.java +++ b/flink-runtime/src/main/java/org/apache/flink/runtime/state/heap/HeapListState.java @@ -67,9 +67,7 @@ public Iterable get() { @Override public void add(V value) { - if (value == null) { - return; - } + Preconditions.checkNotNull(value, "You cannot add null to a ListState."); final N namespace = currentNamespace; @@ -123,23 +121,36 @@ protected List mergeState(List a, List b) { @Override public void update(List values) throws Exception { - if (values != null && !values.isEmpty()) { - stateTable.put(currentNamespace, new ArrayList<>(values)); - } else { + Preconditions.checkNotNull(values, "List of values to add cannot be null."); + + if (values.isEmpty()) { clear(); + return; } + + List newStateList = new ArrayList<>(); + for (V v : values) { + Preconditions.checkNotNull(v, "You cannot add null to a ListState."); + newStateList.add(v); + } + + stateTable.put(currentNamespace, newStateList); } @Override public void addAll(List values) throws Exception { - if (values != null && !values.isEmpty()) { + Preconditions.checkNotNull(values, "List of values to add cannot be null."); + + if (!values.isEmpty()) { stateTable.transform(currentNamespace, values, (previousState, value) -> { - if (previousState != null) { - previousState.addAll(value); - return previousState; - } else { - return new ArrayList<>(value); + if (previousState == null) { + previousState = new ArrayList<>(); + } + for (V v : value) { + Preconditions.checkNotNull(v, "You cannot add null to a ListState."); + previousState.add(v); } + return previousState; }); } } diff --git a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java index ad69ae8fcf2bc..8acefa4f62c08 100644 --- a/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java +++ b/flink-runtime/src/test/java/org/apache/flink/runtime/state/StateBackendTestBase.java @@ -1295,36 +1295,179 @@ public void testListState() throws Exception { backend.dispose(); } + /** + * This test verifies that all ListState implementations are consistent in not allowing + * adding {@code null}. + */ @Test - public void testListStateAPIs() throws Exception { - + public void testListStateAddNull() throws Exception { AbstractKeyedStateBackend keyedBackend = createKeyedBackend(StringSerializer.INSTANCE); final ListStateDescriptor stateDescr = new ListStateDescriptor<>("my-state", Long.class); try { ListState state = - keyedBackend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, stateDescr); + keyedBackend.getPartitionedState( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + stateDescr); keyedBackend.setCurrentKey("abc"); assertNull(state.get()); + + expectedException.expect(NullPointerException.class); state.add(null); + } finally { + keyedBackend.close(); + keyedBackend.dispose(); + } + } + + /** + * This test verifies that all ListState implementations are consistent in not allowing + * {@link ListState#addAll(List)} to be called with {@code null} entries in the list of entries + * to add. + */ + @Test + public void testListStateAddAllNullEntries() throws Exception { + AbstractKeyedStateBackend keyedBackend = createKeyedBackend(StringSerializer.INSTANCE); + + final ListStateDescriptor stateDescr = new ListStateDescriptor<>("my-state", Long.class); + + try { + ListState state = + keyedBackend.getPartitionedState( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + stateDescr); + + keyedBackend.setCurrentKey("abc"); assertNull(state.get()); + expectedException.expect(NullPointerException.class); + + List adding = new ArrayList<>(); + adding.add(3L); + adding.add(null); + adding.add(5L); + state.addAll(adding); + } finally { + keyedBackend.close(); + keyedBackend.dispose(); + } + } + + /** + * This test verifies that all ListState implementations are consistent in not allowing + * {@link ListState#addAll(List)} to be called with {@code null}. + */ + @Test + public void testListStateAddAllNull() throws Exception { + AbstractKeyedStateBackend keyedBackend = createKeyedBackend(StringSerializer.INSTANCE); + + final ListStateDescriptor stateDescr = new ListStateDescriptor<>("my-state", Long.class); + + try { + ListState state = + keyedBackend.getPartitionedState( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + stateDescr); + + keyedBackend.setCurrentKey("abc"); + assertNull(state.get()); + + expectedException.expect(NullPointerException.class); + state.addAll(null); + } finally { + keyedBackend.close(); + keyedBackend.dispose(); + } + } + + /** + * This test verifies that all ListState implementations are consistent in not allowing + * {@link ListState#addAll(List)} to be called with {@code null} entries in the list of entries + * to add. + */ + @Test + public void testListStateUpdateNullEntries() throws Exception { + AbstractKeyedStateBackend keyedBackend = createKeyedBackend(StringSerializer.INSTANCE); + + final ListStateDescriptor stateDescr = new ListStateDescriptor<>("my-state", Long.class); + + try { + ListState state = + keyedBackend.getPartitionedState( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + stateDescr); + + keyedBackend.setCurrentKey("abc"); + assertNull(state.get()); + + expectedException.expect(NullPointerException.class); + + List adding = new ArrayList<>(); + adding.add(3L); + adding.add(null); + adding.add(5L); + state.update(adding); + } finally { + keyedBackend.close(); + keyedBackend.dispose(); + } + } + + /** + * This test verifies that all ListState implementations are consistent in not allowing + * {@link ListState#addAll(List)} to be called with {@code null}. + */ + @Test + public void testListStateUpdateNull() throws Exception { + AbstractKeyedStateBackend keyedBackend = createKeyedBackend(StringSerializer.INSTANCE); + + final ListStateDescriptor stateDescr = new ListStateDescriptor<>("my-state", Long.class); + + try { + ListState state = + keyedBackend.getPartitionedState( + VoidNamespace.INSTANCE, + VoidNamespaceSerializer.INSTANCE, + stateDescr); + + keyedBackend.setCurrentKey("abc"); + assertNull(state.get()); + + expectedException.expect(NullPointerException.class); + state.update(null); + } finally { + keyedBackend.close(); + keyedBackend.dispose(); + } + } + + @Test + public void testListStateAPIs() throws Exception { + + AbstractKeyedStateBackend keyedBackend = createKeyedBackend(StringSerializer.INSTANCE); + + final ListStateDescriptor stateDescr = new ListStateDescriptor<>("my-state", Long.class); + + try { + ListState state = + keyedBackend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, stateDescr); + keyedBackend.setCurrentKey("def"); assertNull(state.get()); state.add(17L); state.add(11L); assertThat(state.get(), containsInAnyOrder(17L, 11L)); - // update(null) should remain the value null - state.update(null); - assertNull(state.get()); // update(emptyList) should remain the value null state.update(Collections.emptyList()); assertNull(state.get()); state.update(Arrays.asList(10L, 16L)); assertThat(state.get(), containsInAnyOrder(16L, 10L)); - state.add(null); assertThat(state.get(), containsInAnyOrder(16L, 10L)); keyedBackend.setCurrentKey("abc"); @@ -1332,13 +1475,11 @@ public void testListStateAPIs() throws Exception { keyedBackend.setCurrentKey("g"); assertNull(state.get()); - state.addAll(null); assertNull(state.get()); state.addAll(Collections.emptyList()); assertNull(state.get()); state.addAll(Arrays.asList(3L, 4L)); assertThat(state.get(), containsInAnyOrder(3L, 4L)); - state.addAll(null); assertThat(state.get(), containsInAnyOrder(3L, 4L)); state.addAll(new ArrayList<>()); assertThat(state.get(), containsInAnyOrder(3L, 4L)); @@ -1347,7 +1488,6 @@ public void testListStateAPIs() throws Exception { state.addAll(new ArrayList<>()); assertThat(state.get(), containsInAnyOrder(3L, 4L, 5L, 6L)); - state.add(null); assertThat(state.get(), containsInAnyOrder(3L, 4L, 5L, 6L)); state.update(Arrays.asList(1L, 2L)); assertThat(state.get(), containsInAnyOrder(1L, 2L));