blob: 80c66ae479d58ad716330542c08df89fdf1a818b [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Cloud Datastore query splitter test."""
from __future__ import absolute_import
import sys
import unittest
from mock import MagicMock
from mock import call
# Protect against environments where datastore library is not available.
# pylint: disable=wrong-import-order, wrong-import-position
try:
from apache_beam.io.gcp.datastore.v1 import fake_datastore
from apache_beam.io.gcp.datastore.v1 import query_splitter
from google.cloud.proto.datastore.v1 import datastore_pb2
from google.cloud.proto.datastore.v1 import query_pb2
from google.cloud.proto.datastore.v1.query_pb2 import PropertyFilter
except (ImportError, TypeError):
datastore_pb2 = None
query_splitter = None
# pylint: enable=wrong-import-order, wrong-import-position
class QuerySplitterTest(unittest.TestCase):
@unittest.skipIf(
sys.version_info[0] == 3,
'v1/query_splitter does not support Python 3 TODO: BEAM-4543')
@unittest.skipIf(datastore_pb2 is None, 'GCP dependencies are not installed')
def setUp(self):
pass
def create_query(self, kinds=(), order=False, limit=None, offset=None,
inequality_filter=False):
query = query_pb2.Query()
for kind in kinds:
query.kind.add().name = kind
if order:
query.order.add()
if limit is not None:
query.limit.value = limit
if offset is not None:
query.offset = offset
if inequality_filter:
test_filter = query.filter.composite_filter.filters.add()
test_filter.property_filter.op = PropertyFilter.GREATER_THAN
return query
split_error = ValueError
query_splitter = query_splitter
def test_get_splits_query_with_multiple_kinds(self):
query = self.create_query(kinds=['a', 'b'])
with self.assertRaisesRegexp(self.split_error, r'one kind'):
self.query_splitter.get_splits(None, query, 4)
def test_get_splits_query_with_order(self):
query = self.create_query(kinds=['a'], order=True)
with self.assertRaisesRegexp(self.split_error, r'sort orders'):
self.query_splitter.get_splits(None, query, 3)
def test_get_splits_query_with_unsupported_filter(self):
query = self.create_query(kinds=['a'], inequality_filter=True)
with self.assertRaisesRegexp(self.split_error, r'inequality filters'):
self.query_splitter.get_splits(None, query, 2)
def test_get_splits_query_with_limit(self):
query = self.create_query(kinds=['a'], limit=10)
with self.assertRaisesRegexp(self.split_error, r'limit set'):
self.query_splitter.get_splits(None, query, 2)
def test_get_splits_query_with_offset(self):
query = self.create_query(kinds=['a'], offset=10)
with self.assertRaisesRegexp(self.split_error, r'offset set'):
self.query_splitter.get_splits(None, query, 2)
def test_create_scatter_query(self):
query = self.create_query(kinds=['shakespeare-demo'])
num_splits = 10
scatter_query = self.query_splitter._create_scatter_query(query, num_splits)
self.assertEqual(scatter_query.kind[0], query.kind[0])
self.assertEqual(scatter_query.limit.value,
(num_splits -1) * self.query_splitter.KEYS_PER_SPLIT)
self.assertEqual(scatter_query.order[0].direction,
query_pb2.PropertyOrder.ASCENDING)
self.assertEqual(scatter_query.projection[0].property.name,
self.query_splitter.KEY_PROPERTY_NAME)
def test_get_splits_with_two_splits(self):
query = self.create_query(kinds=['shakespeare-demo'])
num_splits = 2
num_entities = 97
batch_size = 9
self.check_get_splits(query, num_splits, num_entities, batch_size)
def test_get_splits_with_multiple_splits(self):
query = self.create_query(kinds=['shakespeare-demo'])
num_splits = 4
num_entities = 369
batch_size = 12
self.check_get_splits(query, num_splits, num_entities, batch_size)
def test_get_splits_with_large_num_splits(self):
query = self.create_query(kinds=['shakespeare-demo'])
num_splits = 10
num_entities = 4
batch_size = 10
self.check_get_splits(query, num_splits, num_entities, batch_size)
def test_get_splits_with_small_num_entities(self):
query = self.create_query(kinds=['shakespeare-demo'])
num_splits = 4
num_entities = 50
batch_size = 10
self.check_get_splits(query, num_splits, num_entities, batch_size)
def test_get_splits_with_batch_size_exact_multiple(self):
"""Test get_splits when num scatter keys is a multiple of batch size."""
query = self.create_query(kinds=['shakespeare-demo'])
num_splits = 4
num_entities = 400
batch_size = 32
self.check_get_splits(query, num_splits, num_entities, batch_size)
def test_get_splits_with_large_batch_size(self):
"""Test get_splits when all scatter keys are retured in a single req."""
query = self.create_query(kinds=['shakespeare-demo'])
num_splits = 4
num_entities = 400
batch_size = 500
self.check_get_splits(query, num_splits, num_entities, batch_size)
def test_get_splits_with_num_splits_gt_entities(self):
query = self.create_query(kinds=['shakespeare-demo'])
num_splits = 10
num_entities = 4
batch_size = 10
self.check_get_splits(query, num_splits, num_entities, batch_size)
def check_get_splits(self, query, num_splits, num_entities, batch_size):
"""A helper method to test the query_splitter get_splits method.
Args:
query: the query to be split
num_splits: number of splits
num_entities: number of scatter entities contained in the fake datastore.
batch_size: the number of entities returned by fake datastore in one req.
"""
# Test for random long ids, string ids, and a mix of both.
id_or_name = [True, False, None]
for id_type in id_or_name:
if id_type is None:
entities = fake_datastore.create_entities(num_entities, False)
entities.extend(fake_datastore.create_entities(num_entities, True))
num_entities *= 2
else:
entities = fake_datastore.create_entities(num_entities, id_type)
mock_datastore = MagicMock()
# Assign a fake run_query method as a side_effect to the mock.
mock_datastore.run_query.side_effect = \
fake_datastore.create_run_query(entities, batch_size)
split_queries = self.query_splitter.get_splits(
mock_datastore, query, num_splits)
# if request num_splits is greater than num_entities, the best it can
# do is one entity per split.
expected_num_splits = min(num_splits, num_entities + 1)
self.assertEqual(len(split_queries), expected_num_splits)
expected_requests = self.create_scatter_requests(
query, num_splits, batch_size, num_entities)
expected_calls = []
for req in expected_requests:
expected_calls.append(call(req))
self.assertEqual(expected_calls, mock_datastore.run_query.call_args_list)
def create_scatter_requests(self, query, num_splits, batch_size,
num_entities):
"""Creates a list of expected scatter requests from the query splitter.
This list of requests returned is used to verify that the query splitter
made the same number of requests in the same order to datastore.
"""
requests = []
count = (num_splits - 1) * self.query_splitter.KEYS_PER_SPLIT
start_cursor = ''
i = 0
scatter_query = self.query_splitter._create_scatter_query(query, count)
while i < count and i < num_entities:
request = datastore_pb2.RunQueryRequest()
request.query.CopyFrom(scatter_query)
request.query.start_cursor = start_cursor
request.query.limit.value = count - i
requests.append(request)
i += batch_size
start_cursor = str(i)
return requests
if __name__ == '__main__':
unittest.main()