| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| #include "arrow/flight/types.h" |
| |
| #include <memory> |
| #include <sstream> |
| #include <utility> |
| |
| #include "arrow/flight/serialization_internal.h" |
| #include "arrow/io/memory.h" |
| #include "arrow/ipc/dictionary.h" |
| #include "arrow/ipc/reader.h" |
| #include "arrow/status.h" |
| #include "arrow/table.h" |
| #include "arrow/util/uri.h" |
| |
| namespace arrow { |
| namespace flight { |
| |
| const char* kSchemeGrpc = "grpc"; |
| const char* kSchemeGrpcTcp = "grpc+tcp"; |
| const char* kSchemeGrpcUnix = "grpc+unix"; |
| const char* kSchemeGrpcTls = "grpc+tls"; |
| |
| const char* kErrorDetailTypeId = "flight::FlightStatusDetail"; |
| |
| const char* FlightStatusDetail::type_id() const { return kErrorDetailTypeId; } |
| |
| std::string FlightStatusDetail::ToString() const { return CodeAsString(); } |
| |
| FlightStatusCode FlightStatusDetail::code() const { return code_; } |
| |
| std::string FlightStatusDetail::extra_info() const { return extra_info_; } |
| |
| void FlightStatusDetail::set_extra_info(std::string extra_info) { |
| extra_info_ = std::move(extra_info); |
| } |
| |
| std::string FlightStatusDetail::CodeAsString() const { |
| switch (code()) { |
| case FlightStatusCode::Internal: |
| return "Internal"; |
| case FlightStatusCode::TimedOut: |
| return "TimedOut"; |
| case FlightStatusCode::Cancelled: |
| return "Cancelled"; |
| case FlightStatusCode::Unauthenticated: |
| return "Unauthenticated"; |
| case FlightStatusCode::Unauthorized: |
| return "Unauthorized"; |
| case FlightStatusCode::Unavailable: |
| return "Unavailable"; |
| default: |
| return "Unknown"; |
| } |
| } |
| |
| std::shared_ptr<FlightStatusDetail> FlightStatusDetail::UnwrapStatus( |
| const arrow::Status& status) { |
| if (!status.detail() || status.detail()->type_id() != kErrorDetailTypeId) { |
| return nullptr; |
| } |
| return std::dynamic_pointer_cast<FlightStatusDetail>(status.detail()); |
| } |
| |
| Status MakeFlightError(FlightStatusCode code, std::string message, |
| std::string extra_info) { |
| StatusCode arrow_code = arrow::StatusCode::IOError; |
| return arrow::Status(arrow_code, std::move(message), |
| std::make_shared<FlightStatusDetail>(code, std::move(extra_info))); |
| } |
| |
| bool FlightDescriptor::Equals(const FlightDescriptor& other) const { |
| if (type != other.type) { |
| return false; |
| } |
| switch (type) { |
| case PATH: |
| return path == other.path; |
| case CMD: |
| return cmd == other.cmd; |
| default: |
| return false; |
| } |
| } |
| |
| std::string FlightDescriptor::ToString() const { |
| std::stringstream ss; |
| ss << "FlightDescriptor<"; |
| switch (type) { |
| case PATH: { |
| bool first = true; |
| ss << "path = '"; |
| for (const auto& p : path) { |
| if (!first) { |
| ss << "/"; |
| } |
| first = false; |
| ss << p; |
| } |
| ss << "'"; |
| break; |
| } |
| case CMD: |
| ss << "cmd = '" << cmd << "'"; |
| break; |
| default: |
| break; |
| } |
| ss << ">"; |
| return ss.str(); |
| } |
| |
| Status SchemaResult::GetSchema(ipc::DictionaryMemo* dictionary_memo, |
| std::shared_ptr<Schema>* out) const { |
| io::BufferReader schema_reader(raw_schema_); |
| return ipc::ReadSchema(&schema_reader, dictionary_memo).Value(out); |
| } |
| |
| Status FlightDescriptor::SerializeToString(std::string* out) const { |
| pb::FlightDescriptor pb_descriptor; |
| RETURN_NOT_OK(internal::ToProto(*this, &pb_descriptor)); |
| |
| if (!pb_descriptor.SerializeToString(out)) { |
| return Status::IOError("Serialized descriptor exceeded 2 GiB limit"); |
| } |
| return Status::OK(); |
| } |
| |
| Status FlightDescriptor::Deserialize(const std::string& serialized, |
| FlightDescriptor* out) { |
| pb::FlightDescriptor pb_descriptor; |
| if (!pb_descriptor.ParseFromString(serialized)) { |
| return Status::Invalid("Not a valid descriptor"); |
| } |
| return internal::FromProto(pb_descriptor, out); |
| } |
| |
| bool Ticket::Equals(const Ticket& other) const { return ticket == other.ticket; } |
| |
| Status Ticket::SerializeToString(std::string* out) const { |
| pb::Ticket pb_ticket; |
| internal::ToProto(*this, &pb_ticket); |
| |
| if (!pb_ticket.SerializeToString(out)) { |
| return Status::IOError("Serialized ticket exceeded 2 GiB limit"); |
| } |
| return Status::OK(); |
| } |
| |
| Status Ticket::Deserialize(const std::string& serialized, Ticket* out) { |
| pb::Ticket pb_ticket; |
| if (!pb_ticket.ParseFromString(serialized)) { |
| return Status::Invalid("Not a valid ticket"); |
| } |
| return internal::FromProto(pb_ticket, out); |
| } |
| |
| arrow::Result<FlightInfo> FlightInfo::Make(const Schema& schema, |
| const FlightDescriptor& descriptor, |
| const std::vector<FlightEndpoint>& endpoints, |
| int64_t total_records, int64_t total_bytes) { |
| FlightInfo::Data data; |
| data.descriptor = descriptor; |
| data.endpoints = endpoints; |
| data.total_records = total_records; |
| data.total_bytes = total_bytes; |
| RETURN_NOT_OK(internal::SchemaToString(schema, &data.schema)); |
| return FlightInfo(data); |
| } |
| |
| Status FlightInfo::GetSchema(ipc::DictionaryMemo* dictionary_memo, |
| std::shared_ptr<Schema>* out) const { |
| if (reconstructed_schema_) { |
| *out = schema_; |
| return Status::OK(); |
| } |
| io::BufferReader schema_reader(data_.schema); |
| RETURN_NOT_OK(ipc::ReadSchema(&schema_reader, dictionary_memo).Value(&schema_)); |
| reconstructed_schema_ = true; |
| *out = schema_; |
| return Status::OK(); |
| } |
| |
| Status FlightInfo::SerializeToString(std::string* out) const { |
| pb::FlightInfo pb_info; |
| RETURN_NOT_OK(internal::ToProto(*this, &pb_info)); |
| |
| if (!pb_info.SerializeToString(out)) { |
| return Status::IOError("Serialized FlightInfo exceeded 2 GiB limit"); |
| } |
| return Status::OK(); |
| } |
| |
| Status FlightInfo::Deserialize(const std::string& serialized, |
| std::unique_ptr<FlightInfo>* out) { |
| pb::FlightInfo pb_info; |
| if (!pb_info.ParseFromString(serialized)) { |
| return Status::Invalid("Not a valid FlightInfo"); |
| } |
| FlightInfo::Data data; |
| RETURN_NOT_OK(internal::FromProto(pb_info, &data)); |
| out->reset(new FlightInfo(data)); |
| return Status::OK(); |
| } |
| |
| Location::Location() { uri_ = std::make_shared<arrow::internal::Uri>(); } |
| |
| Status Location::Parse(const std::string& uri_string, Location* location) { |
| return location->uri_->Parse(uri_string); |
| } |
| |
| Status Location::ForGrpcTcp(const std::string& host, const int port, Location* location) { |
| std::stringstream uri_string; |
| uri_string << "grpc+tcp://" << host << ':' << port; |
| return Location::Parse(uri_string.str(), location); |
| } |
| |
| Status Location::ForGrpcTls(const std::string& host, const int port, Location* location) { |
| std::stringstream uri_string; |
| uri_string << "grpc+tls://" << host << ':' << port; |
| return Location::Parse(uri_string.str(), location); |
| } |
| |
| Status Location::ForGrpcUnix(const std::string& path, Location* location) { |
| std::stringstream uri_string; |
| uri_string << "grpc+unix://" << path; |
| return Location::Parse(uri_string.str(), location); |
| } |
| |
| std::string Location::ToString() const { return uri_->ToString(); } |
| std::string Location::scheme() const { |
| std::string scheme = uri_->scheme(); |
| if (scheme.empty()) { |
| // Default to grpc+tcp |
| return "grpc+tcp"; |
| } |
| return scheme; |
| } |
| |
| bool Location::Equals(const Location& other) const { |
| return ToString() == other.ToString(); |
| } |
| |
| bool FlightEndpoint::Equals(const FlightEndpoint& other) const { |
| return ticket == other.ticket && locations == other.locations; |
| } |
| |
| Status MetadataRecordBatchReader::ReadAll( |
| std::vector<std::shared_ptr<RecordBatch>>* batches) { |
| FlightStreamChunk chunk; |
| |
| while (true) { |
| RETURN_NOT_OK(Next(&chunk)); |
| if (!chunk.data) break; |
| batches->emplace_back(std::move(chunk.data)); |
| } |
| return Status::OK(); |
| } |
| |
| Status MetadataRecordBatchReader::ReadAll(std::shared_ptr<Table>* table) { |
| std::vector<std::shared_ptr<RecordBatch>> batches; |
| RETURN_NOT_OK(ReadAll(&batches)); |
| ARROW_ASSIGN_OR_RAISE(auto schema, GetSchema()); |
| return Table::FromRecordBatches(schema, std::move(batches)).Value(table); |
| } |
| |
| Status MetadataRecordBatchWriter::Begin(const std::shared_ptr<Schema>& schema) { |
| return Begin(schema, ipc::IpcWriteOptions::Defaults()); |
| } |
| |
| namespace { |
| class MetadataRecordBatchReaderAdapter : public RecordBatchReader { |
| public: |
| explicit MetadataRecordBatchReaderAdapter( |
| std::shared_ptr<Schema> schema, std::shared_ptr<MetadataRecordBatchReader> delegate) |
| : schema_(std::move(schema)), delegate_(std::move(delegate)) {} |
| std::shared_ptr<Schema> schema() const override { return schema_; } |
| Status ReadNext(std::shared_ptr<RecordBatch>* batch) override { |
| FlightStreamChunk next; |
| while (true) { |
| RETURN_NOT_OK(delegate_->Next(&next)); |
| if (!next.data && !next.app_metadata) { |
| // EOS |
| *batch = nullptr; |
| return Status::OK(); |
| } else if (next.data) { |
| *batch = std::move(next.data); |
| return Status::OK(); |
| } |
| // Got metadata, but no data (which is valid) - read the next message |
| } |
| } |
| |
| private: |
| std::shared_ptr<Schema> schema_; |
| std::shared_ptr<MetadataRecordBatchReader> delegate_; |
| }; |
| }; // namespace |
| |
| arrow::Result<std::shared_ptr<RecordBatchReader>> MakeRecordBatchReader( |
| std::shared_ptr<MetadataRecordBatchReader> reader) { |
| ARROW_ASSIGN_OR_RAISE(auto schema, reader->GetSchema()); |
| return std::make_shared<MetadataRecordBatchReaderAdapter>(std::move(schema), |
| std::move(reader)); |
| } |
| |
| SimpleFlightListing::SimpleFlightListing(const std::vector<FlightInfo>& flights) |
| : position_(0), flights_(flights) {} |
| |
| SimpleFlightListing::SimpleFlightListing(std::vector<FlightInfo>&& flights) |
| : position_(0), flights_(std::move(flights)) {} |
| |
| Status SimpleFlightListing::Next(std::unique_ptr<FlightInfo>* info) { |
| if (position_ >= static_cast<int>(flights_.size())) { |
| *info = nullptr; |
| return Status::OK(); |
| } |
| *info = std::unique_ptr<FlightInfo>(new FlightInfo(std::move(flights_[position_++]))); |
| return Status::OK(); |
| } |
| |
| SimpleResultStream::SimpleResultStream(std::vector<Result>&& results) |
| : results_(std::move(results)), position_(0) {} |
| |
| Status SimpleResultStream::Next(std::unique_ptr<Result>* result) { |
| if (position_ >= results_.size()) { |
| *result = nullptr; |
| return Status::OK(); |
| } |
| *result = std::unique_ptr<Result>(new Result(std::move(results_[position_++]))); |
| return Status::OK(); |
| } |
| |
| Status BasicAuth::Deserialize(const std::string& serialized, BasicAuth* out) { |
| pb::BasicAuth pb_result; |
| pb_result.ParseFromString(serialized); |
| return internal::FromProto(pb_result, out); |
| } |
| |
| Status BasicAuth::Serialize(const BasicAuth& basic_auth, std::string* out) { |
| pb::BasicAuth pb_result; |
| RETURN_NOT_OK(internal::ToProto(basic_auth, &pb_result)); |
| *out = pb_result.SerializeAsString(); |
| return Status::OK(); |
| } |
| } // namespace flight |
| } // namespace arrow |