Skip to content

Commit

Permalink
[FLINK-12170] [table-planner-blink] Add support for generating optimi…
Browse files Browse the repository at this point in the history
…zed logical plan for Over aggregate (apache#8157)
  • Loading branch information
godfreyhe authored and KurtYoung committed Apr 17, 2019
1 parent c3a9293 commit e42f6b5
Show file tree
Hide file tree
Showing 26 changed files with 2,573 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ public class PlannerConfigOptions {
"3. L and R shuffle by c1 and c2\n" +
"It can reduce some shuffle cost someTimes.");

public static final ConfigOption<Boolean> SQL_OPTIMIZER_SMJ_REMOVE_SORT_ENABLE =
key("sql.optimizer.smj.remove-sort.enable")
public static final ConfigOption<Boolean> SQL_OPTIMIZER_SMJ_REMOVE_SORT_ENABLED =
key("sql.optimizer.smj.remove-sort.enabled")
.defaultValue(false)
.withDescription("When true, the optimizer will try to remove redundant sort for SortMergeJoin. " +
"However that will increase optimization time. Default value is false.");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.apache.flink.api.common.typeinfo.TypeInformation;

import java.util.Arrays;
import java.util.Collections;
import java.util.List;

import static org.apache.flink.table.expressions.BuiltInFunctionDefinitions.AND;
Expand All @@ -33,6 +34,7 @@
import static org.apache.flink.table.expressions.BuiltInFunctionDefinitions.IS_NULL;
import static org.apache.flink.table.expressions.BuiltInFunctionDefinitions.LESS_THAN;
import static org.apache.flink.table.expressions.BuiltInFunctionDefinitions.MINUS;
import static org.apache.flink.table.expressions.BuiltInFunctionDefinitions.NOT;
import static org.apache.flink.table.expressions.BuiltInFunctionDefinitions.OR;
import static org.apache.flink.table.expressions.BuiltInFunctionDefinitions.PLUS;

Expand Down Expand Up @@ -69,6 +71,10 @@ public static Expression or(Expression... args) {
return new CallExpression(OR, Arrays.asList(args));
}

public static Expression not(Expression arg) {
return new CallExpression(NOT, Collections.singletonList(arg));
}

public static Expression isNull(Expression input) {
return call(IS_NULL, input);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.functions.aggfunctions;

import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.expressions.UnresolvedReferenceExpression;
import org.apache.flink.table.type.InternalType;
import org.apache.flink.table.type.InternalTypes;

import static org.apache.flink.table.expressions.ExpressionBuilder.ifThenElse;
import static org.apache.flink.table.expressions.ExpressionBuilder.literal;
import static org.apache.flink.table.expressions.ExpressionBuilder.plus;

/**
* built-in dense_rank aggregate function.
*/
public class DenseRankAggFunction extends RankLikeAggFunctionBase {

public DenseRankAggFunction(InternalType[] orderKeyTypes) {
super(orderKeyTypes);
}

@Override
public UnresolvedReferenceExpression[] aggBufferAttributes() {
UnresolvedReferenceExpression[] aggBufferAttrs = new UnresolvedReferenceExpression[1 + lastValues.length];
aggBufferAttrs[0] = sequence;
System.arraycopy(lastValues, 0, aggBufferAttrs, 1, lastValues.length);
return aggBufferAttrs;
}

@Override
public InternalType[] getAggBufferTypes() {
InternalType[] aggBufferTypes = new InternalType[1 + orderKeyTypes.length];
aggBufferTypes[0] = InternalTypes.LONG;
System.arraycopy(orderKeyTypes, 0, aggBufferTypes, 1, orderKeyTypes.length);
return aggBufferTypes;
}

@Override
public Expression[] initialValuesExpressions() {
Expression[] initExpressions = new Expression[1 + orderKeyTypes.length];
// sequence = 0L
initExpressions[0] = literal(0L);
for (int i = 0; i < orderKeyTypes.length; ++i) {
// lastValue_i = init value
initExpressions[i + 1] = generateInitLiteral(orderKeyTypes[i]);
}
return initExpressions;
}

@Override
public Expression[] accumulateExpressions() {
Expression[] accExpressions = new Expression[1 + operands().length];
// sequence = if (lastValues equalTo orderKeys) sequence else sequence + 1
accExpressions[0] = ifThenElse(orderKeyEqualsExpression(), sequence, plus(sequence, literal(1L)));
Expression[] operands = operands();
for (int i = 0; i < operands.length; ++i) {
// lastValue_i = orderKey[i]
accExpressions[i + 1] = operands[i];
}
return accExpressions;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,12 @@
* limitations under the License.
*/

package org.apache.flink.table.functions;
package org.apache.flink.table.functions.aggfunctions;

import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.expressions.UnresolvedReferenceExpression;
import org.apache.flink.table.functions.aggfunctions.DeclarativeAggregateFunction;
import org.apache.flink.table.runtime.over.frame.OffsetOverFrame;
import org.apache.flink.table.type.DecimalType;
import org.apache.flink.table.type.InternalType;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.functions.aggfunctions;

import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.expressions.UnresolvedReferenceExpression;
import org.apache.flink.table.type.InternalType;
import org.apache.flink.table.type.InternalTypes;

import static org.apache.flink.table.expressions.ExpressionBuilder.and;
import static org.apache.flink.table.expressions.ExpressionBuilder.equalTo;
import static org.apache.flink.table.expressions.ExpressionBuilder.ifThenElse;
import static org.apache.flink.table.expressions.ExpressionBuilder.literal;
import static org.apache.flink.table.expressions.ExpressionBuilder.not;
import static org.apache.flink.table.expressions.ExpressionBuilder.plus;

/**
* built-in rank aggregate function.
*/
public class RankAggFunction extends RankLikeAggFunctionBase {

private UnresolvedReferenceExpression currNumber = new UnresolvedReferenceExpression("currNumber");

public RankAggFunction(InternalType[] orderKeyTypes) {
super(orderKeyTypes);
}

@Override
public UnresolvedReferenceExpression[] aggBufferAttributes() {
UnresolvedReferenceExpression[] aggBufferAttrs = new UnresolvedReferenceExpression[2 + lastValues.length];
aggBufferAttrs[0] = currNumber;
aggBufferAttrs[1] = sequence;
System.arraycopy(lastValues, 0, aggBufferAttrs, 2, lastValues.length);
return aggBufferAttrs;
}

@Override
public InternalType[] getAggBufferTypes() {
InternalType[] aggBufferTypes = new InternalType[2 + orderKeyTypes.length];
aggBufferTypes[0] = InternalTypes.LONG;
aggBufferTypes[1] = InternalTypes.LONG;
System.arraycopy(orderKeyTypes, 0, aggBufferTypes, 2, orderKeyTypes.length);
return aggBufferTypes;
}

@Override
public Expression[] initialValuesExpressions() {
Expression[] initExpressions = new Expression[2 + orderKeyTypes.length];
// currNumber = 0L
initExpressions[0] = literal(0L);
// sequence = 0L
initExpressions[1] = literal(0L);
for (int i = 0; i < orderKeyTypes.length; ++i) {
// lastValue_i = init value
initExpressions[i + 2] = generateInitLiteral(orderKeyTypes[i]);
}
return initExpressions;
}

@Override
public Expression[] accumulateExpressions() {
Expression[] accExpressions = new Expression[2 + operands().length];
// currNumber = currNumber + 1
accExpressions[0] = plus(currNumber, literal(1L));
// sequence = if (lastValues equalTo orderKeys and sequence != 0) sequence else currNumber
accExpressions[1] = ifThenElse(and(orderKeyEqualsExpression(), not(equalTo(sequence, literal(0L)))),
sequence, currNumber);
Expression[] operands = operands();
for (int i = 0; i < operands.length; ++i) {
// lastValue_i = orderKey[i]
accExpressions[i + 2] = operands[i];
}
return accExpressions;
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.flink.table.functions.aggfunctions;

import org.apache.flink.api.common.typeinfo.TypeInformation;
import org.apache.flink.api.common.typeinfo.Types;
import org.apache.flink.table.api.TableException;
import org.apache.flink.table.expressions.Expression;
import org.apache.flink.table.expressions.UnresolvedReferenceExpression;
import org.apache.flink.table.type.DecimalType;
import org.apache.flink.table.type.InternalType;
import org.apache.flink.table.type.InternalTypes;

import java.sql.Date;
import java.sql.Time;
import java.sql.Timestamp;

import static org.apache.flink.table.expressions.ExpressionBuilder.and;
import static org.apache.flink.table.expressions.ExpressionBuilder.equalTo;
import static org.apache.flink.table.expressions.ExpressionBuilder.ifThenElse;
import static org.apache.flink.table.expressions.ExpressionBuilder.isNull;
import static org.apache.flink.table.expressions.ExpressionBuilder.literal;

/**
* built-in rank like aggregate function, e.g. rank, dense_rank
*/
public abstract class RankLikeAggFunctionBase extends DeclarativeAggregateFunction {
protected UnresolvedReferenceExpression sequence = new UnresolvedReferenceExpression("sequence");
protected UnresolvedReferenceExpression[] lastValues;
protected InternalType[] orderKeyTypes;

public RankLikeAggFunctionBase(InternalType[] orderKeyTypes) {
this.orderKeyTypes = orderKeyTypes;
lastValues = new UnresolvedReferenceExpression[orderKeyTypes.length];
for (int i = 0; i < orderKeyTypes.length; ++i) {
lastValues[i] = new UnresolvedReferenceExpression("lastValue_" + i);
}
}

@Override
public int operandCount() {
return orderKeyTypes.length;
}

@Override
public TypeInformation getResultType() {
return Types.LONG;
}

@Override
public Expression[] retractExpressions() {
throw new TableException("This function does not support retraction.");
}

@Override
public Expression[] mergeExpressions() {
throw new TableException("This function does not support merge.");
}

@Override
public Expression getValueExpression() {
return sequence;
}

protected Expression orderKeyEqualsExpression() {
Expression[] orderKeyEquals = new Expression[orderKeyTypes.length];
for (int i = 0; i < orderKeyTypes.length; ++i) {
// pseudo code:
// if (lastValue_i is null) {
// if (operand(i) is null) true else false
// } else {
// lastValue_i equalTo orderKey(i)
// }
Expression lasValue = lastValues[i];
orderKeyEquals[i] = ifThenElse(isNull(lasValue),
ifThenElse(isNull(operand(i)), literal(true), literal(false)),
equalTo(lasValue, operand(i)));
}
if (orderKeyEquals.length == 0) {
return literal(true);
} else {
return and(orderKeyEquals);
}
}

protected Expression generateInitLiteral(InternalType orderType) {
if (orderType.equals(InternalTypes.BOOLEAN)) {
return literal(false);
} else if (orderType.equals(InternalTypes.BYTE)) {
return literal((byte) 0);
} else if (orderType.equals(InternalTypes.SHORT)) {
return literal((short) 0);
} else if (orderType.equals(InternalTypes.INT)) {
return literal(0);
} else if (orderType.equals(InternalTypes.LONG)) {
return literal(0L);
} else if (orderType.equals(InternalTypes.FLOAT)) {
return literal(0.0f);
} else if (orderType.equals(InternalTypes.DOUBLE)) {
return literal(0.0d);
} else if (orderType instanceof DecimalType) {
return literal(java.math.BigDecimal.ZERO);
} else if (orderType.equals(InternalTypes.DATE)) {
return literal(new Date(0));
} else if (orderType.equals(InternalTypes.TIME)) {
return literal(new Time(0));
} else if (orderType.equals(InternalTypes.TIMESTAMP)) {
return literal(new Timestamp(0));
} else {
throw new TableException("Unsupported type: " + orderType);
}
}
}
Loading

0 comments on commit e42f6b5

Please sign in to comment.