| /** |
| * |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "c2/protocols/RESTProtocol.h" |
| |
| #include <algorithm> |
| #include <memory> |
| #include <utility> |
| #include <map> |
| #include <string> |
| #include <vector> |
| #include <list> |
| |
| namespace org { |
| namespace apache { |
| namespace nifi { |
| namespace minifi { |
| namespace c2 { |
| |
| const C2Payload RESTProtocol::parseJsonResponse(const C2Payload &payload, const std::vector<char> &response) { |
| rapidjson::Document root; |
| |
| try { |
| rapidjson::ParseResult ok = root.Parse(response.data(), response.size()); |
| if (ok) { |
| std::string requested_operation = getOperation(payload); |
| |
| std::string identifier; |
| if (root.HasMember("operationid")) { |
| identifier = root["operationid"].GetString(); |
| } else if (root.HasMember("operationId")) { |
| identifier = root["operationId"].GetString(); |
| } else if (root.HasMember("identifier")) { |
| identifier = root["identifier"].GetString(); |
| } |
| if (root["requested_operations"].Size() == 0 && root["requestedOperations"].Size() == 0) |
| return std::move(C2Payload(payload.getOperation(), state::UpdateState::READ_COMPLETE, true)); |
| |
| C2Payload new_payload(payload.getOperation(), state::UpdateState::NESTED, true); |
| |
| if (!identifier.empty()) |
| new_payload.setIdentifier(identifier); |
| auto array = root["requested_operations"].GetArray(); |
| if (root["requested_operations"].Size() == 0) |
| array = root["requestedOperations"].GetArray(); |
| for (const rapidjson::Value& request : array) { |
| Operation newOp = stringToOperation(request["operation"].GetString()); |
| C2Payload nested_payload(newOp, state::UpdateState::READ_COMPLETE, true); |
| C2ContentResponse new_command(newOp); |
| new_command.delay = 0; |
| new_command.required = true; |
| new_command.ttl = -1; |
| |
| // set the identifier if one exists |
| if (request.HasMember("operationid")) { |
| if (request["operationid"].IsNumber()) |
| new_command.ident = std::to_string(request["operationid"].GetInt64()); |
| else if (request["operationid"].IsString()) |
| new_command.ident = request["operationid"].GetString(); |
| else |
| throw(Exception(SITE2SITE_EXCEPTION, "Invalid type for operationid")); |
| nested_payload.setIdentifier(new_command.ident); |
| } else if (request.HasMember("operationId")) { |
| if (request["operationId"].IsNumber()) |
| new_command.ident = std::to_string(request["operationId"].GetInt64()); |
| else if (request["operationId"].IsString()) |
| new_command.ident = request["operationId"].GetString(); |
| else |
| throw(Exception(SITE2SITE_EXCEPTION, "Invalid type for operationId")); |
| nested_payload.setIdentifier(new_command.ident); |
| } else if (request.HasMember("identifier")) { |
| if (request["identifier"].IsNumber()) |
| new_command.ident = std::to_string(request["identifier"].GetInt64()); |
| else if (request["identifier"].IsString()) |
| new_command.ident = request["identifier"].GetString(); |
| else |
| throw(Exception(SITE2SITE_EXCEPTION, "Invalid type for operationid")); |
| nested_payload.setIdentifier(new_command.ident); |
| } |
| |
| if (request.HasMember("name")) { |
| new_command.name = request["name"].GetString(); |
| } else if (request.HasMember("operand")) { |
| new_command.name = request["operand"].GetString(); |
| } |
| |
| if (request.HasMember("content") && request["content"].MemberCount() > 0) { |
| if (request["content"].IsArray()) { |
| for (const auto &member : request["content"].GetArray()) |
| new_command.operation_arguments[member.GetString()] = member.GetString(); |
| } else { |
| for (const auto &member : request["content"].GetObject()) |
| new_command.operation_arguments[member.name.GetString()] = member.value.GetString(); |
| } |
| } else if (request.HasMember("args") && request["args"].MemberCount() > 0) { |
| if (request["args"].IsArray()) { |
| for (const auto &member : request["args"].GetArray()) |
| new_command.operation_arguments[member.GetString()] = member.GetString(); |
| } else { |
| for (const auto &member : request["args"].GetObject()) |
| new_command.operation_arguments[member.name.GetString()] = member.value.GetString(); |
| } |
| } |
| nested_payload.addContent(std::move(new_command)); |
| new_payload.addPayload(std::move(nested_payload)); |
| } |
| |
| // we have a response for this request |
| return new_payload; |
| // } |
| } |
| } catch (...) { |
| } |
| return std::move(C2Payload(payload.getOperation(), state::UpdateState::READ_COMPLETE, true)); |
| } |
| |
| void setJsonStr(const std::string& key, const state::response::ValueNode& value, rapidjson::Value& parent, rapidjson::Document::AllocatorType& alloc) { // NOLINT |
| rapidjson::Value keyVal; |
| rapidjson::Value valueVal; |
| const char* c_key = key.c_str(); |
| |
| auto base_type = value.getValue(); |
| keyVal.SetString(c_key, key.length(), alloc); |
| |
| if (auto sub_type = std::dynamic_pointer_cast<state::response::IntValue>(base_type)) { |
| valueVal.SetInt(sub_type->getValue()); |
| } else if (auto sub_type = std::dynamic_pointer_cast<state::response::Int64Value>(base_type)) { |
| valueVal.SetInt64(sub_type->getValue()); |
| } else if (auto sub_type = std::dynamic_pointer_cast<state::response::BoolValue>(base_type)) { |
| valueVal.SetBool(sub_type->getValue()); |
| } else { |
| auto str = base_type->getStringValue(); |
| const char* c_val = str.c_str(); |
| valueVal.SetString(c_val, str.length(), alloc); |
| } |
| parent.AddMember(keyVal, valueVal, alloc); |
| } |
| |
| rapidjson::Value RESTProtocol::getStringValue(const std::string& value, rapidjson::Document::AllocatorType& alloc) { // NOLINT |
| rapidjson::Value Val; |
| Val.SetString(value.c_str(), value.length(), alloc); |
| return Val; |
| } |
| |
| void RESTProtocol::mergePayloadContent(rapidjson::Value &target, const C2Payload &payload, rapidjson::Document::AllocatorType &alloc) { |
| const std::vector<C2ContentResponse> &content = payload.getContent(); |
| bool all_empty = content.size() > 0 ? true : false; |
| bool is_parent_array = target.IsArray(); |
| |
| for (const auto &payload_content : content) { |
| for (auto content : payload_content.operation_arguments) { |
| if (!content.second.empty()) { |
| all_empty = false; |
| break; |
| } |
| } |
| if (!all_empty) |
| break; |
| } |
| |
| if (all_empty) { |
| if (!is_parent_array) { |
| target.SetArray(); |
| is_parent_array = true; |
| } |
| rapidjson::Value arr(rapidjson::kArrayType); |
| for (const auto &payload_content : content) { |
| for (auto content : payload_content.operation_arguments) { |
| rapidjson::Value keyVal; |
| keyVal.SetString(content.first.c_str(), content.first.length(), alloc); |
| if (is_parent_array) |
| target.PushBack(keyVal, alloc); |
| else |
| arr.PushBack(keyVal, alloc); |
| } |
| } |
| |
| if (!is_parent_array) { |
| rapidjson::Value sub_key = getStringValue(payload.getLabel(), alloc); |
| target.AddMember(sub_key, arr, alloc); |
| } |
| return; |
| } |
| for (const auto &payload_content : content) { |
| rapidjson::Value payload_content_values(rapidjson::kObjectType); |
| bool use_sub_option = true; |
| if (payload_content.op == payload.getOperation()) { |
| for (auto content : payload_content.operation_arguments) { |
| setJsonStr(content.first, content.second, target, alloc); |
| } |
| } else { |
| } |
| if (use_sub_option) { |
| rapidjson::Value sub_key = getStringValue(payload_content.name, alloc); |
| } |
| } |
| } |
| |
| std::string RESTProtocol::serializeJsonRootPayload(const C2Payload& payload) { |
| rapidjson::Document json_payload(payload.isContainer() ? rapidjson::kArrayType : rapidjson::kObjectType); |
| rapidjson::Document::AllocatorType &alloc = json_payload.GetAllocator(); |
| |
| rapidjson::Value opReqStrVal; |
| std::string operation_request_str = getOperation(payload); |
| opReqStrVal.SetString(operation_request_str.c_str(), operation_request_str.length(), alloc); |
| json_payload.AddMember("operation", opReqStrVal, alloc); |
| |
| std::string operationid = payload.getIdentifier(); |
| if (operationid.length() > 0) { |
| json_payload.AddMember("operationid", getStringValue(operationid, alloc), alloc); |
| json_payload.AddMember("operationId", getStringValue(operationid, alloc), alloc); |
| json_payload.AddMember("identifier", getStringValue(operationid, alloc), alloc); |
| } |
| |
| mergePayloadContent(json_payload, payload, alloc); |
| |
| for (const auto &nested_payload : payload.getNestedPayloads()) { |
| if (!minimize_updates_ || (minimize_updates_ && !containsPayload(nested_payload))) { |
| rapidjson::Value np_key = getStringValue(nested_payload.getLabel(), alloc); |
| rapidjson::Value np_value = serializeJsonPayload(nested_payload, alloc); |
| if (minimize_updates_) { |
| nested_payloads_.insert(std::pair<std::string, C2Payload>(nested_payload.getLabel(), nested_payload)); |
| } |
| json_payload.AddMember(np_key, np_value, alloc); |
| } |
| } |
| |
| rapidjson::StringBuffer buffer; |
| rapidjson::PrettyWriter<rapidjson::StringBuffer> writer(buffer); |
| json_payload.Accept(writer); |
| return buffer.GetString(); |
| } |
| |
| bool RESTProtocol::containsPayload(const C2Payload &o) { |
| auto it = nested_payloads_.find(o.getLabel()); |
| if (it != nested_payloads_.end()) { |
| return it->second == o; |
| } |
| return false; |
| } |
| |
| rapidjson::Value RESTProtocol::serializeJsonPayload(const C2Payload &payload, rapidjson::Document::AllocatorType &alloc) { |
| // get the name from the content |
| rapidjson::Value json_payload(payload.isContainer() ? rapidjson::kArrayType : rapidjson::kObjectType); |
| |
| std::map<std::string, std::list<rapidjson::Value*>> children; |
| |
| for (const auto &nested_payload : payload.getNestedPayloads()) { |
| rapidjson::Value* child_payload = new rapidjson::Value(serializeJsonPayload(nested_payload, alloc)); |
| |
| children[nested_payload.getLabel()].push_back(child_payload); |
| } |
| |
| for (auto child_vector : children) { |
| rapidjson::Value children_json; |
| rapidjson::Value newMemberKey = getStringValue(child_vector.first, alloc); |
| if (child_vector.second.size() > 1) { |
| children_json.SetArray(); |
| for (auto child : child_vector.second) { |
| if (json_payload.IsArray()) |
| json_payload.PushBack(child->Move(), alloc); |
| else |
| children_json.PushBack(child->Move(), alloc); |
| } |
| if (!json_payload.IsArray()) |
| json_payload.AddMember(newMemberKey, children_json, alloc); |
| } else if (child_vector.second.size() == 1) { |
| rapidjson::Value* first = child_vector.second.front(); |
| if (first->IsObject() && first->HasMember(newMemberKey)) { |
| if (json_payload.IsArray()) |
| json_payload.PushBack((*first)[newMemberKey].Move(), alloc); |
| else |
| json_payload.AddMember(newMemberKey, (*first)[newMemberKey].Move(), alloc); |
| } else { |
| if (json_payload.IsArray()) { |
| json_payload.PushBack(first->Move(), alloc); |
| } else { |
| json_payload.AddMember(newMemberKey, first->Move(), alloc); |
| } |
| } |
| } |
| for (rapidjson::Value* child : child_vector.second) |
| delete child; |
| } |
| |
| mergePayloadContent(json_payload, payload, alloc); |
| return json_payload; |
| } |
| |
| std::string RESTProtocol::getOperation(const C2Payload &payload) { |
| switch (payload.getOperation()) { |
| case Operation::ACKNOWLEDGE: |
| return "acknowledge"; |
| case Operation::HEARTBEAT: |
| return "heartbeat"; |
| case Operation::RESTART: |
| return "restart"; |
| case Operation::DESCRIBE: |
| return "describe"; |
| case Operation::STOP: |
| return "stop"; |
| case Operation::START: |
| return "start"; |
| case Operation::UPDATE: |
| return "update"; |
| default: |
| return "heartbeat"; |
| } |
| } |
| |
| Operation RESTProtocol::stringToOperation(const std::string str) { |
| std::string op = str; |
| std::transform(str.begin(), str.end(), op.begin(), ::tolower); |
| if (op == "heartbeat") { |
| return Operation::HEARTBEAT; |
| } else if (op == "acknowledge") { |
| return Operation::ACKNOWLEDGE; |
| } else if (op == "update") { |
| return Operation::UPDATE; |
| } else if (op == "describe") { |
| return Operation::DESCRIBE; |
| } else if (op == "restart") { |
| return Operation::RESTART; |
| } else if (op == "clear") { |
| return Operation::CLEAR; |
| } else if (op == "stop") { |
| return Operation::STOP; |
| } else if (op == "start") { |
| return Operation::START; |
| } |
| return Operation::HEARTBEAT; |
| } |
| |
| } /* namespace c2 */ |
| } /* namespace minifi */ |
| } /* namespace nifi */ |
| } /* namespace apache */ |
| } /* namespace org */ |