blob: 0143dc5b9efe1e7ac30943925c468fb6799a8f66 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#' Load a Python Flight server
#'
#' @param name string Python module name
#' @param path file system path where the Python module is found. Default is
#' to look in the `inst/` directory for included modules.
#' @export
#' @examplesIf FALSE
#' load_flight_server("demo_flight_server")
load_flight_server <- function(name, path = system.file(package = "arrow")) {
reticulate::import_from_path(name, path)
}
#' Connect to a Flight server
#'
#' @param host string hostname to connect to
#' @param port integer port to connect on
#' @param scheme URL scheme, default is "grpc+tcp"
#' @return A `pyarrow.flight.FlightClient`.
#' @export
flight_connect <- function(host = "localhost", port, scheme = "grpc+tcp") {
pa <- reticulate::import("pyarrow")
location <- paste0(scheme, "://", host, ":", port)
pa$flight$FlightClient(location)
}
#' Send data to a Flight server
#'
#' @param client `pyarrow.flight.FlightClient`, as returned by [flight_connect()]
#' @param data `data.frame`, [RecordBatch], or [Table] to upload
#' @param path string identifier to store the data under
#' @param overwrite logical: if `path` exists on `client` already, should we
#' replace it with the contents of `data`? Default is `TRUE`; if `FALSE` and
#' `path` exists, the function will error.
#' @return `client`, invisibly.
#' @export
flight_put <- function(client, data, path, overwrite = TRUE) {
if (!overwrite && flight_path_exists(client, path)) {
stop(path, " exists.", call. = FALSE)
}
if (is.data.frame(data)) {
data <- Table$create(data)
}
py_data <- reticulate::r_to_py(data)
writer <- client$do_put(descriptor_for_path(path), py_data$schema)[[1]]
if (inherits(data, "RecordBatch")) {
writer$write_batch(py_data)
} else {
writer$write_table(py_data)
}
writer$close()
invisible(client)
}
#' Get data from a Flight server
#'
#' @param client `pyarrow.flight.FlightClient`, as returned by [flight_connect()]
#' @param path string identifier under which data is stored
#' @return A [Table]
#' @export
flight_get <- function(client, path) {
reader <- flight_reader(client, path)
reader$read_all()
}
# TODO: could use this as a RecordBatch iterator, call $read_chunk() on this
flight_reader <- function(client, path) {
info <- client$get_flight_info(descriptor_for_path(path))
# Hack: assume a single ticket, on the same server as client is already connected
ticket <- info$endpoints[[1]]$ticket
client$do_get(ticket)
}
descriptor_for_path <- function(path) {
pa <- reticulate::import("pyarrow")
pa$flight$FlightDescriptor$for_path(path)
}
#' See available resources on a Flight server
#'
#' @inheritParams flight_get
#' @return `list_flights()` returns a character vector of paths.
#' `flight_path_exists()` returns a logical value, the equivalent of `path %in% list_flights()`
#' @export
list_flights <- function(client) {
generator <- client$list_flights()
out <- reticulate::iterate(generator, function(x) as.character(x$descriptor$path[[1]]))
out
}
#' @rdname list_flights
#' @export
flight_path_exists <- function(client, path) {
it_exists <- tryCatch({
client$get_flight_info(descriptor_for_path(path))
TRUE
},
error = function(e) {
msg <- conditionMessage(e)
if (!any(grepl("ArrowKeyError", msg))) {
# Raise an error if this fails for any reason other than not found
stop(e)
}
FALSE
}
)
}