blob: 879c1a5387ea36f96f835e3614e53311ed8bb2aa [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package hivemall.nlp.tokenizer;
import hivemall.UDFWithOptions;
import hivemall.utils.hadoop.HiveUtils;
import hivemall.utils.io.HttpUtils;
import hivemall.utils.io.IOUtils;
import hivemall.utils.lang.ExceptionUtils;
import hivemall.utils.lang.Preconditions;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.Reader;
import java.io.StringReader;
import java.net.HttpURLConnection;
import java.nio.charset.CharsetDecoder;
import java.nio.charset.CodingErrorAction;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import javax.annotation.Nonnull;
import javax.annotation.Nullable;
import org.apache.commons.cli.CommandLine;
import org.apache.commons.cli.Options;
import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentException;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.UDFType;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory;
import org.apache.hadoop.io.Text;
import org.apache.lucene.analysis.CharArraySet;
import org.apache.lucene.analysis.TokenStream;
import org.apache.lucene.analysis.ja.JapaneseAnalyzer;
import org.apache.lucene.analysis.ja.JapaneseTokenizer;
import org.apache.lucene.analysis.ja.JapaneseTokenizer.Mode;
import org.apache.lucene.analysis.ja.dict.UserDictionary;
import org.apache.lucene.analysis.ja.tokenattributes.PartOfSpeechAttribute;
import org.apache.lucene.analysis.tokenattributes.CharTermAttribute;
@Description(name = "tokenize_ja",
value = "_FUNC_(String line [, const string mode = \"normal\", const array<string> stopWords, const array<string> stopTags, const array<string> userDict (or string userDictURL)])"
+ " - returns tokenized strings in array<string>",
extended = "select tokenize_ja(\"kuromojiを使った分かち書きのテストです。第二引数にはnormal/search/extendedを指定できます。デフォルトではnormalモードです。\");\n"
+ "\n"
+ "> [\"kuromoji\",\"使う\",\"分かち書き\",\"テスト\",\"第\",\"二\",\"引数\",\"normal\",\"search\",\"extended\",\"指定\",\"デフォルト\",\"normal\",\" モード\"]\n")
@UDFType(deterministic = true, stateful = false)
public final class KuromojiUDF extends UDFWithOptions {
private static final int CONNECT_TIMEOUT_MS = 10000; // 10 sec
private static final int READ_TIMEOUT_MS = 60000; // 60 sec
private static final long MAX_INPUT_STREAM_SIZE = 32L * 1024L * 1024L; // ~32MB
private Mode _mode;
private boolean _returnPos;
private transient Object[] _result;
@Nullable
private String[] _stopWordsArray;
private Set<String> _stopTags;
@Nullable
private Object _userDictObj; // String[] or String
// workaround to avoid org.apache.hive.com.esotericsoftware.kryo.KryoException: java.util.ConcurrentModificationException
private transient JapaneseAnalyzer _analyzer;
@Override
protected Options getOptions() {
Options opts = new Options();
opts.addOption("mode", true,
"The tokenization mode. One of ['normal', 'search', 'extended', 'default' (normal)]");
opts.addOption("pos", false, "Return part-of-speech information");
return opts;
}
@Override
protected CommandLine processOptions(String optionValue) throws UDFArgumentException {
CommandLine cl = parseOptions(optionValue);
if (cl.hasOption("mode")) {
String modeStr = cl.getOptionValue("mode");
this._mode = tokenizationMode(modeStr);
}
this._returnPos = cl.hasOption("pos");
return cl;
}
@Override
public ObjectInspector initialize(ObjectInspector[] arguments) throws UDFArgumentException {
final int arglen = arguments.length;
if (arglen < 1 || arglen > 5) {
showHelp("Invalid number of arguments for `tokenize_ja`: " + arglen);
}
this._mode = Mode.NORMAL;
if (arglen >= 2) {
String arg1 = HiveUtils.getConstString(arguments[1]);
if (arg1 != null) {
if (arg1.startsWith("-")) {
processOptions(arg1);
} else {
this._mode = tokenizationMode(arg1);
}
}
}
if (arglen >= 3 && !HiveUtils.isVoidOI(arguments[2])) {
this._stopWordsArray = HiveUtils.getConstStringArray(arguments[2]);
}
this._stopTags =
(arglen >= 4) ? stopTags(arguments[3]) : JapaneseAnalyzer.getDefaultStopTags();
if (arglen >= 5) {
if (HiveUtils.isConstListOI(arguments[4])) {
this._userDictObj = HiveUtils.getConstStringArray(arguments[4]);
} else if (HiveUtils.isConstString(arguments[4])) {
this._userDictObj = HiveUtils.getConstString(arguments[4]);
} else {
throw new UDFArgumentException(
"User dictionary MUST be given as an array of constant string or constant string (URL)");
}
}
this._analyzer = null;
if (_returnPos) {
this._result = new Object[2];
ArrayList<String> fieldNames = new ArrayList<String>();
ArrayList<ObjectInspector> fieldOIs = new ArrayList<ObjectInspector>();
fieldNames.add("tokens");
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.writableStringObjectInspector));
fieldNames.add("pos");
fieldOIs.add(ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.writableStringObjectInspector));
return ObjectInspectorFactory.getStandardStructObjectInspector(fieldNames, fieldOIs);
} else {
return ObjectInspectorFactory.getStandardListObjectInspector(
PrimitiveObjectInspectorFactory.writableStringObjectInspector);
}
}
@Override
public Object evaluate(DeferredObject[] arguments) throws HiveException {
if (_analyzer == null) {
CharArraySet stopWords = stopWords(_stopWordsArray);
UserDictionary userDict = null;
if (_userDictObj instanceof String[]) {
userDict = userDictionary((String[]) _userDictObj);
} else if (_userDictObj instanceof String) {
userDict = userDictionary((String) _userDictObj);
}
this._analyzer = new JapaneseAnalyzer(userDict, _mode, stopWords, _stopTags);
}
Object arg0 = arguments[0].get();
if (arg0 == null) {
return null;
}
String line = arg0.toString();
if (_returnPos) {
return parseLine(_analyzer, line, _result);
} else {
return parseLine(_analyzer, line);
}
}
@Nonnull
private static Object[] parseLine(@Nonnull JapaneseAnalyzer analyzer, @Nonnull String line,
@Nonnull Object[] result) throws HiveException {
Objects.requireNonNull(result);
Preconditions.checkArgument(result.length == 2);
final List<Text> tokens = new ArrayList<Text>(32);
final List<Text> pos = new ArrayList<Text>(32);
TokenStream stream = null;
try {
stream = analyzer.tokenStream("", line);
if (stream != null) {
analyzeTokens(stream, tokens, pos);
}
} catch (IOException e) {
IOUtils.closeQuietly(analyzer);
throw new HiveException(e);
} finally {
IOUtils.closeQuietly(stream);
}
result[0] = tokens;
result[1] = pos;
return result;
}
@Nonnull
private static List<Text> parseLine(@Nonnull JapaneseAnalyzer analyzer, @Nonnull String line)
throws HiveException {
final List<Text> tokens = new ArrayList<Text>(32);
TokenStream stream = null;
try {
stream = analyzer.tokenStream("", line);
if (stream != null) {
analyzeTokens(stream, tokens);
}
} catch (IOException e) {
IOUtils.closeQuietly(analyzer);
throw new HiveException(e);
} finally {
IOUtils.closeQuietly(stream);
}
return tokens;
}
@Override
public void close() throws IOException {
IOUtils.closeQuietly(_analyzer);
}
@Nonnull
private static Mode tokenizationMode(@Nonnull final String arg) throws UDFArgumentException {
final Mode mode;
if ("NORMAL".equalsIgnoreCase(arg)) {
mode = Mode.NORMAL;
} else if ("SEARCH".equalsIgnoreCase(arg)) {
mode = Mode.SEARCH;
} else if ("EXTENDED".equalsIgnoreCase(arg)) {
mode = Mode.EXTENDED;
} else if ("DEFAULT".equalsIgnoreCase(arg)) {
mode = JapaneseTokenizer.DEFAULT_MODE;
} else {
throw new UDFArgumentException(
"Expected NORMAL|SEARCH|EXTENDED|DEFAULT but got an unexpected mode: " + arg);
}
return mode;
}
@Nonnull
private static CharArraySet stopWords(@Nullable final String[] array)
throws UDFArgumentException {
if (array == null) {
return JapaneseAnalyzer.getDefaultStopSet();
}
if (array.length == 0) {
return CharArraySet.EMPTY_SET;
}
return new CharArraySet(Arrays.asList(array), /* ignoreCase */true);
}
@Nonnull
private static Set<String> stopTags(@Nonnull final ObjectInspector oi)
throws UDFArgumentException {
if (HiveUtils.isVoidOI(oi)) {
return JapaneseAnalyzer.getDefaultStopTags();
}
final String[] array = HiveUtils.getConstStringArray(oi);
if (array == null) {
return JapaneseAnalyzer.getDefaultStopTags();
}
final int length = array.length;
if (length == 0) {
return Collections.emptySet();
}
final Set<String> results = new HashSet<String>(length);
for (int i = 0; i < length; i++) {
String s = array[i];
if (s != null) {
results.add(s);
}
}
return results;
}
@Nullable
private static UserDictionary userDictionary(@Nullable final String[] userDictArray)
throws UDFArgumentException {
if (userDictArray == null) {
return null;
}
final StringBuilder builder = new StringBuilder();
for (String row : userDictArray) {
builder.append(row).append('\n');
}
final Reader reader = new StringReader(builder.toString());
try {
return UserDictionary.open(reader); // return null if empty
} catch (Throwable e) {
throw new UDFArgumentException(
"Failed to create user dictionary based on the given array<string>: "
+ builder.toString() + '\n' + ExceptionUtils.prettyPrintStackTrace(e));
}
}
@Nullable
private static UserDictionary userDictionary(@Nullable final String userDictURL)
throws UDFArgumentException {
if (userDictURL == null) {
return null;
}
final HttpURLConnection conn;
try {
conn = HttpUtils.getHttpURLConnection(userDictURL);
} catch (IllegalArgumentException | IOException e) {
throw new UDFArgumentException("Failed to create HTTP connection to the URL: "
+ userDictURL + '\n' + ExceptionUtils.prettyPrintStackTrace(e));
}
// allow to read as a compressed GZIP file for efficiency
conn.setRequestProperty("Accept-Encoding", "gzip");
conn.setConnectTimeout(CONNECT_TIMEOUT_MS); // throw exception from connect()
conn.setReadTimeout(READ_TIMEOUT_MS); // throw exception from getXXX() methods
final int responseCode;
try {
responseCode = conn.getResponseCode();
} catch (IOException e) {
throw new UDFArgumentException("Failed to get response code: " + userDictURL + '\n'
+ ExceptionUtils.prettyPrintStackTrace(e));
}
if (responseCode != 200) {
throw new UDFArgumentException("Got invalid response code: " + responseCode);
}
final InputStream is;
try {
is = IOUtils.decodeInputStream(
HttpUtils.getLimitedInputStream(conn, MAX_INPUT_STREAM_SIZE));
} catch (NullPointerException | IOException e) {
throw new UDFArgumentException("Failed to get input stream from the connection: "
+ userDictURL + '\n' + ExceptionUtils.prettyPrintStackTrace(e));
}
CharsetDecoder decoder =
StandardCharsets.UTF_8.newDecoder()
.onMalformedInput(CodingErrorAction.REPORT)
.onUnmappableCharacter(CodingErrorAction.REPORT);
final Reader reader = new InputStreamReader(is, decoder);
try {
return UserDictionary.open(reader); // return null if empty
} catch (Throwable e) {
throw new UDFArgumentException(
"Failed to parse the file in CSV format (UTF-8 encoding is expected): "
+ userDictURL + '\n' + ExceptionUtils.prettyPrintStackTrace(e));
}
}
private static void analyzeTokens(@Nonnull final TokenStream stream,
@Nonnull final List<Text> tokens) throws IOException {
// instantiate an attribute placeholder once
CharTermAttribute termAttr = stream.getAttribute(CharTermAttribute.class);
stream.reset();
while (stream.incrementToken()) {
String term = termAttr.toString();
tokens.add(new Text(term));
}
}
private static void analyzeTokens(@Nonnull final TokenStream stream,
@Nonnull final List<Text> tokenResult, @Nonnull final List<Text> posResult)
throws IOException {
// instantiate an attribute placeholder once
CharTermAttribute termAttr = stream.getAttribute(CharTermAttribute.class);
PartOfSpeechAttribute posAttr = stream.addAttribute(PartOfSpeechAttribute.class);
stream.reset();
while (stream.incrementToken()) {
String term = termAttr.toString();
tokenResult.add(new Text(term));
String pos = posAttr.getPartOfSpeech();
posResult.add(new Text(pos));
}
}
@Override
public String getDisplayString(String[] children) {
return "tokenize_ja(" + Arrays.toString(children) + ')';
}
}