blob: c4b8f55c6c2f9b001c387b22d6ca851ebf47b8c6 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include "vec/functions/function_wasm.h"
#include <brpc/controller.h>
#include <fmt/format.h>
#include <algorithm>
#include <memory>
#include <string>
#include <utility>
#include "gutil/strings/substitute.h"
#include "runtime/exec_env.h"
#include "runtime/user_function_cache.h"
#include "vec/columns/column.h"
#include "vec/data_types/data_type.h"
#include "vec/data_types/serde/data_type_serde.h"
#include "vec/functions/function.h"
namespace doris::vectorized {
FunctionWasm::FunctionWasm(const TFunction& fn, const DataTypes& argument_types,
const DataTypePtr& return_type)
: _argument_types(argument_types), _return_type(return_type), _tfn(fn) {
_is_nullable = false;
for (const auto& type : argument_types) {
auto argument_type = type;
if (type->is_nullable()) {
argument_type = remove_nullable(type);
_is_nullable = true;
}
_not_nullable_argument_types.push_back(argument_type);
}
}
Status FunctionWasm::open(FunctionContext* context, FunctionContext::FunctionStateScope scope) {
if (scope == FunctionContext::FRAGMENT_LOCAL) {
string local_location;
auto* function_cache = UserFunctionCache::instance();
RETURN_IF_ERROR(function_cache->get_watpath(_tfn.id, _tfn.hdfs_location, _tfn.checksum,
&local_location));
std::shared_ptr<WasmFunctionManager> manager = std::make_shared<WasmFunctionManager>();
context->set_function_state(FunctionContext::THREAD_LOCAL, manager);
std::ifstream wat_file;
wat_file.open(local_location.c_str());
std::stringstream str_stream;
str_stream << wat_file.rdbuf();
const std::string wasm_body = str_stream.str();
manager->RegisterFunction(_tfn.name.function_name, _tfn.scalar_fn.symbol, wasm_body);
}
return Status::OK();
}
Status FunctionWasm::execute(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
size_t result, size_t input_rows_count, bool dry_run) {
/**
* i32, a 32-bit integer (equivalent to C++’s signed long int)
* i64, a 64-bit integer (equivalent to C++’s signed long long int)
* f32, 32-bit float (equivalent to C++’s float)
* f64, 64-bit float (equivalent to C++’s double)
*/
int arg_size = arguments.size();
ColumnPtr data_cols[arg_size];
auto* manager = reinterpret_cast<WasmFunctionManager*>(
context->get_function_state(FunctionContext::THREAD_LOCAL));
auto return_type = _return_type;
auto result_nullable = return_type->is_nullable();
ColumnUInt8::MutablePtr null_map = nullptr;
if (result_nullable) {
return_type = remove_nullable(_return_type);
null_map = ColumnUInt8::create(input_rows_count, 0);
memset(null_map->get_data().data(), 0, input_rows_count);
}
auto result_col = return_type->create_column();
result_col->resize(input_rows_count);
// check type : defined datatype same with param datatype
for (size_t arg_idx = 0; arg_idx < arg_size; ++arg_idx) {
ColumnWithTypeAndName& column = block.get_by_position(arguments[arg_idx]);
DataTypePtr data_type = column.type;
if (data_type->is_nullable() && !_is_nullable) {
return Status::InternalError(fmt::format(
"Defined datatype is not nullable, but param datatype is nullable"));
}
if (data_type->is_nullable()) {
data_type = remove_nullable(data_type);
}
DCHECK(_not_nullable_argument_types[arg_idx]->equals(*data_type))
<< " input column's type is " + data_type->get_name()
<< " does not equal to required type "
<< _not_nullable_argument_types[arg_idx]->get_name();
auto data_col = column.column->convert_to_full_column_if_const();
data_cols[arg_idx] = data_col;
}
// step1. process column value to wasm param
// step2. call wasm function
// step3. return wasm result to column value
// TODO: vec the code to call wasm fun
int row_size = data_cols[0]->size();
for (size_t i = 0; i < row_size; ++i) {
std::vector<wasmtime::Val> params;
for (size_t arg_idx = 0; arg_idx < arg_size; ++arg_idx) {
WhichDataType which_type(_not_nullable_argument_types[arg_idx]);
if (data_cols[arg_idx]->is_null_at(i)) {
null_map->get_data()[i] = 1;
continue;
}
if (which_type.is_int32()) {
auto data_col = data_cols[arg_idx];
if (data_col->is_nullable()) {
data_col = remove_nullable(data_col);
}
const auto* param_column = check_and_get_column<ColumnInt32>(data_col);
params.emplace_back(param_column->get_data()[i]);
} else if (which_type.is_int64()) {
auto data_col = data_cols[arg_idx];
if (data_col->is_nullable()) {
data_col = remove_nullable(data_col);
}
const auto* param_column = check_and_get_column<ColumnInt64>(data_col);
params.emplace_back(param_column->get_data()[i]);
} else if (which_type.is_float32()) {
auto data_col = data_cols[arg_idx];
if (data_col->is_nullable()) {
data_col = remove_nullable(data_col);
}
const auto* param_column = check_and_get_column<ColumnFloat32>(data_col);
params.emplace_back(param_column->get_data()[i]);
} else if (which_type.is_float64()) {
auto data_col = data_cols[arg_idx];
if (data_col->is_nullable()) {
data_col = remove_nullable(data_col);
}
const auto* param_column = check_and_get_column<ColumnFloat64>(data_col);
params.emplace_back(param_column->get_data()[i]);
}
}
if (null_map->get_data()[i] == 1) {
continue;
}
auto rets = manager->runElemFunc(_tfn.name.function_name, params);
auto ret = rets.at(0);
if (ret.kind() == wasmtime::ValKind::I32) {
reinterpret_cast<ColumnInt32&>(*result_col).get_data()[i] = ret.i32();
} else if (ret.kind() == wasmtime::ValKind::I64) {
reinterpret_cast<ColumnInt64&>(*result_col).get_data()[i] = ret.i64();
} else if (ret.kind() == wasmtime::ValKind::F32) {
reinterpret_cast<ColumnFloat32&>(*result_col).get_data()[i] = ret.f32();
} else if (ret.kind() == wasmtime::ValKind::F64) {
reinterpret_cast<ColumnFloat64&>(*result_col).get_data()[i] = ret.f64();
}
}
if (result_nullable) {
block.replace_by_position(
result, ColumnNullable::create(std::move(result_col), std::move(null_map)));
} else {
block.replace_by_position(result, std::move(result_col));
}
return Status::OK();
}
} // namespace doris::vectorized