blob: 70d6348388b65bf5633bfae6dfd9409839157d68 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.tajo.plan.function;
import org.apache.tajo.catalog.CatalogUtil;
import org.apache.tajo.catalog.FunctionDesc;
import org.apache.tajo.common.TajoDataTypes;
import org.apache.tajo.datum.Datum;
import org.apache.tajo.datum.DatumFactory;
import org.apache.tajo.plan.function.python.PythonScriptEngine;
import org.apache.tajo.storage.Tuple;
import java.io.IOException;
public class PythonAggFunctionInvoke extends AggFunctionInvoke implements Cloneable {
private transient PythonScriptEngine scriptEngine;
private transient PythonAggFunctionContext prevContext;
private static int nextContextId = 0;
/**
* Aggregated result should be kept in Tajo task rather than Python UDAF to control memory usage.
* {@link PythonAggFunctionContext} is to support executing aggregation with keys.
* It stores a snapshot of Python UDAF class instance as a json string.
*
* For each UDAF call with different aggregation key,
* {@link PythonAggFunctionInvoke} calls {@link PythonAggFunctionInvoke#updateContextIfNecessary} to backup and restore
* intermediate aggregation states for the previous key and the current key, respectively.
*/
public static class PythonAggFunctionContext implements FunctionContext {
final int id; // id to identify each context
String jsonData; // snapshot of Python class
public PythonAggFunctionContext() {
this.id = nextContextId++;
}
public void setJsonData(String jsonData) {
this.jsonData = jsonData;
}
public String getJsonData() {
return jsonData;
}
}
public PythonAggFunctionInvoke(FunctionDesc functionDesc) {
super(functionDesc);
}
@Override
public void init(FunctionInvokeContext context) throws IOException {
this.scriptEngine = (PythonScriptEngine) context.getScriptEngine();
}
@Override
public FunctionContext newContext() {
return new PythonAggFunctionContext();
}
/**
* Context does not need to be updated per every UDAF call.
* If the current aggregation key is same with the previous one,
* python-side context doesn't need to be updated because it already contains necessary intermediate result.
*
* @param context
*/
private void updateContextIfNecessary(FunctionContext context) {
PythonAggFunctionContext givenContext = (PythonAggFunctionContext) context;
if (prevContext == null || prevContext.id != givenContext.id) {
try {
if (prevContext != null) {
scriptEngine.updateJavaSideContext(prevContext);
}
scriptEngine.updatePythonSideContext(givenContext);
prevContext = givenContext;
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}
@Override
public void eval(FunctionContext context, Tuple params) {
updateContextIfNecessary(context);
scriptEngine.callAggFunc(context, params);
}
@Override
public void merge(FunctionContext context, Tuple params) {
if (params.isBlankOrNull(0)) {
return;
}
updateContextIfNecessary(context);
scriptEngine.callAggFunc(context, params);
}
@Override
public Datum getPartialResult(FunctionContext context) {
updateContextIfNecessary(context);
// partial results are stored as json strings.
String result = scriptEngine.getPartialResult(context);
return DatumFactory.createText(result);
}
@Override
public TajoDataTypes.DataType getPartialResultType() {
return CatalogUtil.newSimpleDataType(TajoDataTypes.Type.TEXT);
}
@Override
public Datum terminate(FunctionContext context) {
updateContextIfNecessary(context);
return scriptEngine.getFinalResult(context);
}
@Override
public Object clone() throws CloneNotSupportedException {
// nothing to do
return super.clone();
}
}