DRILL-7818 SplitPart (SPLIT_PART) UDF work correct only with one-row data
Fix SPLIT_PART drill function;
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/StringFunctions.java b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/StringFunctions.java
index b9c53be..7b2aaba 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/StringFunctions.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/StringFunctions.java
@@ -387,63 +387,48 @@
@FunctionTemplate(name = "split_part", scope = FunctionScope.SIMPLE, nulls = NullHandling.NULL_IF_NULL,
outputWidthCalculatorType = OutputWidthCalculatorType.CUSTOM_FIXED_WIDTH_DEFAULT)
public static class SplitPart implements DrillSimpleFunc {
- @Param VarCharHolder str;
- @Param VarCharHolder splitter;
- @Param IntHolder index;
+ @Param
+ VarCharHolder in;
+ @Param
+ VarCharHolder delimiter;
- @Output VarCharHolder out;
+ @Param
+ IntHolder index;
+
+ @Workspace
+ com.google.common.base.Splitter splitter;
+
+ @Inject
+ DrillBuf buffer;
+
+ @Output
+ VarCharHolder out;
@Override
- public void setup() {}
-
- @Override
- public void eval() {
+ public void setup() {
if (index.value < 1) {
throw org.apache.drill.common.exceptions.UserException.functionError()
.message("Index in split_part must be positive, value provided was " + index.value).build();
}
- int bufPos = str.start;
- out.start = bufPos;
- boolean beyondLastIndex = false;
- int splitterLen = (splitter.end - splitter.start);
- for (int i = 1; i < index.value + 1; i++) {
- //Do string match.
- final int pos = org.apache.drill.exec.expr.fn.impl.StringFunctionUtil.stringLeftMatchUTF8(str.buffer,
- bufPos, str.end,
- splitter.buffer, splitter.start, splitter.end);
- if (pos < 0) {
- // this is the last iteration, it is okay to hit the end of the string
- if (i == index.value) {
- bufPos = str.end;
- // when the output is terminated by the end of the string we do not want
- // to subtract the length of the splitter from the output at the end of
- // the function below
- splitterLen = 0;
- break;
- } else {
- beyondLastIndex = true;
- break;
- }
- } else {
- // Count the # of characters. (one char could have 1-4 bytes)
- // unlike the position function don't add 1, we are not translating the positions into SQL user level 1 based indices
- bufPos = org.apache.drill.exec.expr.fn.impl.StringFunctionUtil.getUTF8CharLength(str.buffer, str.start, pos)
- + splitterLen;
- // if this is the second to last iteration, store the position again, as the start and end of the
- // string to be returned need to be available
- if (i == index.value - 1) {
- out.start = bufPos;
- }
- }
- }
- if (beyondLastIndex) {
- out.start = 0;
- out.end = 0;
- out.buffer = str.buffer;
- } else {
- out.buffer = str.buffer;
- out.end = bufPos - splitterLen;
- }
+ String split = org.apache.drill.exec.expr.fn.impl.StringFunctionHelpers.
+ toStringFromUTF8(delimiter.start, delimiter.end, delimiter.buffer);
+ splitter = com.google.common.base.Splitter.on(split);
+
+ }
+
+ @Override
+ public void eval() {
+ String inputString =
+ org.apache.drill.exec.expr.fn.impl.StringFunctionHelpers.toStringFromUTF8(in.start, in.end, in.buffer);
+ int arrayIndex = index.value - 1;
+ String result =
+ (String) com.google.common.collect.Iterables.get(splitter.split(inputString), arrayIndex, "");
+ byte[] strBytes = result.getBytes(com.google.common.base.Charsets.UTF_8);
+
+ out.buffer = buffer = buffer.reallocIfNeeded(strBytes.length);
+ out.start = 0;
+ out.end = strBytes.length;
+ out.buffer.setBytes(0, strBytes);
}
}
diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/expr/fn/impl/TestStringFunctions.java b/exec/java-exec/src/test/java/org/apache/drill/exec/expr/fn/impl/TestStringFunctions.java
index 0f79daa..8965edf 100644
--- a/exec/java-exec/src/test/java/org/apache/drill/exec/expr/fn/impl/TestStringFunctions.java
+++ b/exec/java-exec/src/test/java/org/apache/drill/exec/expr/fn/impl/TestStringFunctions.java
@@ -55,17 +55,19 @@
@Test
public void testSplitPart() throws Exception {
testBuilder()
- .sqlQuery("select split_part('abc~@~def~@~ghi', '~@~', 1) res1 from (values(1))")
+ .sqlQuery("select split_part(a, '~@~', 1) res1 from (values('abc~@~def~@~ghi'), ('qwe~@~rty~@~uio')) as t(a)")
.ordered()
.baselineColumns("res1")
.baselineValues("abc")
+ .baselineValues("qwe")
.go();
testBuilder()
- .sqlQuery("select split_part('abc~@~def~@~ghi', '~@~', 2) res1 from (values(1))")
+ .sqlQuery("select split_part(a, '~@~', 2) res1 from (values('abc~@~def~@~ghi'), ('qwe~@~rty~@~uio')) as t(a)")
.ordered()
.baselineColumns("res1")
.baselineValues("def")
+ .baselineValues("rty")
.go();
// invalid index