| /** |
| * |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "c2/protocols/RESTProtocol.h" |
| |
| #include <algorithm> |
| #include <list> |
| #include <map> |
| #include <memory> |
| #include <string> |
| #include <utility> |
| #include <vector> |
| |
| #include "core/TypedValues.h" |
| #include "utils/gsl.h" |
| |
| namespace org { |
| namespace apache { |
| namespace nifi { |
| namespace minifi { |
| namespace c2 { |
| |
| #ifdef WIN32 |
| #pragma push_macro("GetObject") |
| #undef GetObject |
| #endif |
| |
| AnnotatedValue parseAnnotatedValue(const rapidjson::Value& jsonValue) { |
| AnnotatedValue result; |
| if (jsonValue.IsObject() && jsonValue.HasMember("value")) { |
| result = jsonValue["value"].GetString(); |
| for (const auto& annotation : jsonValue.GetObject()) { |
| if (annotation.name.GetString() == std::string("value")) { |
| continue; |
| } |
| result.annotations[annotation.name.GetString()] = parseAnnotatedValue(annotation.value); |
| } |
| } else if (jsonValue.IsBool()) { |
| result = jsonValue.GetBool(); |
| } else { |
| result = jsonValue.GetString(); |
| } |
| return result; |
| } |
| |
| const C2Payload RESTProtocol::parseJsonResponse(const C2Payload &payload, const std::vector<char> &response) { |
| rapidjson::Document root; |
| |
| try { |
| rapidjson::ParseResult ok = root.Parse(response.data(), response.size()); |
| if (ok) { |
| std::string requested_operation = getOperation(payload); |
| |
| std::string identifier; |
| for (auto key : {"operationid", "operationId", "identifier"}) { |
| if (root.HasMember(key)) { |
| identifier = root[key].GetString(); |
| break; |
| } |
| } |
| |
| int size = 0; |
| for (auto key : {"requested_operations", "requestedOperations"}) { |
| if (root.HasMember(key)) { |
| size = root[key].Size(); |
| break; |
| } |
| } |
| |
| // neither must be there. We don't want assign array yet and cause an assertion error |
| if (size == 0) |
| return C2Payload(payload.getOperation(), state::UpdateState::READ_COMPLETE); |
| |
| C2Payload new_payload(payload.getOperation(), state::UpdateState::NESTED); |
| if (!identifier.empty()) |
| new_payload.setIdentifier(identifier); |
| |
| auto array = root.HasMember("requested_operations") ? root["requested_operations"].GetArray() : root["requestedOperations"].GetArray(); |
| |
| for (const rapidjson::Value& request : array) { |
| Operation newOp = stringToOperation(request["operation"].GetString()); |
| C2Payload nested_payload(newOp, state::UpdateState::READ_COMPLETE); |
| C2ContentResponse new_command(newOp); |
| new_command.delay = 0; |
| new_command.required = true; |
| new_command.ttl = -1; |
| |
| // set the identifier if one exists |
| for (auto key : {"operationid", "operationId", "identifier"}) { |
| if (request.HasMember(key)) { |
| if (request[key].IsNumber()) { |
| new_command.ident = std::to_string(request[key].GetInt64()); |
| } else if (request[key].IsString()) { |
| new_command.ident = request[key].GetString(); |
| } else { |
| throw Exception(SITE2SITE_EXCEPTION, "Invalid type for " + std::string{key}); |
| } |
| nested_payload.setIdentifier(new_command.ident); |
| break; |
| } |
| } |
| |
| if (request.HasMember("name")) { |
| new_command.name = request["name"].GetString(); |
| } else if (request.HasMember("operand")) { |
| new_command.name = request["operand"].GetString(); |
| } |
| |
| for (auto key : {"content", "args"}) { |
| if (request.HasMember(key) && request[key].IsObject()) { |
| for (const auto &member : request[key].GetObject()) { |
| new_command.operation_arguments[member.name.GetString()] = parseAnnotatedValue(member.value); |
| } |
| break; |
| } |
| } |
| |
| nested_payload.addContent(std::move(new_command)); |
| new_payload.addPayload(std::move(nested_payload)); |
| } |
| |
| // we have a response for this request |
| return new_payload; |
| // } |
| } |
| } catch (...) { |
| } |
| return C2Payload(payload.getOperation(), state::UpdateState::READ_COMPLETE); |
| } |
| |
| void setJsonStr(const std::string& key, const state::response::ValueNode& value, rapidjson::Value& parent, rapidjson::Document::AllocatorType& alloc) { // NOLINT |
| rapidjson::Value keyVal; |
| rapidjson::Value valueVal; |
| const char* c_key = key.c_str(); |
| auto base_type = value.getValue(); |
| keyVal.SetString(c_key, gsl::narrow<rapidjson::SizeType>(key.length()), alloc); |
| |
| auto type_index = base_type->getTypeIndex(); |
| if (auto sub_type = std::dynamic_pointer_cast<core::TransformableValue>(base_type)) { |
| auto str = base_type->getStringValue(); |
| const char* c_val = str.c_str(); |
| valueVal.SetString(c_val, gsl::narrow<rapidjson::SizeType>(str.length()), alloc); |
| } else { |
| if (type_index == state::response::Value::BOOL_TYPE) { |
| bool value = false; |
| base_type->convertValue(value); |
| valueVal.SetBool(value); |
| } else if (type_index == state::response::Value::INT_TYPE) { |
| int value = 0; |
| base_type->convertValue(value); |
| valueVal.SetInt(value); |
| } else if (type_index == state::response::Value::UINT32_TYPE) { |
| uint32_t value = 0; |
| base_type->convertValue(value); |
| valueVal.SetUint(value); |
| } else if (type_index == state::response::Value::INT64_TYPE) { |
| int64_t value = 0; |
| base_type->convertValue(value); |
| valueVal.SetInt64(value); |
| } else if (type_index == state::response::Value::UINT64_TYPE) { |
| int64_t value = 0; |
| base_type->convertValue(value); |
| valueVal.SetInt64(value); |
| } else { |
| auto str = base_type->getStringValue(); |
| const char* c_val = str.c_str(); |
| valueVal.SetString(c_val, gsl::narrow<rapidjson::SizeType>(str.length()), alloc); |
| } |
| } |
| parent.AddMember(keyVal, valueVal, alloc); |
| } |
| |
| rapidjson::Value RESTProtocol::getStringValue(const std::string& value, rapidjson::Document::AllocatorType& alloc) { // NOLINT |
| rapidjson::Value Val; |
| Val.SetString(value.c_str(), gsl::narrow<rapidjson::SizeType>(value.length()), alloc); |
| return Val; |
| } |
| |
| void RESTProtocol::mergePayloadContent(rapidjson::Value &target, const C2Payload &payload, rapidjson::Document::AllocatorType &alloc) { |
| const std::vector<C2ContentResponse> &content = payload.getContent(); |
| bool all_empty = !content.empty(); |
| bool is_parent_array = target.IsArray(); |
| |
| for (const auto &payload_content : content) { |
| for (const auto &op_arg : payload_content.operation_arguments) { |
| if (!op_arg.second.empty()) { |
| all_empty = false; |
| break; |
| } |
| } |
| if (!all_empty) |
| break; |
| } |
| |
| if (all_empty) { |
| if (!is_parent_array) { |
| target.SetArray(); |
| is_parent_array = true; |
| } |
| rapidjson::Value arr(rapidjson::kArrayType); |
| for (const auto &payload_content : content) { |
| for (const auto& op_arg : payload_content.operation_arguments) { |
| rapidjson::Value keyVal; |
| keyVal.SetString(op_arg.first.c_str(), gsl::narrow<rapidjson::SizeType>(op_arg.first.length()), alloc); |
| if (is_parent_array) { |
| target.PushBack(keyVal, alloc); |
| } else { |
| arr.PushBack(keyVal, alloc); |
| } |
| } |
| } |
| |
| if (!is_parent_array) { |
| rapidjson::Value sub_key = getStringValue(payload.getLabel(), alloc); |
| target.AddMember(sub_key, arr, alloc); |
| } |
| return; |
| } |
| for (const auto &payload_content : content) { |
| rapidjson::Value payload_content_values(rapidjson::kObjectType); |
| bool use_sub_option = true; |
| if (payload_content.op == payload.getOperation()) { |
| for (const auto& op_arg : payload_content.operation_arguments) { |
| if (!op_arg.second.empty()) { |
| setJsonStr(op_arg.first, op_arg.second, target, alloc); |
| } |
| } |
| } else { |
| } |
| if (use_sub_option) { |
| rapidjson::Value sub_key = getStringValue(payload_content.name, alloc); |
| } |
| } |
| } |
| |
| std::string RESTProtocol::serializeJsonRootPayload(const C2Payload& payload) { |
| rapidjson::Document json_payload(payload.isContainer() ? rapidjson::kArrayType : rapidjson::kObjectType); |
| rapidjson::Document::AllocatorType &alloc = json_payload.GetAllocator(); |
| |
| rapidjson::Value opReqStrVal; |
| std::string operation_request_str = getOperation(payload); |
| opReqStrVal.SetString(operation_request_str.c_str(), gsl::narrow<rapidjson::SizeType>(operation_request_str.length()), alloc); |
| json_payload.AddMember("operation", opReqStrVal, alloc); |
| |
| std::string operationid = payload.getIdentifier(); |
| if (operationid.length() > 0) { |
| json_payload.AddMember("operationId", getStringValue(operationid, alloc), alloc); |
| std::string operationStateStr = "FULLY_APPLIED"; |
| switch (payload.getStatus().getState()) { |
| case state::UpdateState::FULLY_APPLIED: |
| operationStateStr = "FULLY_APPLIED"; |
| break; |
| case state::UpdateState::PARTIALLY_APPLIED: |
| operationStateStr = "PARTIALLY_APPLIED"; |
| break; |
| case state::UpdateState::READ_ERROR: |
| operationStateStr = "OPERATION_NOT_UNDERSTOOD"; |
| break; |
| case state::UpdateState::SET_ERROR: |
| default: |
| operationStateStr = "NOT_APPLIED"; |
| } |
| |
| rapidjson::Value opstate(rapidjson::kObjectType); |
| |
| opstate.AddMember("state", getStringValue(operationStateStr, alloc), alloc); |
| const auto details = payload.getRawData(); |
| |
| opstate.AddMember("details", getStringValue(std::string(details.data(), details.size()), alloc), alloc); |
| |
| json_payload.AddMember("operationState", opstate, alloc); |
| json_payload.AddMember("identifier", getStringValue(operationid, alloc), alloc); |
| } |
| |
| mergePayloadContent(json_payload, payload, alloc); |
| |
| for (const auto &nested_payload : payload.getNestedPayloads()) { |
| if (!minimize_updates_ || (minimize_updates_ && !containsPayload(nested_payload))) { |
| rapidjson::Value np_key = getStringValue(nested_payload.getLabel(), alloc); |
| rapidjson::Value np_value = serializeJsonPayload(nested_payload, alloc); |
| if (minimize_updates_) { |
| nested_payloads_.insert(std::pair<std::string, C2Payload>(nested_payload.getLabel(), nested_payload)); |
| } |
| json_payload.AddMember(np_key, np_value, alloc); |
| } |
| } |
| |
| rapidjson::StringBuffer buffer; |
| rapidjson::PrettyWriter<rapidjson::StringBuffer> writer(buffer); |
| json_payload.Accept(writer); |
| return buffer.GetString(); |
| } |
| |
| bool RESTProtocol::containsPayload(const C2Payload &o) { |
| auto it = nested_payloads_.find(o.getLabel()); |
| if (it != nested_payloads_.end()) { |
| return it->second == o; |
| } |
| return false; |
| } |
| |
| rapidjson::Value RESTProtocol::serializeConnectionQueues(const C2Payload &payload, std::string &label, rapidjson::Document::AllocatorType &alloc) { |
| rapidjson::Value json_payload(payload.isContainer() ? rapidjson::kArrayType : rapidjson::kObjectType); |
| |
| C2Payload adjusted(payload.getOperation(), payload.getIdentifier(), payload.isRaw()); |
| |
| auto name = payload.getLabel(); |
| std::string uuid; |
| C2ContentResponse updatedContent(payload.getOperation()); |
| for (const C2ContentResponse &content : payload.getContent()) { |
| for (const auto& op_arg : content.operation_arguments) { |
| if (op_arg.first == "uuid") { |
| uuid = op_arg.second.to_string(); |
| } |
| updatedContent.operation_arguments.insert(op_arg); |
| } |
| } |
| updatedContent.name = uuid; |
| adjusted.setLabel(uuid); |
| adjusted.setIdentifier(uuid); |
| c2::AnnotatedValue nd; |
| // name should be what was previously the TLN ( top level node ) |
| nd = name; |
| updatedContent.operation_arguments.insert(std::make_pair("name", nd)); |
| // the rvalue reference is an unfortunate side effect of the underlying API decision. |
| adjusted.addContent(std::move(updatedContent), true); |
| mergePayloadContent(json_payload, adjusted, alloc); |
| label = uuid; |
| return json_payload; |
| } |
| |
| rapidjson::Value RESTProtocol::serializeJsonPayload(const C2Payload &payload, rapidjson::Document::AllocatorType &alloc) { |
| // get the name from the content |
| rapidjson::Value json_payload(payload.isContainer() ? rapidjson::kArrayType : rapidjson::kObjectType); |
| |
| std::vector<ValueObject> children; |
| |
| bool isQueue = payload.getLabel() == "queues"; |
| |
| for (const auto &nested_payload : payload.getNestedPayloads()) { |
| std::string label = nested_payload.getLabel(); |
| rapidjson::Value* child_payload = new rapidjson::Value(isQueue ? serializeConnectionQueues(nested_payload, label, alloc) : serializeJsonPayload(nested_payload, alloc)); |
| |
| if (nested_payload.isCollapsible()) { |
| bool combine = false; |
| for (auto &subordinate : children) { |
| if (subordinate.name == label) { |
| subordinate.values.push_back(child_payload); |
| combine = true; |
| break; |
| } |
| } |
| if (!combine) { |
| ValueObject obj; |
| obj.name = label; |
| obj.values.push_back(child_payload); |
| children.push_back(obj); |
| } |
| } else { |
| ValueObject obj; |
| obj.name = label; |
| obj.values.push_back(child_payload); |
| children.push_back(obj); |
| } |
| } |
| |
| for (auto child_vector : children) { |
| rapidjson::Value children_json; |
| rapidjson::Value newMemberKey = getStringValue(child_vector.name, alloc); |
| if (child_vector.values.size() > 1) { |
| children_json.SetArray(); |
| for (auto child : child_vector.values) { |
| if (json_payload.IsArray()) |
| json_payload.PushBack(child->Move(), alloc); |
| else |
| children_json.PushBack(child->Move(), alloc); |
| } |
| if (!json_payload.IsArray()) |
| json_payload.AddMember(newMemberKey, children_json, alloc); |
| } else if (child_vector.values.size() == 1) { |
| rapidjson::Value* first = child_vector.values.front(); |
| if (first->IsObject() && first->HasMember(newMemberKey)) { |
| if (json_payload.IsArray()) |
| json_payload.PushBack((*first)[newMemberKey].Move(), alloc); |
| else |
| json_payload.AddMember(newMemberKey, (*first)[newMemberKey].Move(), alloc); |
| } else { |
| if (json_payload.IsArray()) { |
| json_payload.PushBack(first->Move(), alloc); |
| } else { |
| json_payload.AddMember(newMemberKey, first->Move(), alloc); |
| } |
| } |
| } |
| for (rapidjson::Value* child : child_vector.values) |
| delete child; |
| } |
| |
| mergePayloadContent(json_payload, payload, alloc); |
| return json_payload; |
| } |
| |
| std::string RESTProtocol::getOperation(const C2Payload &payload) { |
| switch (payload.getOperation()) { |
| case Operation::ACKNOWLEDGE: |
| return "acknowledge"; |
| case Operation::HEARTBEAT: |
| return "heartbeat"; |
| case Operation::RESTART: |
| return "restart"; |
| case Operation::DESCRIBE: |
| return "describe"; |
| case Operation::STOP: |
| return "stop"; |
| case Operation::START: |
| return "start"; |
| case Operation::UPDATE: |
| return "update"; |
| default: |
| return "heartbeat"; |
| } |
| } |
| |
| Operation RESTProtocol::stringToOperation(const std::string str) { |
| std::string op = str; |
| std::transform(str.begin(), str.end(), op.begin(), ::tolower); |
| if (op == "heartbeat") { |
| return Operation::HEARTBEAT; |
| } else if (op == "acknowledge") { |
| return Operation::ACKNOWLEDGE; |
| } else if (op == "update") { |
| return Operation::UPDATE; |
| } else if (op == "describe") { |
| return Operation::DESCRIBE; |
| } else if (op == "restart") { |
| return Operation::RESTART; |
| } else if (op == "clear") { |
| return Operation::CLEAR; |
| } else if (op == "stop") { |
| return Operation::STOP; |
| } else if (op == "start") { |
| return Operation::START; |
| } |
| return Operation::HEARTBEAT; |
| } |
| #ifdef WIN32 |
| #pragma pop_macro("GetObject") |
| #endif |
| } // namespace c2 |
| } // namespace minifi |
| } // namespace nifi |
| } // namespace apache |
| } // namespace org |