Skip to content

Commit

Permalink
[FLINK-6948] [serializer] Harden EnumValueSerializer to detect change…
Browse files Browse the repository at this point in the history
…d enum indices

This PR changes the seriailization format of the ScalaEnumSerializerConfigSnapshot to also include the
ordinal value of an enum value when being deserialized. This allows to detect if the ordinal values
have been changed and, thus, if migration is required.

IMPORTANT: This PR changes the serialization format of ScalaEnumSerializerConfigSnapshot.

Remove backwards compatibility path for 1.3.1

This closes apache#4142.
  • Loading branch information
tillrohrmann authored and tzulitai committed Jun 20, 2017
1 parent e520023 commit 228faf8
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ import org.apache.flink.api.java.typeutils.runtime.{DataInputViewStream, DataOut
import org.apache.flink.core.memory.{DataInputView, DataOutputView}
import org.apache.flink.util.{InstantiationUtil, Preconditions}

import scala.collection.mutable.ListBuffer

/**
* Serializer for [[Enumeration]] values.
*/
Expand Down Expand Up @@ -88,18 +90,23 @@ class EnumValueSerializer[E <: Enumeration](val enum: E) extends TypeSerializer[
case enumSerializerConfigSnapshot: EnumValueSerializer.ScalaEnumSerializerConfigSnapshot[_] =>
val enumClass = enum.getClass.asInstanceOf[Class[E]]
if (enumClass.equals(enumSerializerConfigSnapshot.getEnumClass)) {
val previousEnumConstants = enumSerializerConfigSnapshot.getEnumConstants
val previousEnumConstants:List[(String, Int)] =
enumSerializerConfigSnapshot.getEnumConstants

if (previousEnumConstants != null) {
for (i <- enum.values.iterator) {
// skip the check for all newly added fields
if (i.id < previousEnumConstants.length) {
if (!previousEnumConstants(i.id).equals(i.toString)) {
// compatible only if new enum constants are only appended,
// and original constants must be in the exact same order

for ((previousEnumConstant, idx) <- previousEnumConstants) {
val enumValue = try {
enum(idx)
} catch {
case _: NoSuchElementException =>
// couldn't find an enum value for the given index
return CompatibilityResult.requiresMigration()
}
}

if (!previousEnumConstant.equals(enumValue.toString)) {
// compatible only if new enum constants are only appended,
// and original constants must be in the exact same order
return CompatibilityResult.requiresMigration()
}
}
}
Expand All @@ -120,12 +127,12 @@ object EnumValueSerializer {
extends TypeSerializerConfigSnapshot {

var enumClass: Class[E] = _
var enumConstants: List[String] = _
var enumConstants: List[(String, Int)] = _

def this(enum: E) = {
this()
this.enumClass = Preconditions.checkNotNull(enum).getClass.asInstanceOf[Class[E]]
this.enumConstants = enum.values.toList.map(_.toString)
this.enumConstants = enum.values.toList.map(x => (x.toString, x.id))
}

override def write(out: DataOutputView): Unit = {
Expand All @@ -135,7 +142,12 @@ object EnumValueSerializer {
val outViewWrapper = new DataOutputViewStream(out)
try {
InstantiationUtil.serializeObject(outViewWrapper, enumClass)
InstantiationUtil.serializeObject(outViewWrapper, enumConstants)

out.writeInt(enumConstants.length)
for ((name, idx) <- enumConstants) {
out.writeUTF(name)
out.writeInt(idx)
}
} finally if (outViewWrapper != null) outViewWrapper.close()
}
}
Expand All @@ -150,8 +162,24 @@ object EnumValueSerializer {
enumClass = InstantiationUtil.deserializeObject(
inViewWrapper, getUserCodeClassLoader)

enumConstants = InstantiationUtil.deserializeObject(
inViewWrapper, getUserCodeClassLoader)
if (getReadVersion == 1) {
// read null from input stream
InstantiationUtil.deserializeObject(inViewWrapper, getUserCodeClassLoader)
enumConstants = List()
} else if (getReadVersion == 2) {
val length = in.readInt()
val listBuffer = ListBuffer[(String, Int)]()

for (_ <- 0 until length) {
val name = in.readUTF()
val idx = in.readInt()
listBuffer += ((name, idx))
}

enumConstants = listBuffer.toList
} else {
throw new IOException(s"Cannot deserialize ${getClass.getSimpleName} with version $getReadVersion.")
}
} catch {
case e: ClassNotFoundException =>
throw new IOException("The requested enum class cannot be found in classpath.", e)
Expand All @@ -164,7 +192,7 @@ object EnumValueSerializer {

def getEnumClass: Class[E] = enumClass

def getEnumConstants: List[String] = enumConstants
def getEnumConstants: List[(String, Int)] = enumConstants

override def equals(obj: scala.Any): Boolean = {
if (obj == this) {
Expand All @@ -184,10 +212,13 @@ object EnumValueSerializer {
override def hashCode(): Int = {
enumClass.hashCode() * 31 + enumConstants.hashCode()
}

override def getCompatibleVersions: Array[Int] = {
Array(1, 2)
}
}

object ScalaEnumSerializerConfigSnapshot {
val VERSION = 1
val VERSION = 2
}

}
31 changes: 31 additions & 0 deletions flink-scala/src/test/resources/log4j-test.properties
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http:https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# Set root logger level to OFF to not flood build logs
# set manually to INFO for debugging purposes
log4j.rootLogger=OFF, testlogger

# A1 is set to be a ConsoleAppender.
log4j.appender.testlogger=org.apache.log4j.ConsoleAppender
log4j.appender.testlogger.target = System.err
log4j.appender.testlogger.layout=org.apache.log4j.PatternLayout
log4j.appender.testlogger.layout.ConversionPattern=%-4r [%t] %-5p %c %x - %m%n

# suppress the irrelevant (wrong) warnings from the netty channel handler
log4j.logger.org.jboss.netty.channel.DefaultChannelPipeline=ERROR
log4j.logger.org.apache.zookeeper=OFF
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,16 @@ class EnumValueSerializerUpgradeTest extends TestLogger with JUnitSuiteLike {
|}
""".stripMargin

val enumE =
s"""
|@SerialVersionUID(1L)
|object $enumName extends Enumeration {
| val A = Value(42)
| val B = Value(5)
| val C = Value(1337)
|}
""".stripMargin

/**
* Check that identical enums don't require migration
*/
Expand Down Expand Up @@ -106,6 +116,16 @@ class EnumValueSerializerUpgradeTest extends TestLogger with JUnitSuiteLike {
assertTrue(checkCompatibility(enumA, enumD).isRequiresMigration)
}

/**
* Check that changing the enum ids causes a migration
*/
@Test
def checkDifferentIds(): Unit = {
assertTrue(
"Different ids should cause a migration.",
checkCompatibility(enumA, enumE).isRequiresMigration)
}

def checkCompatibility(enumSourceA: String, enumSourceB: String)
: CompatibilityResult[Enumeration#Value] = {
import EnumValueSerializerUpgradeTest._
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.

0 comments on commit 228faf8

Please sign in to comment.