Skip to content

Commit

Permalink
Use joni to replace java.regex for regex sql functions
Browse files Browse the repository at this point in the history
  • Loading branch information
haozhun committed May 1, 2015
1 parent 2aae001 commit c1fd79f
Show file tree
Hide file tree
Showing 5 changed files with 260 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.regex.Pattern;

import static com.facebook.presto.metadata.FunctionRegistry.operatorInfo;
import static com.facebook.presto.spi.type.BigintType.BIGINT;
Expand All @@ -72,7 +71,6 @@ public class FunctionListBuilder
Slice.class,
boolean.class,
Boolean.class,
Pattern.class,
Regex.class,
JsonPath.class);

Expand All @@ -82,7 +80,6 @@ public class FunctionListBuilder
Slice.class,
boolean.class,
int.class,
Pattern.class,
Regex.class,
JsonPath.class);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,30 @@
import com.facebook.presto.spi.PrestoException;
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.block.BlockBuilderStatus;
import com.facebook.presto.spi.block.VariableWidthBlockBuilder;
import com.facebook.presto.spi.type.StandardTypes;
import com.facebook.presto.spi.type.VarcharType;
import com.facebook.presto.type.RegexpType;
import com.facebook.presto.type.SqlType;
import com.google.common.primitives.Ints;
import io.airlift.jcodings.specific.NonStrictUTF8Encoding;
import io.airlift.joni.Matcher;
import io.airlift.joni.Option;
import io.airlift.joni.Regex;
import io.airlift.joni.Region;
import io.airlift.joni.Syntax;
import io.airlift.joni.exception.ValueException;
import io.airlift.slice.DynamicSliceOutput;
import io.airlift.slice.Slice;
import io.airlift.slice.SliceOutput;
import io.airlift.slice.Slices;

import javax.annotation.Nullable;

import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.regex.PatternSyntaxException;
import java.nio.charset.StandardCharsets;

import static com.facebook.presto.spi.StandardErrorCode.INVALID_FUNCTION_ARGUMENT;
import static com.facebook.presto.spi.type.VarcharType.VARCHAR;
import static com.facebook.presto.type.ArrayType.toStackRepresentation;
import static com.facebook.presto.type.TypeUtils.buildStructuralSlice;
import static java.lang.String.format;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Arrays.asList;

public final class RegexpFunctions
{
Expand All @@ -49,75 +51,198 @@ private RegexpFunctions()

@ScalarOperator(OperatorType.CAST)
@SqlType(RegexpType.NAME)
public static Pattern castToRegexp(@SqlType(StandardTypes.VARCHAR) Slice pattern)
public static Regex castToRegexp(@SqlType(StandardTypes.VARCHAR) Slice pattern)
{
Regex regex;
try {
return Pattern.compile(pattern.toString(UTF_8));
// When normal UTF8 encoding instead of non-strict UTF8) is used, joni can infinite loop when invalid UTF8 slice is supplied to it.
regex = new Regex(pattern.getBytes(), 0, pattern.length(), Option.DEFAULT, NonStrictUTF8Encoding.INSTANCE, Syntax.Java);
}
catch (PatternSyntaxException e) {
catch (Exception e) {
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, e);
}
return regex;
}

@Description("returns substrings matching a regular expression")
@ScalarFunction
@SqlType(StandardTypes.BOOLEAN)
public static boolean regexpLike(@SqlType(StandardTypes.VARCHAR) Slice source, @SqlType(RegexpType.NAME) Pattern pattern)
public static boolean regexpLike(@SqlType(StandardTypes.VARCHAR) Slice source, @SqlType(RegexpType.NAME) Regex pattern)
{
return pattern.matcher(source.toString(UTF_8)).find();
Matcher m = pattern.matcher(source.getBytes());
int offset = m.search(0, source.length(), Option.DEFAULT);
return offset != -1;
}

@Description("removes substrings matching a regular expression")
@ScalarFunction
@SqlType(StandardTypes.VARCHAR)
public static Slice regexpReplace(@SqlType(StandardTypes.VARCHAR) Slice source, @SqlType(RegexpType.NAME) Pattern pattern)
public static Slice regexpReplace(@SqlType(StandardTypes.VARCHAR) Slice source, @SqlType(RegexpType.NAME) Regex pattern)
{
return regexpReplace(source, pattern, Slices.EMPTY_SLICE);
}

@Description("replaces substrings matching a regular expression by given string")
@ScalarFunction
@SqlType(StandardTypes.VARCHAR)
public static Slice regexpReplace(@SqlType(StandardTypes.VARCHAR) Slice source, @SqlType(RegexpType.NAME) Pattern pattern, @SqlType(StandardTypes.VARCHAR) Slice replacement)
public static Slice regexpReplace(@SqlType(StandardTypes.VARCHAR) Slice source, @SqlType(RegexpType.NAME) Regex pattern, @SqlType(StandardTypes.VARCHAR) Slice replacement)
{
Matcher matcher = pattern.matcher(source.toString(UTF_8));
String replaced = matcher.replaceAll(replacement.toString(UTF_8));
return Slices.copiedBuffer(replaced, UTF_8);
Matcher matcher = pattern.matcher(source.getBytes());
SliceOutput sliceOutput = new DynamicSliceOutput(source.length() + replacement.length() * 5);

int lastEnd = 0;
int nextStart = 0; // nextStart is the same as lastEnd, unless the last match was zero-width. In such case, nextStart is lastEnd + 1.
while (true) {
int offset = matcher.search(nextStart, source.length(), Option.DEFAULT);
if (offset == -1) {
break;
}
if (matcher.getEnd() == matcher.getBegin()) {
nextStart = matcher.getEnd() + 1;
}
else {
nextStart = matcher.getEnd();
}
Slice sliceBetweenReplacements = source.slice(lastEnd, matcher.getBegin() - lastEnd);
lastEnd = matcher.getEnd();
sliceOutput.appendBytes(sliceBetweenReplacements);
appendReplacement(sliceOutput, source, pattern, matcher.getEagerRegion(), replacement);
}
sliceOutput.appendBytes(source.slice(lastEnd, source.length() - lastEnd));

return sliceOutput.slice();
}

private static void appendReplacement(SliceOutput result, Slice source, Regex pattern, Region region, Slice replacement)
{
// Handle the following items:
// 1. ${name};
// 2. $0, $1, $123 (group 123, if exists; or group 12, if exists; or group 1);
// 3. \\, \$, \t (literal 't').
// 4. Anything that doesn't starts with \ or $ is considered regular bytes

int idx = 0;

while (idx < replacement.length()) {
byte nextByte = replacement.getByte(idx);
if (nextByte == '$') {
idx++;
if (idx == replacement.length()) { // not using checkArgument because `.toStringUtf8` is expensive
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Illegal replacement sequence: " + replacement.toStringUtf8());
}
nextByte = replacement.getByte(idx);
int backref;
if (nextByte == '{') { // case 1 in the above comment
idx++;
int startCursor = idx;
while (idx < replacement.length()) {
nextByte = replacement.getByte(idx);
if (nextByte == '}') {
break;
}
idx++;
}
byte[] groupName = replacement.getBytes(startCursor, idx - startCursor);
try {
backref = pattern.nameToBackrefNumber(groupName, 0, groupName.length, region);
}
catch (ValueException e) {
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Illegal replacement sequence: unknown group { " + new String(groupName, StandardCharsets.UTF_8) + " }");
}
idx++;
}
else { // case 2 in the above comment
backref = nextByte - '0';
if (backref < 0 || backref > 9) { // not using checkArgument because `.toStringUtf8` is expensive
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Illegal replacement sequence: " + replacement.toStringUtf8());
}
if (region.numRegs <= backref) {
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Illegal replacement sequence: unknown group " + backref);
}
idx++;
while (idx < replacement.length()) { // Adaptive group number: find largest group num that is not greater than actual number of groups
int nextDigit = replacement.getByte(idx) - '0';
if (nextDigit < 0 || nextDigit > 9) {
break;
}
int newBackref = (backref * 10) + nextDigit;
if (region.numRegs <= newBackref) {
break;
}
backref = newBackref;
idx++;
}
}
int beg = region.beg[backref];
int end = region.end[backref];
if (beg != -1 && end != -1) { // the specific group doesn't exist in the current match, skip
result.appendBytes(source.slice(beg, end - beg));
}
}
else { // case 3 and 4 in the above comment
if (nextByte == '\\') {
idx++;
if (idx == replacement.length()) { // not using checkArgument because `.toStringUtf8` is expensive
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Illegal replacement sequence: " + replacement.toStringUtf8());
}
nextByte = replacement.getByte(idx);
}
result.appendByte(nextByte);
idx++;
}
}
}

@Description("string(s) extracted using the given pattern")
@ScalarFunction
@SqlType("array<varchar>")
public static Slice regexpExtractAll(@SqlType(StandardTypes.VARCHAR) Slice source, @SqlType(RegexpType.NAME) Pattern pattern)
public static Slice regexpExtractAll(@SqlType(StandardTypes.VARCHAR) Slice source, @SqlType(RegexpType.NAME) Regex pattern)
{
return regexpExtractAll(source, pattern, 0);
}

@Description("group(s) extracted using the given pattern")
@ScalarFunction
@SqlType("array<varchar>")
public static Slice regexpExtractAll(@SqlType(StandardTypes.VARCHAR) Slice source, @SqlType(RegexpType.NAME) Pattern pattern, @SqlType(StandardTypes.BIGINT) long group)
public static Slice regexpExtractAll(@SqlType(StandardTypes.VARCHAR) Slice source, @SqlType(RegexpType.NAME) Regex pattern, @SqlType(StandardTypes.BIGINT) long groupIndex)
{
Matcher matcher = pattern.matcher(source.toString(UTF_8));
validateGroup(group, matcher);
BlockBuilder blockBuilder = new VariableWidthBlockBuilder(new BlockBuilderStatus(), 1024);
while (matcher.find()) {
String string = matcher.group(Ints.checkedCast(group));
if (string == null) {
Matcher matcher = pattern.matcher(source.getBytes());
validateGroup(groupIndex, matcher.getEagerRegion());
BlockBuilder blockBuilder = VARCHAR.createBlockBuilder(new BlockBuilderStatus(), 32);
int group = Ints.checkedCast(groupIndex);

int nextStart = 0;
while (true) {
int offset = matcher.search(nextStart, source.length(), Option.DEFAULT);
if (offset == -1) {
break;
}
if (matcher.getEnd() == matcher.getBegin()) {
nextStart = matcher.getEnd() + 1;
}
else {
nextStart = matcher.getEnd();
}
Region region = matcher.getEagerRegion();
int beg = region.beg[group];
int end = region.end[group];
if (beg == -1 || end == -1) {
blockBuilder.appendNull();
}
else {
VarcharType.VARCHAR.writeString(blockBuilder, string);
Slice slice = source.slice(beg, end - beg);
VARCHAR.writeSlice(blockBuilder, slice);
}
}

return buildStructuralSlice(blockBuilder);
}

@Nullable
@Description("string extracted using the given pattern")
@ScalarFunction
@SqlType(StandardTypes.VARCHAR)
public static Slice regexpExtract(@SqlType(StandardTypes.VARCHAR) Slice source, @SqlType(RegexpType.NAME) Pattern pattern)
public static Slice regexpExtract(@SqlType(StandardTypes.VARCHAR) Slice source, @SqlType(RegexpType.NAME) Regex pattern)
{
return regexpExtract(source, pattern, 0);
}
Expand All @@ -126,36 +251,65 @@ public static Slice regexpExtract(@SqlType(StandardTypes.VARCHAR) Slice source,
@Description("returns regex group of extracted string with a pattern")
@ScalarFunction
@SqlType(StandardTypes.VARCHAR)
public static Slice regexpExtract(@SqlType(StandardTypes.VARCHAR) Slice source, @SqlType(RegexpType.NAME) Pattern pattern, @SqlType(StandardTypes.BIGINT) long group)
public static Slice regexpExtract(@SqlType(StandardTypes.VARCHAR) Slice source, @SqlType(RegexpType.NAME) Regex pattern, @SqlType(StandardTypes.BIGINT) long groupIndex)
{
Matcher matcher = pattern.matcher(source.toString(UTF_8));
validateGroup(group, matcher);
if (!matcher.find()) {
Matcher matcher = pattern.matcher(source.getBytes());
validateGroup(groupIndex, matcher.getEagerRegion());
int group = Ints.checkedCast(groupIndex);

int offset = matcher.search(0, source.length(), Option.DEFAULT);
if (offset == -1) {
return null;
}
String extracted = matcher.group(Ints.checkedCast(group));
if (extracted == null) {
Region region = matcher.getEagerRegion();
int beg = region.beg[group];
int end = region.end[group];
if (beg == -1) {
// end == -1 must be true
return null;
}
return Slices.utf8Slice(extracted);

Slice slice = source.slice(beg, end - beg);
return slice;
}

@ScalarFunction
@Description("returns array of strings split by pattern")
@SqlType("array<varchar>")
public static Slice regexpSplit(@SqlType(StandardTypes.VARCHAR) Slice source, @SqlType(RegexpType.NAME) Pattern pattern)
public static Slice regexpSplit(@SqlType(StandardTypes.VARCHAR) Slice source, @SqlType(RegexpType.NAME) Regex pattern)
{
String[] result = pattern.split(source.toStringUtf8());
return toStackRepresentation(asList(result), VARCHAR);
Matcher matcher = pattern.matcher(source.getBytes());
BlockBuilder blockBuilder = VARCHAR.createBlockBuilder(new BlockBuilderStatus(), 32);

int lastEnd = 0;
int nextStart = 0;
while (true) {
int offset = matcher.search(nextStart, source.length(), Option.DEFAULT);
if (offset == -1) {
break;
}
if (matcher.getEnd() == matcher.getBegin()) {
nextStart = matcher.getEnd() + 1;
}
else {
nextStart = matcher.getEnd();
}
Slice slice = source.slice(lastEnd, matcher.getBegin() - lastEnd);
lastEnd = matcher.getEnd();
VARCHAR.writeSlice(blockBuilder, slice);
}
VARCHAR.writeSlice(blockBuilder, source.slice(lastEnd, source.length() - lastEnd));

return buildStructuralSlice(blockBuilder);
}

private static void validateGroup(long group, Matcher matcher)
private static void validateGroup(long group, Region region)
{
if (group < 0) {
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, "Group cannot be negative");
}
if (group > matcher.groupCount()) {
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, format("Pattern has %d groups. Cannot access group %d", matcher.groupCount(), group));
if (group > region.numRegs - 1) {
throw new PrestoException(INVALID_FUNCTION_ARGUMENT, format("Pattern has %d groups. Cannot access group %d", region.numRegs - 1, group));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@
import com.facebook.presto.spi.block.BlockBuilder;
import com.facebook.presto.spi.block.BlockBuilderStatus;
import com.facebook.presto.spi.type.AbstractType;

import java.util.regex.Pattern;
import io.airlift.joni.Regex;

import static com.facebook.presto.spi.StandardErrorCode.INTERNAL_ERROR;
import static com.facebook.presto.type.TypeUtils.parameterizedTypeName;
Expand All @@ -33,7 +32,7 @@ public class RegexpType

public RegexpType()
{
super(parameterizedTypeName(NAME), Pattern.class);
super(parameterizedTypeName(NAME), Regex.class);
}

@Override
Expand Down
Loading

0 comments on commit c1fd79f

Please sign in to comment.