Skip to content

Commit

Permalink
[FLINK-4793] [types] Improve lambda constructor reference handling
Browse files Browse the repository at this point in the history
This closes apache#2621.
  • Loading branch information
twalthr committed Oct 12, 2016
1 parent 6731ec1 commit 1dda3ad
Show file tree
Hide file tree
Showing 7 changed files with 306 additions and 121 deletions.
7 changes: 7 additions & 0 deletions flink-core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,13 @@ under the License.
</exclusions>
</dependency>

<!-- ASM is needed for type extraction -->
<dependency>
<groupId>org.ow2.asm</groupId>
<artifactId>asm-all</artifactId>
<version>${asm.version}</version>
</dependency>

<!-- test dependencies -->
<dependency>
<groupId>org.apache.flink</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@

package org.apache.flink.api.common.functions.util;

import java.lang.reflect.Method;

import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.functions.Function;
import org.apache.flink.api.common.functions.RichFunction;
Expand Down Expand Up @@ -62,73 +60,6 @@ public static RuntimeContext getFunctionRuntimeContext(Function function, Runtim
return defaultContext;
}
}

public static Method checkAndExtractLambdaMethod(Function function) {
try {
// get serialized lambda
Object serializedLambda = null;
for (Class<?> clazz = function.getClass(); clazz != null; clazz = clazz.getSuperclass()) {
try {
Method replaceMethod = clazz.getDeclaredMethod("writeReplace");
replaceMethod.setAccessible(true);
Object serialVersion = replaceMethod.invoke(function);

// check if class is a lambda function
if (serialVersion.getClass().getName().equals("java.lang.invoke.SerializedLambda")) {

// check if SerializedLambda class is present
try {
Class.forName("java.lang.invoke.SerializedLambda");
}
catch (Exception e) {
throw new UnsupportedOperationException("User code tries to use lambdas, but framework is running with a Java version < 8");
}
serializedLambda = serialVersion;
break;
}
}
catch (NoSuchMethodException e) {
// thrown if the method is not there. fall through the loop
}
}

// not a lambda method -> return null
if (serializedLambda == null) {
return null;
}

// find lambda method
Method implClassMethod = serializedLambda.getClass().getDeclaredMethod("getImplClass");
Method implMethodNameMethod = serializedLambda.getClass().getDeclaredMethod("getImplMethodName");

String className = (String) implClassMethod.invoke(serializedLambda);
String methodName = (String) implMethodNameMethod.invoke(serializedLambda);

Class<?> implClass = Class.forName(className.replace('/', '.'), true, Thread.currentThread().getContextClassLoader());

Method[] methods = implClass.getDeclaredMethods();
Method parameterizedMethod = null;
for (Method method : methods) {
if(method.getName().equals(methodName)) {
if(parameterizedMethod != null) {
// It is very unlikely that a class contains multiple e.g. "lambda$2()" but its possible
// Actually, the signature need to be checked, but this is very complex
throw new Exception("Lambda method name is not unique.");
}
else {
parameterizedMethod = method;
}
}
}
if (parameterizedMethod == null) {
throw new Exception("No lambda method found.");
}
return parameterizedMethod;
}
catch (Exception e) {
throw new RuntimeException("Could not extract lambda method out of function: " + e.getClass().getSimpleName() + " - " + e.getMessage(), e);
}
}

/**
* Private constructor to prevent instantiation.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.api.java.typeutils;

import org.apache.flink.annotation.Internal;

/**
* Type extraction always contains some uncertainty due to unpredictable JVM differences
* between vendors or versions. This exception is thrown if an assumption failed during extraction.
*/
@Internal
public class TypeExtractionException extends Exception {

private static final long serialVersionUID = 1L;

/**
* Creates a new exception with no message.
*/
public TypeExtractionException() {
super();
}

/**
* Creates a new exception with the given message.
*
* @param message The exception message.
*/
public TypeExtractionException(String message) {
super(message);
}

/**
* Creates a new exception with the given message and cause.
*
* @param message The exception message.
* @param e cause
*/
public TypeExtractionException(String message, Throwable e) {
super(message, e);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.api.java.typeutils;

import java.lang.reflect.Constructor;
import java.lang.reflect.Method;
import java.lang.reflect.Type;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import org.apache.flink.annotation.Internal;
import org.apache.flink.api.common.functions.Function;
import static org.objectweb.asm.Type.getConstructorDescriptor;
import static org.objectweb.asm.Type.getMethodDescriptor;

@Internal
public class TypeExtractionUtils {

private TypeExtractionUtils() {
// do not allow instantiation
}

/**
* Similar to a Java 8 Executable but with a return type.
*/
public static class LambdaExecutable {

private Type[] parameterTypes;
private Type returnType;
private String name;
private Object executable;

public LambdaExecutable(Constructor<?> constructor) {
this.parameterTypes = constructor.getGenericParameterTypes();
this.returnType = constructor.getDeclaringClass();
this.name = constructor.getName();
this.executable = constructor;
}

public LambdaExecutable(Method method) {
this.parameterTypes = method.getGenericParameterTypes();
this.returnType = method.getGenericReturnType();
this.name = method.getName();
this.executable = method;
}

public Type[] getParameterTypes() {
return parameterTypes;
}

public Type getReturnType() {
return returnType;
}

public String getName() {
return name;
}

public boolean executablesEquals(Method m) {
return executable.equals(m);
}

public boolean executablesEquals(Constructor<?> c) {
return executable.equals(c);
}
}

public static LambdaExecutable checkAndExtractLambda(Function function) throws TypeExtractionException {
try {
// get serialized lambda
Object serializedLambda = null;
for (Class<?> clazz = function.getClass(); clazz != null; clazz = clazz.getSuperclass()) {
try {
Method replaceMethod = clazz.getDeclaredMethod("writeReplace");
replaceMethod.setAccessible(true);
Object serialVersion = replaceMethod.invoke(function);

// check if class is a lambda function
if (serialVersion.getClass().getName().equals("java.lang.invoke.SerializedLambda")) {

// check if SerializedLambda class is present
try {
Class.forName("java.lang.invoke.SerializedLambda");
}
catch (Exception e) {
throw new TypeExtractionException("User code tries to use lambdas, but framework is running with a Java version < 8");
}
serializedLambda = serialVersion;
break;
}
}
catch (NoSuchMethodException e) {
// thrown if the method is not there. fall through the loop
}
}

// not a lambda method -> return null
if (serializedLambda == null) {
return null;
}

// find lambda method
Method implClassMethod = serializedLambda.getClass().getDeclaredMethod("getImplClass");
Method implMethodNameMethod = serializedLambda.getClass().getDeclaredMethod("getImplMethodName");
Method implMethodSig = serializedLambda.getClass().getDeclaredMethod("getImplMethodSignature");

String className = (String) implClassMethod.invoke(serializedLambda);
String methodName = (String) implMethodNameMethod.invoke(serializedLambda);
String methodSig = (String) implMethodSig.invoke(serializedLambda);

Class<?> implClass = Class.forName(className.replace('/', '.'), true, Thread.currentThread().getContextClassLoader());

// find constructor
if (methodName.equals("<init>")) {
Constructor<?>[] constructors = implClass.getDeclaredConstructors();
for (Constructor<?> constructor : constructors) {
if(getConstructorDescriptor(constructor).equals(methodSig)) {
return new LambdaExecutable(constructor);
}
}
}
// find method
else {
List<Method> methods = getAllDeclaredMethods(implClass);
for (Method method : methods) {
if(method.getName().equals(methodName) && getMethodDescriptor(method).equals(methodSig)) {
return new LambdaExecutable(method);
}
}
}
throw new TypeExtractionException("No lambda method found.");
}
catch (Exception e) {
throw new TypeExtractionException("Could not extract lambda method out of function: " +
e.getClass().getSimpleName() + " - " + e.getMessage(), e);
}
}

/**
* Returns all declared methods of a class including methods of superclasses.
*/
public static List<Method> getAllDeclaredMethods(Class<?> clazz) {
List<Method> result = new ArrayList<>();
while (clazz != null) {
Method[] methods = clazz.getDeclaredMethods();
Collections.addAll(result, methods);
clazz = clazz.getSuperclass();
}
return result;
}
}
Loading

0 comments on commit 1dda3ad

Please sign in to comment.