Skip to content

Commit

Permalink
[FLINK-8411] Don't allow null in ListState.add()/addAll()
Browse files Browse the repository at this point in the history
  • Loading branch information
aljoscha committed Feb 17, 2018
1 parent fafd5b6 commit f99c4dd
Show file tree
Hide file tree
Showing 5 changed files with 186 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,7 @@ public Iterable<T> get() throws Exception {

@Override
public void add(T value) throws Exception {
Preconditions.checkNotNull(value, "You cannot add null to a ListState.");
list.add(value);
}

Expand All @@ -717,6 +718,8 @@ public void update(List<T> values) throws Exception {
@Override
public void addAll(List<T> values) throws Exception {
if (values != null) {
values.forEach(v -> Preconditions.checkNotNull(v, "You cannot add null to a ListState."));

list.addAll(values);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.flink.core.memory.DataInputViewStreamWrapper;
import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
import org.apache.flink.runtime.state.internal.InternalListState;
import org.apache.flink.util.Preconditions;

import org.rocksdb.ColumnFamilyHandle;
import org.rocksdb.RocksDBException;
Expand Down Expand Up @@ -112,9 +113,7 @@ public Iterable<V> get() {

@Override
public void add(V value) throws IOException {
if (value == null) {
return;
}
Preconditions.checkNotNull(value, "You cannot add null to a ListState.");

try {
writeCurrentKeyWithGroupAndNamespace();
Expand Down Expand Up @@ -169,9 +168,11 @@ public void mergeNamespaces(N target, Collection<N> sources) throws Exception {

@Override
public void update(List<V> values) throws Exception {
Preconditions.checkNotNull(values, "List of values to add cannot be null.");

clear();

if (values != null && !values.isEmpty()) {
if (!values.isEmpty()) {
try {
writeCurrentKeyWithGroupAndNamespace();
byte[] key = keySerializationStream.toByteArray();
Expand All @@ -190,7 +191,9 @@ public void update(List<V> values) throws Exception {

@Override
public void addAll(List<V> values) throws Exception {
if (values != null && !values.isEmpty()) {
Preconditions.checkNotNull(values, "List of values to add cannot be null.");

if (!values.isEmpty()) {
try {
writeCurrentKeyWithGroupAndNamespace();
byte[] key = keySerializationStream.toByteArray();
Expand All @@ -213,6 +216,7 @@ private byte[] getPreMergedValue(List<V> values) throws IOException {
keySerializationStream.reset();
boolean first = true;
for (V value : values) {
Preconditions.checkNotNull(value, "You cannot add null to a ListState.");
if (first) {
first = false;
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,7 @@ public Iterable<S> get() {

@Override
public void add(S value) {
Preconditions.checkNotNull(value, "You cannot add null to a ListState.");
internalList.add(value);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ public Iterable<V> get() {

@Override
public void add(V value) {
if (value == null) {
return;
}
Preconditions.checkNotNull(value, "You cannot add null to a ListState.");

final N namespace = currentNamespace;

Expand Down Expand Up @@ -123,23 +121,36 @@ protected List<V> mergeState(List<V> a, List<V> b) {

@Override
public void update(List<V> values) throws Exception {
if (values != null && !values.isEmpty()) {
stateTable.put(currentNamespace, new ArrayList<>(values));
} else {
Preconditions.checkNotNull(values, "List of values to add cannot be null.");

if (values.isEmpty()) {
clear();
return;
}

List<V> newStateList = new ArrayList<>();
for (V v : values) {
Preconditions.checkNotNull(v, "You cannot add null to a ListState.");
newStateList.add(v);
}

stateTable.put(currentNamespace, newStateList);
}

@Override
public void addAll(List<V> values) throws Exception {
if (values != null && !values.isEmpty()) {
Preconditions.checkNotNull(values, "List of values to add cannot be null.");

if (!values.isEmpty()) {
stateTable.transform(currentNamespace, values, (previousState, value) -> {
if (previousState != null) {
previousState.addAll(value);
return previousState;
} else {
return new ArrayList<>(value);
if (previousState == null) {
previousState = new ArrayList<>();
}
for (V v : value) {
Preconditions.checkNotNull(v, "You cannot add null to a ListState.");
previousState.add(v);
}
return previousState;
});
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1295,50 +1295,191 @@ public void testListState() throws Exception {
backend.dispose();
}

/**
* This test verifies that all ListState implementations are consistent in not allowing
* adding {@code null}.
*/
@Test
public void testListStateAPIs() throws Exception {

public void testListStateAddNull() throws Exception {
AbstractKeyedStateBackend<String> keyedBackend = createKeyedBackend(StringSerializer.INSTANCE);

final ListStateDescriptor<Long> stateDescr = new ListStateDescriptor<>("my-state", Long.class);

try {
ListState<Long> state =
keyedBackend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, stateDescr);
keyedBackend.getPartitionedState(
VoidNamespace.INSTANCE,
VoidNamespaceSerializer.INSTANCE,
stateDescr);

keyedBackend.setCurrentKey("abc");
assertNull(state.get());

expectedException.expect(NullPointerException.class);
state.add(null);
} finally {
keyedBackend.close();
keyedBackend.dispose();
}
}

/**
* This test verifies that all ListState implementations are consistent in not allowing
* {@link ListState#addAll(List)} to be called with {@code null} entries in the list of entries
* to add.
*/
@Test
public void testListStateAddAllNullEntries() throws Exception {
AbstractKeyedStateBackend<String> keyedBackend = createKeyedBackend(StringSerializer.INSTANCE);

final ListStateDescriptor<Long> stateDescr = new ListStateDescriptor<>("my-state", Long.class);

try {
ListState<Long> state =
keyedBackend.getPartitionedState(
VoidNamespace.INSTANCE,
VoidNamespaceSerializer.INSTANCE,
stateDescr);

keyedBackend.setCurrentKey("abc");
assertNull(state.get());

expectedException.expect(NullPointerException.class);

List<Long> adding = new ArrayList<>();
adding.add(3L);
adding.add(null);
adding.add(5L);
state.addAll(adding);
} finally {
keyedBackend.close();
keyedBackend.dispose();
}
}

/**
* This test verifies that all ListState implementations are consistent in not allowing
* {@link ListState#addAll(List)} to be called with {@code null}.
*/
@Test
public void testListStateAddAllNull() throws Exception {
AbstractKeyedStateBackend<String> keyedBackend = createKeyedBackend(StringSerializer.INSTANCE);

final ListStateDescriptor<Long> stateDescr = new ListStateDescriptor<>("my-state", Long.class);

try {
ListState<Long> state =
keyedBackend.getPartitionedState(
VoidNamespace.INSTANCE,
VoidNamespaceSerializer.INSTANCE,
stateDescr);

keyedBackend.setCurrentKey("abc");
assertNull(state.get());

expectedException.expect(NullPointerException.class);
state.addAll(null);
} finally {
keyedBackend.close();
keyedBackend.dispose();
}
}

/**
* This test verifies that all ListState implementations are consistent in not allowing
* {@link ListState#addAll(List)} to be called with {@code null} entries in the list of entries
* to add.
*/
@Test
public void testListStateUpdateNullEntries() throws Exception {
AbstractKeyedStateBackend<String> keyedBackend = createKeyedBackend(StringSerializer.INSTANCE);

final ListStateDescriptor<Long> stateDescr = new ListStateDescriptor<>("my-state", Long.class);

try {
ListState<Long> state =
keyedBackend.getPartitionedState(
VoidNamespace.INSTANCE,
VoidNamespaceSerializer.INSTANCE,
stateDescr);

keyedBackend.setCurrentKey("abc");
assertNull(state.get());

expectedException.expect(NullPointerException.class);

List<Long> adding = new ArrayList<>();
adding.add(3L);
adding.add(null);
adding.add(5L);
state.update(adding);
} finally {
keyedBackend.close();
keyedBackend.dispose();
}
}

/**
* This test verifies that all ListState implementations are consistent in not allowing
* {@link ListState#addAll(List)} to be called with {@code null}.
*/
@Test
public void testListStateUpdateNull() throws Exception {
AbstractKeyedStateBackend<String> keyedBackend = createKeyedBackend(StringSerializer.INSTANCE);

final ListStateDescriptor<Long> stateDescr = new ListStateDescriptor<>("my-state", Long.class);

try {
ListState<Long> state =
keyedBackend.getPartitionedState(
VoidNamespace.INSTANCE,
VoidNamespaceSerializer.INSTANCE,
stateDescr);

keyedBackend.setCurrentKey("abc");
assertNull(state.get());

expectedException.expect(NullPointerException.class);
state.update(null);
} finally {
keyedBackend.close();
keyedBackend.dispose();
}
}

@Test
public void testListStateAPIs() throws Exception {

AbstractKeyedStateBackend<String> keyedBackend = createKeyedBackend(StringSerializer.INSTANCE);

final ListStateDescriptor<Long> stateDescr = new ListStateDescriptor<>("my-state", Long.class);

try {
ListState<Long> state =
keyedBackend.getPartitionedState(VoidNamespace.INSTANCE, VoidNamespaceSerializer.INSTANCE, stateDescr);

keyedBackend.setCurrentKey("def");
assertNull(state.get());
state.add(17L);
state.add(11L);
assertThat(state.get(), containsInAnyOrder(17L, 11L));
// update(null) should remain the value null
state.update(null);
assertNull(state.get());
// update(emptyList) should remain the value null
state.update(Collections.emptyList());
assertNull(state.get());
state.update(Arrays.asList(10L, 16L));
assertThat(state.get(), containsInAnyOrder(16L, 10L));
state.add(null);
assertThat(state.get(), containsInAnyOrder(16L, 10L));

keyedBackend.setCurrentKey("abc");
assertNull(state.get());

keyedBackend.setCurrentKey("g");
assertNull(state.get());
state.addAll(null);
assertNull(state.get());
state.addAll(Collections.emptyList());
assertNull(state.get());
state.addAll(Arrays.asList(3L, 4L));
assertThat(state.get(), containsInAnyOrder(3L, 4L));
state.addAll(null);
assertThat(state.get(), containsInAnyOrder(3L, 4L));
state.addAll(new ArrayList<>());
assertThat(state.get(), containsInAnyOrder(3L, 4L));
Expand All @@ -1347,7 +1488,6 @@ public void testListStateAPIs() throws Exception {
state.addAll(new ArrayList<>());
assertThat(state.get(), containsInAnyOrder(3L, 4L, 5L, 6L));

state.add(null);
assertThat(state.get(), containsInAnyOrder(3L, 4L, 5L, 6L));
state.update(Arrays.asList(1L, 2L));
assertThat(state.get(), containsInAnyOrder(1L, 2L));
Expand Down

0 comments on commit f99c4dd

Please sign in to comment.