blob: 8d0092d3b511b1065186eaaba77b8f4065136441 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "TFApplyGraph.h"
#include <core/ProcessContext.h>
#include <core/ProcessSession.h>
#include <tensorflow/cc/ops/standard_ops.h>
namespace org {
namespace apache {
namespace nifi {
namespace minifi {
namespace processors {
core::Property TFApplyGraph::InputNode(
core::PropertyBuilder::createProperty("Input Node")
->withDescription(
"The node of the TensorFlow graph to feed tensor inputs to")
->withDefaultValue("")
->build());
core::Property TFApplyGraph::OutputNode(
core::PropertyBuilder::createProperty("Output Node")
->withDescription(
"The node of the TensorFlow graph to read tensor outputs from")
->withDefaultValue("")
->build());
core::Relationship TFApplyGraph::Success( // NOLINT
"success",
"Successful graph application outputs");
core::Relationship TFApplyGraph::Retry( // NOLINT
"retry",
"Inputs which fail graph application but may work if sent again");
core::Relationship TFApplyGraph::Failure( // NOLINT
"failure",
"Failures which will not work if retried");
void TFApplyGraph::initialize() {
std::set<core::Property> properties;
properties.insert(InputNode);
properties.insert(OutputNode);
setSupportedProperties(std::move(properties));
std::set<core::Relationship> relationships;
relationships.insert(Success);
relationships.insert(Retry);
relationships.insert(Failure);
setSupportedRelationships(std::move(relationships));
}
void TFApplyGraph::onSchedule(core::ProcessContext *context, core::ProcessSessionFactory *sessionFactory) {
context->getProperty(InputNode.getName(), input_node_);
if (input_node_.empty()) {
logger_->log_error("Invalid input node");
}
context->getProperty(OutputNode.getName(), output_node_);
if (output_node_.empty()) {
logger_->log_error("Invalid output node");
}
}
void TFApplyGraph::onTrigger(const std::shared_ptr<core::ProcessContext> &context,
const std::shared_ptr<core::ProcessSession> &session) {
auto flow_file = session->get();
if (!flow_file) {
return;
}
try {
// Read graph
std::string tf_type;
flow_file->getAttribute("tf.type", tf_type);
std::shared_ptr<tensorflow::GraphDef> graph_def;
uint32_t graph_version;
{
std::lock_guard<std::mutex> guard(graph_def_mtx_);
if ("graph" == tf_type) {
logger_->log_info("Reading new graph def");
graph_def_ = std::make_shared<tensorflow::GraphDef>();
GraphReadCallback graph_cb(graph_def_);
session->read(flow_file, &graph_cb);
graph_version_++;
logger_->log_info("Read graph version: %i", graph_version_);
session->remove(flow_file);
return;
}
graph_version = graph_version_;
graph_def = graph_def_;
}
if (!graph_def) {
logger_->log_error("Cannot process input because no graph has been defined");
session->transfer(flow_file, Retry);
return;
}
// Use an existing context, if one is available
std::shared_ptr<TFContext> ctx;
if (tf_context_q_.try_dequeue(ctx)) {
logger_->log_debug("Using available TensorFlow context");
if (ctx->graph_version != graph_version) {
logger_->log_info("Allowing session with stale graph to expire");
ctx = nullptr;
}
}
if (!ctx) {
logger_->log_info("Creating new TensorFlow context");
tensorflow::SessionOptions options;
ctx = std::make_shared<TFContext>();
ctx->tf_session.reset(tensorflow::NewSession(options));
ctx->graph_version = graph_version;
auto status = ctx->tf_session->Create(*graph_def);
if (!status.ok()) {
std::string msg = "Failed to create TensorFlow session: ";
msg.append(status.ToString());
throw std::runtime_error(msg);
}
}
// Apply graph
// Read input tensor from flow file
auto input_tensor_proto = std::make_shared<tensorflow::TensorProto>();
TensorReadCallback tensor_cb(input_tensor_proto);
session->read(flow_file, &tensor_cb);
tensorflow::Tensor input;
input.FromProto(*input_tensor_proto);
std::vector<tensorflow::Tensor> outputs;
auto status = ctx->tf_session->Run({{input_node_, input}}, {output_node_}, {}, &outputs);
if (!status.ok()) {
std::string msg = "Failed to apply TensorFlow graph: ";
msg.append(status.ToString());
throw std::runtime_error(msg);
}
// Create output flow file for each output tensor
for (const auto &output : outputs) {
auto tensor_proto = std::make_shared<tensorflow::TensorProto>();
output.AsProtoTensorContent(tensor_proto.get());
logger_->log_info("Writing output tensor flow file");
TensorWriteCallback write_cb(tensor_proto);
session->write(flow_file, &write_cb);
session->transfer(flow_file, Success);
}
// Make context available for use again
if (tf_context_q_.size_approx() < getMaxConcurrentTasks()) {
logger_->log_debug("Releasing TensorFlow context");
tf_context_q_.enqueue(ctx);
} else {
logger_->log_info("Destroying TensorFlow context because it is no longer needed");
}
} catch (std::exception &exception) {
logger_->log_error("Caught Exception %s", exception.what());
session->transfer(flow_file, Failure);
this->yield();
} catch (...) {
logger_->log_error("Caught Exception");
session->transfer(flow_file, Failure);
this->yield();
}
}
int64_t TFApplyGraph::GraphReadCallback::process(std::shared_ptr<io::BaseStream> stream) {
std::string graph_proto_buf;
graph_proto_buf.resize(stream->getSize());
auto num_read = stream->readData(reinterpret_cast<uint8_t *>(&graph_proto_buf[0]),
static_cast<int>(stream->getSize()));
if (num_read != stream->getSize()) {
throw std::runtime_error("GraphReadCallback failed to fully read flow file input stream");
}
graph_def_->ParseFromString(graph_proto_buf);
return num_read;
}
int64_t TFApplyGraph::TensorReadCallback::process(std::shared_ptr<io::BaseStream> stream) {
std::string tensor_proto_buf;
tensor_proto_buf.resize(stream->getSize());
auto num_read = stream->readData(reinterpret_cast<uint8_t *>(&tensor_proto_buf[0]),
static_cast<int>(stream->getSize()));
if (num_read != stream->getSize()) {
throw std::runtime_error("TensorReadCallback failed to fully read flow file input stream");
}
tensor_proto_->ParseFromString(tensor_proto_buf);
return num_read;
}
int64_t TFApplyGraph::TensorWriteCallback::process(std::shared_ptr<io::BaseStream> stream) {
auto tensor_proto_buf = tensor_proto_->SerializeAsString();
auto num_wrote = stream->writeData(reinterpret_cast<uint8_t *>(&tensor_proto_buf[0]),
static_cast<int>(tensor_proto_buf.size()));
if (num_wrote != tensor_proto_buf.size()) {
throw std::runtime_error("TensorWriteCallback failed to fully write flow file output stream");
}
return num_wrote;
}
} /* namespace processors */
} /* namespace minifi */
} /* namespace nifi */
} /* namespace apache */
} /* namespace org */