Skip to content

Commit

Permalink
[FLINK-9034] [core] StateDescriptor does not throw away TypeInformati…
Browse files Browse the repository at this point in the history
…on upon serialization.

Throwing away TypeInformation upon serialization was previously done because the type
information was not serializable. Now that it is serializable, we can (and should) keep
it to provide consistent user experience, where all serializers respect the ExecutionConfig.
  • Loading branch information
StephanEwen committed Mar 22, 2018
1 parent 87dcc89 commit 87d31f5
Show file tree
Hide file tree
Showing 6 changed files with 200 additions and 239 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
import org.apache.flink.util.Preconditions;

import javax.annotation.Nullable;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
Expand All @@ -35,6 +37,7 @@
import java.io.Serializable;

import static org.apache.flink.util.Preconditions.checkNotNull;
import static org.apache.flink.util.Preconditions.checkState;

/**
* Base class for state descriptors. A {@code StateDescriptor} is used for creating partitioned
Expand Down Expand Up @@ -76,19 +79,24 @@ public enum Type {
protected final String name;

/** The serializer for the type. May be eagerly initialized in the constructor,
* or lazily once the type is serialized or an ExecutionConfig is provided. */
* or lazily once the {@link #initializeSerializerUnlessSet(ExecutionConfig)} method
* is called. */
@Nullable
protected TypeSerializer<T> serializer;

/** The type information describing the value type. Only used to if the serializer
* is created lazily. */
@Nullable
private TypeInformation<T> typeInfo;

/** Name for queries against state created from this StateDescriptor. */
@Nullable
private String queryableStateName;

/** The default value returned by the state when no other value is bound to a key. */
@Nullable
protected transient T defaultValue;

/** The type information describing the value type. Only used to lazily create the serializer
* and dropped during serialization */
private transient TypeInformation<T> typeInfo;

// ------------------------------------------------------------------------

/**
Expand All @@ -99,7 +107,7 @@ public enum Type {
* @param defaultValue The default value that will be set when requesting state without setting
* a value before.
*/
protected StateDescriptor(String name, TypeSerializer<T> serializer, T defaultValue) {
protected StateDescriptor(String name, TypeSerializer<T> serializer, @Nullable T defaultValue) {
this.name = checkNotNull(name, "name must not be null");
this.serializer = checkNotNull(serializer, "serializer must not be null");
this.defaultValue = defaultValue;
Expand All @@ -113,7 +121,7 @@ protected StateDescriptor(String name, TypeSerializer<T> serializer, T defaultVa
* @param defaultValue The default value that will be set when requesting state without setting
* a value before.
*/
protected StateDescriptor(String name, TypeInformation<T> typeInfo, T defaultValue) {
protected StateDescriptor(String name, TypeInformation<T> typeInfo, @Nullable T defaultValue) {
this.name = checkNotNull(name, "name must not be null");
this.typeInfo = checkNotNull(typeInfo, "type information must not be null");
this.defaultValue = defaultValue;
Expand All @@ -130,7 +138,7 @@ protected StateDescriptor(String name, TypeInformation<T> typeInfo, T defaultVal
* @param defaultValue The default value that will be set when requesting state without setting
* a value before.
*/
protected StateDescriptor(String name, Class<T> type, T defaultValue) {
protected StateDescriptor(String name, Class<T> type, @Nullable T defaultValue) {
this.name = checkNotNull(name, "name must not be null");
checkNotNull(type, "type class must not be null");

Expand Down Expand Up @@ -208,6 +216,7 @@ public void setQueryable(String queryableStateName) {
*
* @return Queryable state name or <code>null</code> if not set.
*/
@Nullable
public String getQueryableStateName() {
return queryableStateName;
}
Expand Down Expand Up @@ -249,12 +258,13 @@ public boolean isSerializerInitialized() {
*/
public void initializeSerializerUnlessSet(ExecutionConfig executionConfig) {
if (serializer == null) {
if (typeInfo != null) {
serializer = typeInfo.createSerializer(executionConfig);
} else {
throw new IllegalStateException(
"Cannot initialize serializer after TypeInformation was dropped during serialization");
}
checkState(typeInfo != null, "no serializer and no type info");

// instantiate the serializer
serializer = typeInfo.createSerializer(executionConfig);

// we can drop the type info now, no longer needed
typeInfo = null;
}
}

Expand Down Expand Up @@ -285,9 +295,6 @@ public String toString() {
// ------------------------------------------------------------------------

private void writeObject(final ObjectOutputStream out) throws IOException {
// make sure we have a serializer before the type information gets lost
initializeSerializerUnlessSet(new ExecutionConfig());

// write all the non-transient fields
out.defaultWriteObject();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,9 @@
package org.apache.flink.api.common.state;

import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.TaskInfo;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.ListSerializer;
import org.apache.flink.api.common.typeutils.base.StringSerializer;
import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer;
import org.apache.flink.core.fs.Path;
import org.apache.flink.core.testutils.CommonTestUtils;

import org.junit.Test;
Expand All @@ -35,15 +32,14 @@
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

/**
* Tests for the {@link ListStateDescriptor}.
*/
public class ListStateDescriptorTest {

@Test
public void testValueStateDescriptorEagerSerializer() throws Exception {
public void testListStateDescriptor() throws Exception {

TypeSerializer<String> serializer = new KryoSerializer<>(String.class, new ExecutionConfig());

Expand All @@ -66,48 +62,6 @@ public void testValueStateDescriptorEagerSerializer() throws Exception {
assertEquals(serializer, copy.getElementSerializer());
}

@Test
public void testValueStateDescriptorLazySerializer() throws Exception {
// some different registered value
ExecutionConfig cfg = new ExecutionConfig();
cfg.registerKryoType(TaskInfo.class);

ListStateDescriptor<Path> descr =
new ListStateDescriptor<>("testName", Path.class);

try {
descr.getSerializer();
fail("should cause an exception");
} catch (IllegalStateException ignored) {}

descr.initializeSerializerUnlessSet(cfg);

assertNotNull(descr.getSerializer());
assertTrue(descr.getSerializer() instanceof ListSerializer);

assertNotNull(descr.getElementSerializer());
assertTrue(descr.getElementSerializer() instanceof KryoSerializer);

assertTrue(((KryoSerializer<?>) descr.getElementSerializer()).getKryo().getRegistration(TaskInfo.class).getId() > 0);
}

@Test
public void testValueStateDescriptorAutoSerializer() throws Exception {

ListStateDescriptor<String> descr =
new ListStateDescriptor<>("testName", String.class);

ListStateDescriptor<String> copy = CommonTestUtils.createCopySerializable(descr);

assertEquals("testName", copy.getName());

assertNotNull(copy.getSerializer());
assertTrue(copy.getSerializer() instanceof ListSerializer);

assertNotNull(copy.getElementSerializer());
assertEquals(StringSerializer.INSTANCE, copy.getElementSerializer());
}

/**
* FLINK-6775.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,9 @@
package org.apache.flink.api.common.state;

import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.TaskInfo;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.LongSerializer;
import org.apache.flink.api.common.typeutils.base.MapSerializer;
import org.apache.flink.api.common.typeutils.base.StringSerializer;
import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer;
import org.apache.flink.core.fs.Path;
import org.apache.flink.core.testutils.CommonTestUtils;

import org.junit.Test;
Expand All @@ -36,15 +32,14 @@
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;

/**
* Tests for the {@link MapStateDescriptor}.
*/
public class MapStateDescriptorTest {

@Test
public void testMapStateDescriptorEagerSerializer() throws Exception {
public void testMapStateDescriptor() throws Exception {

TypeSerializer<Integer> keySerializer = new KryoSerializer<>(Integer.class, new ExecutionConfig());
TypeSerializer<String> valueSerializer = new KryoSerializer<>(String.class, new ExecutionConfig());
Expand Down Expand Up @@ -72,53 +67,6 @@ public void testMapStateDescriptorEagerSerializer() throws Exception {
assertEquals(valueSerializer, copy.getValueSerializer());
}

@Test
public void testMapStateDescriptorLazySerializer() throws Exception {
// some different registered value
ExecutionConfig cfg = new ExecutionConfig();
cfg.registerKryoType(TaskInfo.class);

MapStateDescriptor<Path, String> descr =
new MapStateDescriptor<>("testName", Path.class, String.class);

try {
descr.getSerializer();
fail("should cause an exception");
} catch (IllegalStateException ignored) {}

descr.initializeSerializerUnlessSet(cfg);

assertNotNull(descr.getSerializer());
assertTrue(descr.getSerializer() instanceof MapSerializer);

assertNotNull(descr.getKeySerializer());
assertTrue(descr.getKeySerializer() instanceof KryoSerializer);

assertTrue(((KryoSerializer<?>) descr.getKeySerializer()).getKryo().getRegistration(TaskInfo.class).getId() > 0);

assertNotNull(descr.getValueSerializer());
assertTrue(descr.getValueSerializer() instanceof StringSerializer);
}

@Test
public void testMapStateDescriptorAutoSerializer() throws Exception {

MapStateDescriptor<String, Long> descr =
new MapStateDescriptor<>("testName", String.class, Long.class);

MapStateDescriptor<String, Long> copy = CommonTestUtils.createCopySerializable(descr);

assertEquals("testName", copy.getName());

assertNotNull(copy.getSerializer());
assertTrue(copy.getSerializer() instanceof MapSerializer);

assertNotNull(copy.getKeySerializer());
assertEquals(StringSerializer.INSTANCE, copy.getKeySerializer());
assertNotNull(copy.getValueSerializer());
assertEquals(LongSerializer.INSTANCE, copy.getValueSerializer());
}

/**
* FLINK-6775.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,9 @@
package org.apache.flink.api.common.state;

import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.TaskInfo;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.api.common.typeutils.base.StringSerializer;
import org.apache.flink.api.java.typeutils.runtime.kryo.KryoSerializer;
import org.apache.flink.core.fs.Path;
import org.apache.flink.core.testutils.CommonTestUtils;
import org.apache.flink.util.TestLogger;

Expand All @@ -33,20 +30,16 @@
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNotSame;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.Mockito.mock;

/**
* Tests for the {@link ReducingStateDescriptor}.
*/
public class ReducingStateDescriptorTest extends TestLogger {

@Test
public void testValueStateDescriptorEagerSerializer() throws Exception {
public void testReducingStateDescriptor() throws Exception {

@SuppressWarnings("unchecked")
ReduceFunction<String> reducer = mock(ReduceFunction.class);
ReduceFunction<String> reducer = (a, b) -> a;

TypeSerializer<String> serializer = new KryoSerializer<>(String.class, new ExecutionConfig());

Expand All @@ -56,6 +49,7 @@ public void testValueStateDescriptorEagerSerializer() throws Exception {
assertEquals("testName", descr.getName());
assertNotNull(descr.getSerializer());
assertEquals(serializer, descr.getSerializer());
assertEquals(reducer, descr.getReduceFunction());

ReducingStateDescriptor<String> copy = CommonTestUtils.createCopySerializable(descr);

Expand All @@ -64,48 +58,6 @@ public void testValueStateDescriptorEagerSerializer() throws Exception {
assertEquals(serializer, copy.getSerializer());
}

@Test
public void testValueStateDescriptorLazySerializer() throws Exception {

@SuppressWarnings("unchecked")
ReduceFunction<Path> reducer = mock(ReduceFunction.class);

// some different registered value
ExecutionConfig cfg = new ExecutionConfig();
cfg.registerKryoType(TaskInfo.class);

ReducingStateDescriptor<Path> descr =
new ReducingStateDescriptor<>("testName", reducer, Path.class);

try {
descr.getSerializer();
fail("should cause an exception");
} catch (IllegalStateException ignored) {}

descr.initializeSerializerUnlessSet(cfg);

assertNotNull(descr.getSerializer());
assertTrue(descr.getSerializer() instanceof KryoSerializer);

assertTrue(((KryoSerializer<?>) descr.getSerializer()).getKryo().getRegistration(TaskInfo.class).getId() > 0);
}

@Test
public void testValueStateDescriptorAutoSerializer() throws Exception {

@SuppressWarnings("unchecked")
ReduceFunction<String> reducer = mock(ReduceFunction.class);

ReducingStateDescriptor<String> descr =
new ReducingStateDescriptor<>("testName", reducer, String.class);

ReducingStateDescriptor<String> copy = CommonTestUtils.createCopySerializable(descr);

assertEquals("testName", copy.getName());
assertNotNull(copy.getSerializer());
assertEquals(StringSerializer.INSTANCE, copy.getSerializer());
}

/**
* FLINK-6775.
*
Expand Down
Loading

0 comments on commit 87d31f5

Please sign in to comment.