blob: 4ea28983f66e93cd931ab2244eed84ce51e87b94 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Cloud Datastore helper functions.
For internal use only; no backwards-compatibility guarantees.
"""
from __future__ import absolute_import
import errno
import logging
import sys
import time
from builtins import object
from socket import error as SocketError
from future.builtins import next
from past.builtins import unicode
# pylint: disable=ungrouped-imports
from apache_beam.internal.gcp import auth
from apache_beam.utils import retry
# Protect against environments where datastore library is not available.
# pylint: disable=wrong-import-order, wrong-import-position
try:
from google.cloud.proto.datastore.v1 import datastore_pb2
from google.cloud.proto.datastore.v1 import entity_pb2
from google.cloud.proto.datastore.v1 import query_pb2
from google.rpc import code_pb2
from googledatastore import PropertyFilter, CompositeFilter
from googledatastore import helper as datastore_helper
from googledatastore.connection import Datastore
from googledatastore.connection import RPCError
except ImportError:
pass
# pylint: enable=wrong-import-order, wrong-import-position
# pylint: enable=ungrouped-imports
_LOGGER = logging.getLogger(__name__)
def key_comparator(k1, k2):
"""A comparator for Datastore keys.
Comparison is only valid for keys in the same partition. The comparison here
is between the list of paths for each key.
"""
if k1.partition_id != k2.partition_id:
raise ValueError('Cannot compare keys with different partition ids.')
k2_iter = iter(k2.path)
for k1_path in k1.path:
k2_path = next(k2_iter, None)
if not k2_path:
return 1
result = compare_path(k1_path, k2_path)
if result != 0:
return result
k2_path = next(k2_iter, None)
if k2_path:
return -1
return 0
def compare_path(p1, p2):
"""A comparator for key path.
A path has either an `id` or a `name` field defined. The
comparison works with the following rules:
1. If one path has `id` defined while the other doesn't, then the
one with `id` defined is considered smaller.
2. If both paths have `id` defined, then their ids are compared.
3. If no `id` is defined for both paths, then their `names` are compared.
"""
result = (p1.kind > p2.kind) - (p1.kind < p2.kind)
if result != 0:
return result
if p1.HasField('id'):
if not p2.HasField('id'):
return -1
return (p1.id > p2.id) - (p1.id < p2.id)
if p2.HasField('id'):
return 1
return (p1.name > p2.name) - (p1.name < p2.name)
def get_datastore(project):
"""Returns a Cloud Datastore client."""
credentials = auth.get_service_credentials()
return Datastore(project, credentials, host='batch-datastore.googleapis.com')
def make_request(project, namespace, query):
"""Make a Cloud Datastore request for the given query."""
req = datastore_pb2.RunQueryRequest()
req.partition_id.CopyFrom(make_partition(project, namespace))
req.query.CopyFrom(query)
return req
def make_partition(project, namespace):
"""Make a PartitionId for the given project and namespace."""
partition = entity_pb2.PartitionId()
partition.project_id = project
if namespace is not None:
partition.namespace_id = namespace
return partition
def retry_on_rpc_error(exception):
"""A retry filter for Cloud Datastore RPCErrors."""
if isinstance(exception, RPCError):
err_code = exception.code
# TODO(BEAM-2156): put these codes in a global list and use that instead.
# https://cloud.google.com/datastore/docs/concepts/errors#error_codes
return err_code in [
code_pb2.ABORTED,
code_pb2.DEADLINE_EXCEEDED,
code_pb2.INTERNAL,
code_pb2.UNAVAILABLE,
code_pb2.UNKNOWN,
]
if isinstance(exception, SocketError):
return (exception.errno == errno.ECONNRESET or
exception.errno == errno.ETIMEDOUT or
exception.errno == errno.EPIPE)
return False
def fetch_entities(project, namespace, query, datastore):
"""A helper method to fetch entities from Cloud Datastore.
Args:
project: Project ID
namespace: Cloud Datastore namespace
query: Query to be read from
datastore: Cloud Datastore Client
Returns:
An iterator of entities.
"""
return QueryIterator(project, namespace, query, datastore)
def is_key_valid(key):
"""Returns True if a Cloud Datastore key is complete.
A key is complete if its last element has either an id or a name.
"""
if not key.path:
return False
return key.path[-1].HasField('id') or key.path[-1].HasField('name')
def write_mutations(datastore, project, mutations, throttler,
rpc_stats_callback=None, throttle_delay=1):
"""A helper function to write a batch of mutations to Cloud Datastore.
If a commit fails, it will be retried upto 5 times. All mutations in the
batch will be committed again, even if the commit was partially successful.
If the retry limit is exceeded, the last exception from Cloud Datastore will
be raised.
Args:
datastore: googledatastore.connection.Datastore
project: str, project id
mutations: list of google.cloud.proto.datastore.v1.datastore_pb2.Mutation
rpc_stats_callback: a function to call with arguments `successes` and
`failures` and `throttled_secs`; this is called to record successful
and failed RPCs to Datastore and time spent waiting for throttling.
throttler: AdaptiveThrottler, to use to select requests to be throttled.
throttle_delay: float, time in seconds to sleep when throttled.
Returns a tuple of:
CommitResponse, the response from Datastore;
int, the latency of the successful RPC in milliseconds.
"""
commit_request = datastore_pb2.CommitRequest()
commit_request.mode = datastore_pb2.CommitRequest.NON_TRANSACTIONAL
commit_request.project_id = project
for mutation in mutations:
commit_request.mutations.add().CopyFrom(mutation)
@retry.with_exponential_backoff(num_retries=5,
retry_filter=retry_on_rpc_error)
def commit(request):
# Client-side throttling.
while throttler.throttle_request(time.time()*1000):
_LOGGER.info("Delaying request for %ds due to previous failures",
throttle_delay)
time.sleep(throttle_delay)
rpc_stats_callback(throttled_secs=throttle_delay)
try:
start_time = time.time()
response = datastore.commit(request)
end_time = time.time()
rpc_stats_callback(successes=1)
throttler.successful_request(start_time*1000)
commit_time_ms = int((end_time-start_time)*1000)
return response, commit_time_ms
except (RPCError, SocketError):
if rpc_stats_callback:
rpc_stats_callback(errors=1)
raise
response, commit_time_ms = commit(commit_request)
return response, commit_time_ms
def make_latest_timestamp_query(namespace):
"""Make a Query to fetch the latest timestamp statistics."""
query = query_pb2.Query()
if namespace is None:
query.kind.add().name = '__Stat_Total__'
else:
query.kind.add().name = '__Stat_Ns_Total__'
# Descending order of `timestamp`
datastore_helper.add_property_orders(query, "-timestamp")
# Only get the latest entity
query.limit.value = 1
return query
def make_kind_stats_query(namespace, kind, latest_timestamp):
"""Make a Query to fetch the latest kind statistics."""
kind_stat_query = query_pb2.Query()
if namespace is None:
kind_stat_query.kind.add().name = '__Stat_Kind__'
else:
kind_stat_query.kind.add().name = '__Stat_Ns_Kind__'
kind_filter = datastore_helper.set_property_filter(
query_pb2.Filter(), 'kind_name', PropertyFilter.EQUAL,
unicode(kind))
timestamp_filter = datastore_helper.set_property_filter(
query_pb2.Filter(), 'timestamp', PropertyFilter.EQUAL,
latest_timestamp)
datastore_helper.set_composite_filter(kind_stat_query.filter,
CompositeFilter.AND, kind_filter,
timestamp_filter)
return kind_stat_query
class QueryIterator(object):
"""A iterator class for entities of a given query.
Entities are read in batches. Retries on failures.
"""
# Maximum number of results to request per query.
_BATCH_SIZE = 500
def __init__(self, project, namespace, query, datastore):
self._query = query
self._datastore = datastore
self._project = project
self._namespace = namespace
self._start_cursor = None
self._limit = self._query.limit.value or sys.maxsize
self._req = make_request(project, namespace, query)
@retry.with_exponential_backoff(num_retries=5,
retry_filter=retry_on_rpc_error)
def _next_batch(self):
"""Fetches the next batch of entities."""
if self._start_cursor is not None:
self._req.query.start_cursor = self._start_cursor
# set batch size
self._req.query.limit.value = min(self._BATCH_SIZE, self._limit)
resp = self._datastore.run_query(self._req)
return resp
def __iter__(self):
more_results = True
while more_results:
resp = self._next_batch()
for entity_result in resp.batch.entity_results:
yield entity_result.entity
self._start_cursor = resp.batch.end_cursor
num_results = len(resp.batch.entity_results)
self._limit -= num_results
# Check if we need to read more entities.
# True when query limit hasn't been satisfied and there are more entities
# to be read. The latter is true if the response has a status
# `NOT_FINISHED` or if the number of results read in the previous batch
# is equal to `_BATCH_SIZE` (all indications that there is more data be
# read).
more_results = ((self._limit > 0) and
((num_results == self._BATCH_SIZE) or
(resp.batch.more_results ==
query_pb2.QueryResultBatch.NOT_FINISHED)))