| /** |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "TFExtractTopLabels.h" |
| |
| #include "tensorflow/cc/ops/standard_ops.h" |
| |
| namespace org { |
| namespace apache { |
| namespace nifi { |
| namespace minifi { |
| namespace processors { |
| |
| core::Relationship TFExtractTopLabels::Success( // NOLINT |
| "success", |
| "Successful FlowFiles are sent here with labels as attributes"); |
| core::Relationship TFExtractTopLabels::Retry( // NOLINT |
| "retry", |
| "Failures which might work if retried"); |
| core::Relationship TFExtractTopLabels::Failure( // NOLINT |
| "failure", |
| "Failures which will not work if retried"); |
| |
| void TFExtractTopLabels::initialize() { |
| std::set<core::Property> properties; |
| setSupportedProperties(std::move(properties)); |
| |
| std::set<core::Relationship> relationships; |
| relationships.insert(Success); |
| relationships.insert(Retry); |
| relationships.insert(Failure); |
| setSupportedRelationships(std::move(relationships)); |
| } |
| |
| void TFExtractTopLabels::onSchedule(core::ProcessContext *context, core::ProcessSessionFactory *sessionFactory) { |
| } |
| |
| void TFExtractTopLabels::onTrigger(const std::shared_ptr<core::ProcessContext> &context, |
| const std::shared_ptr<core::ProcessSession> &session) { |
| auto flow_file = session->get(); |
| |
| if (!flow_file) { |
| return; |
| } |
| |
| try { |
| |
| // Read labels |
| std::string tf_type; |
| flow_file->getAttribute("tf.type", tf_type); |
| std::shared_ptr<std::vector<std::string>> labels; |
| |
| { |
| std::lock_guard<std::mutex> guard(labels_mtx_); |
| |
| if (tf_type == "labels") { |
| logger_->log_info("Reading new labels..."); |
| auto new_labels = std::make_shared<std::vector<std::string>>(); |
| LabelsReadCallback cb(new_labels); |
| session->read(flow_file, &cb); |
| labels_ = new_labels; |
| logger_->log_info("Read %d new labels", labels_->size()); |
| session->remove(flow_file); |
| return; |
| } |
| |
| labels = labels_; |
| } |
| |
| // Read input tensor from flow file |
| auto input_tensor_proto = std::make_shared<tensorflow::TensorProto>(); |
| TensorReadCallback tensor_cb(input_tensor_proto); |
| session->read(flow_file, &tensor_cb); |
| |
| tensorflow::Tensor input; |
| input.FromProto(*input_tensor_proto); |
| auto input_flat = input.flat<float>(); |
| |
| std::vector<std::pair<uint64_t, float>> scores; |
| |
| for (int i = 0; i < input_flat.size(); i++) { |
| scores.emplace_back(std::make_pair(i, input_flat(i))); |
| } |
| |
| std::sort(scores.begin(), scores.end(), [](const std::pair<uint64_t, float> &a, |
| const std::pair<uint64_t, float> &b) { |
| return a.second > b.second; |
| }); |
| |
| for (int i = 0; i < 5 && i < scores.size(); i++) { |
| if (!labels || scores[i].first > labels->size()) { |
| logger_->log_error("Label index is out of range (are the correct labels loaded?); routing to retry..."); |
| session->transfer(flow_file, Retry); |
| return; |
| } |
| flow_file->addAttribute("tf.top_label_" + std::to_string(i), labels->at(scores[i].first)); |
| } |
| |
| session->transfer(flow_file, Success); |
| |
| } catch (std::exception &exception) { |
| logger_->log_error("Caught Exception %s", exception.what()); |
| session->transfer(flow_file, Failure); |
| this->yield(); |
| } catch (...) { |
| logger_->log_error("Caught Exception"); |
| session->transfer(flow_file, Failure); |
| this->yield(); |
| } |
| } |
| |
| int64_t TFExtractTopLabels::LabelsReadCallback::process(std::shared_ptr<io::BaseStream> stream) { |
| int64_t total_read = 0; |
| std::string label; |
| uint64_t max_label_len = 65536; |
| label.resize(max_label_len); |
| std::string buf; |
| uint64_t label_size = 0; |
| uint64_t buf_size = 8096; |
| buf.resize(buf_size); |
| |
| while (total_read < stream->getSize()) { |
| auto read = stream->read(reinterpret_cast<uint8_t *>(&buf[0]), static_cast<int>(buf_size)); |
| |
| for (auto i = 0; i < read; i++) { |
| if (buf[i] == '\n' || total_read + i == stream->getSize()) { |
| labels_->emplace_back(label.substr(0, label_size)); |
| label_size = 0; |
| } else { |
| label[label_size] = buf[i]; |
| label_size++; |
| } |
| } |
| |
| total_read += read; |
| } |
| |
| return total_read; |
| } |
| |
| int64_t TFExtractTopLabels::TensorReadCallback::process(std::shared_ptr<io::BaseStream> stream) { |
| std::string tensor_proto_buf; |
| tensor_proto_buf.resize(stream->size()); |
| auto num_read = stream->read(reinterpret_cast<uint8_t *>(&tensor_proto_buf[0]), |
| static_cast<int>(stream->size())); |
| |
| if (num_read != stream->size()) { |
| throw std::runtime_error("TensorReadCallback failed to fully read flow file input stream"); |
| } |
| |
| tensor_proto_->ParseFromString(tensor_proto_buf); |
| return num_read; |
| } |
| |
| } /* namespace processors */ |
| } /* namespace minifi */ |
| } /* namespace nifi */ |
| } /* namespace apache */ |
| } /* namespace org */ |