Skip to content

Commit

Permalink
[FLINK-18524][table-common] Fix type inference for Scala varargs
Browse files Browse the repository at this point in the history
This closes apache#12853.
  • Loading branch information
twalthr committed Jul 9, 2020
1 parent d23587c commit ed5fc5c
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import org.junit.runners.Parameterized.Parameters
import org.junit.{Rule, Test}

import scala.annotation.meta.getter
import scala.annotation.varargs

/**
* Scala tests for [[TypeInferenceExtractor]].
Expand Down Expand Up @@ -78,7 +79,8 @@ object TypeInferenceExtractorScalaTest {
def testData: Array[TestSpec] = Array(

// Scala function with data type hint
TestSpec.forScalarFunction(classOf[ScalaScalarFunction])
TestSpec
.forScalarFunction(classOf[ScalaScalarFunction])
.expectNamedArguments("i", "s", "d")
.expectTypedArguments(
DataTypes.INT.notNull().bridgedTo(classOf[Int]),
Expand All @@ -93,8 +95,42 @@ object TypeInferenceExtractorScalaTest {
InputTypeStrategies.explicit(DataTypes.DECIMAL(10, 4)))),
TypeStrategies.explicit(DataTypes.BOOLEAN.notNull().bridgedTo(classOf[Boolean]))),

TestSpec
.forScalarFunction(classOf[ScalaPrimitiveVarArgScalarFunction])
.expectOutputMapping(
InputTypeStrategies.varyingSequence(
Array[String]("i", "s", "d"),
Array[ArgumentTypeStrategy](
InputTypeStrategies.explicit(DataTypes.INT.notNull().bridgedTo(classOf[Int])),
InputTypeStrategies.explicit(DataTypes.STRING),
InputTypeStrategies.explicit(DataTypes.DOUBLE().notNull().bridgedTo(classOf[Double])))),
TypeStrategies.explicit(DataTypes.BOOLEAN.notNull().bridgedTo(classOf[Boolean]))),

TestSpec
.forScalarFunction(classOf[ScalaBoxedVarArgScalarFunction])
.expectOutputMapping(
InputTypeStrategies.varyingSequence(
Array[String]("i", "s", "d"),
Array[ArgumentTypeStrategy](
InputTypeStrategies.explicit(DataTypes.INT.notNull().bridgedTo(classOf[Int])),
InputTypeStrategies.explicit(DataTypes.STRING),
InputTypeStrategies.explicit(DataTypes.DOUBLE()))),
TypeStrategies.explicit(DataTypes.BOOLEAN.notNull().bridgedTo(classOf[Boolean]))),

TestSpec
.forScalarFunction(classOf[ScalaHintVarArgScalarFunction])
.expectOutputMapping(
InputTypeStrategies.varyingSequence(
Array[String]("i", "s", "d"),
Array[ArgumentTypeStrategy](
InputTypeStrategies.explicit(DataTypes.INT.notNull().bridgedTo(classOf[Int])),
InputTypeStrategies.explicit(DataTypes.STRING),
InputTypeStrategies.explicit(DataTypes.DECIMAL(10, 4)))),
TypeStrategies.explicit(DataTypes.BOOLEAN.notNull().bridgedTo(classOf[Boolean]))),

// global output hint with local input overloading
TestSpec.forScalarFunction(classOf[ScalaGlobalOutputFunctionHint])
TestSpec
.forScalarFunction(classOf[ScalaGlobalOutputFunctionHint])
.expectOutputMapping(
InputTypeStrategies.sequence(InputTypeStrategies.explicit(DataTypes.INT)),
TypeStrategies.explicit(DataTypes.INT))
Expand Down Expand Up @@ -122,4 +158,28 @@ object TypeInferenceExtractorScalaTest {
@FunctionHint(input = Array(new DataTypeHint("STRING")))
def eval(n: String): Integer = null
}

private class ScalaPrimitiveVarArgScalarFunction extends ScalarFunction {
@varargs
def eval(
i: Int,
s: String,
d: Double*): Boolean = false
}

private class ScalaBoxedVarArgScalarFunction extends ScalarFunction {
@varargs
def eval(
i: Int,
s: String,
d: java.lang.Double*): Boolean = false
}

private class ScalaHintVarArgScalarFunction extends ScalarFunction {
@varargs
def eval(
i: Int,
s: String,
@DataTypeHint("ARRAY<DECIMAL(10, 4)>") d: java.math.BigDecimal*): Boolean = false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

import java.lang.reflect.Method;
import java.lang.reflect.Parameter;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.Type;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -156,11 +158,13 @@ private Map<FunctionSignatureTemplate, FunctionResultTemplate> extractResultMapp
}
for (Method method : methods) {
try {
final Method correctMethod = correctVarArgMethod(method);

final Map<FunctionSignatureTemplate, FunctionResultTemplate> collectedMappingsPerMethod =
collectMethodMappings(method, global, globalResultOnly, resultExtraction, accessor);
collectMethodMappings(correctMethod, global, globalResultOnly, resultExtraction, accessor);

// check if the method can be called
verifyMappingForMethod(method, collectedMappingsPerMethod, verification);
verifyMappingForMethod(correctMethod, collectedMappingsPerMethod, verification);

// check if method strategies conflict with function strategies
collectedMappingsPerMethod.forEach((signature, result) -> putMapping(collectedMappings, signature, result));
Expand All @@ -174,6 +178,39 @@ private Map<FunctionSignatureTemplate, FunctionResultTemplate> extractResultMapp
return collectedMappings;
}

/**
* Special case for Scala which generates two methods when using var-args (a {@code Seq < String >}
* and {@code String...}). This method searches for the Java-like variant.
*/
private static Method correctVarArgMethod(Method method) {
final int paramCount = method.getParameterCount();
final Class<?>[] paramClasses = method.getParameterTypes();
if (paramCount > 0 && paramClasses[paramCount - 1].getName().equals("scala.collection.Seq")) {
final Type[] paramTypes = method.getGenericParameterTypes();
final ParameterizedType seqType = (ParameterizedType) paramTypes[paramCount - 1];
final Type varArgType = seqType.getActualTypeArguments()[0];
return ExtractionUtils.collectMethods(method.getDeclaringClass(), method.getName())
.stream()
.filter(Method::isVarArgs)
.filter(candidate -> candidate.getParameterCount() == paramCount)
.filter(candidate -> {
final Type[] candidateParamTypes = candidate.getGenericParameterTypes();
for (int i = 0; i < paramCount - 1; i++) {
if (candidateParamTypes[i] != paramTypes[i]) {
return false;
}
}
final Class<?> candidateVarArgType = candidate.getParameterTypes()[paramCount - 1];
return candidateVarArgType.isArray() &&
// check for Object is needed in case of Scala primitives (e.g. Int)
(varArgType == Object.class || candidateVarArgType.getComponentType() == varArgType);
})
.findAny()
.orElse(method);
}
return method;
}

/**
* Extracts mappings from signature to result (either accumulator or output) for the given method. It
* considers both global hints for the entire function and local hints just for this method.
Expand Down Expand Up @@ -368,7 +405,7 @@ private static FunctionArgumentTemplate extractDataTypeArgument(
return FunctionArgumentTemplate.of(((CollectionDataType) type).getElementDataType());
}
// special case for varargs that have been misinterpreted as BYTES
else {
else if (type.equals(DataTypes.BYTES())) {
return FunctionArgumentTemplate.of(DataTypes.TINYINT().notNull().bridgedTo(byte.class));
}
}
Expand Down

0 comments on commit ed5fc5c

Please sign in to comment.