blob: e3975c711903dc2ed5163c892c14aec907b4cca9 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Unit tests for datastoreio."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import datetime
import math
import unittest
from mock import MagicMock
from mock import call
from mock import patch
# Protect against environments where datastore library is not available.
try:
from apache_beam.io.gcp.datastore.v1 import util
from apache_beam.io.gcp.datastore.v1new import helper
from apache_beam.io.gcp.datastore.v1new import query_splitter
from apache_beam.io.gcp.datastore.v1new.datastoreio import DeleteFromDatastore
from apache_beam.io.gcp.datastore.v1new.datastoreio import ReadFromDatastore
from apache_beam.io.gcp.datastore.v1new.datastoreio import WriteToDatastore
from google.cloud.datastore import client
from google.cloud.datastore import entity
from google.cloud.datastore import helpers
from google.cloud.datastore import key
# Keep this import last so it doesn't import conflicting pb2 modules.
from apache_beam.io.gcp.datastore.v1 import datastoreio_test # pylint: disable=ungrouped-imports
DatastoreioTestBase = datastoreio_test.DatastoreioTest
# TODO(BEAM-4543): Remove TypeError once googledatastore dependency is removed.
except (ImportError, TypeError):
client = None
DatastoreioTestBase = unittest.TestCase
class FakeMutation(object):
def __init__(self, entity=None, key=None):
"""Fake mutation request object.
Requires exactly one of entity or key to be set.
Args:
entity: (``google.cloud.datastore.entity.Entity``) entity representing
this upsert mutation
key: (``google.cloud.datastore.key.Key``) key representing
this delete mutation
"""
self.entity = entity
self.key = key
def ByteSize(self):
if self.entity is not None:
return helpers.entity_to_protobuf(self.entity).ByteSize()
else:
return self.key.to_protobuf().ByteSize()
class FakeBatch(object):
def __init__(self, all_batch_items=None, commit_count=None):
"""Fake ``google.cloud.datastore.batch.Batch`` object.
Args:
all_batch_items: (list) If set, will append all entities/keys added to
this batch.
commit_count: (list of int) If set, will increment commit_count[0] on
each ``commit``.
"""
self._all_batch_items = all_batch_items
self._commit_count = commit_count
self.mutations = []
def put(self, _entity):
assert isinstance(_entity, entity.Entity)
self.mutations.append(FakeMutation(entity=_entity))
if self._all_batch_items is not None:
self._all_batch_items.append(_entity)
def delete(self, _key):
assert isinstance(_key, key.Key)
self.mutations.append(FakeMutation(key=_key))
if self._all_batch_items is not None:
self._all_batch_items.append(_key)
def begin(self):
pass
def commit(self):
if self._commit_count:
self._commit_count[0] += 1
@unittest.skipIf(client is None, 'Datastore dependencies are not installed')
class DatastoreioTest(DatastoreioTestBase):
"""
NOTE: This test inherits test cases from DatastoreioTestBase.
Please prefer to add new test cases to v1/datastoreio_test if possible.
"""
def setUp(self):
self._WRITE_BATCH_INITIAL_SIZE = util.WRITE_BATCH_INITIAL_SIZE
self._mock_client = MagicMock()
self._mock_client.project = self._PROJECT
self._mock_client.namespace = self._NAMESPACE
self._mock_query = MagicMock()
self._mock_query.limit = None
self._mock_query.order = None
self._real_client = client.Client(
project=self._PROJECT, namespace=self._NAMESPACE,
# Don't do any network requests.
_http=MagicMock())
def get_timestamp(self):
return datetime.datetime(2019, 3, 14, 15, 9, 26, 535897)
def test_SplitQueryFn_with_num_splits(self):
with patch.object(helper, 'get_client', return_value=self._mock_client):
num_splits = 23
expected_num_splits = 23
def fake_get_splits(unused_client, query, num_splits):
return [query] * num_splits
with patch.object(query_splitter, 'get_splits',
side_effect=fake_get_splits):
split_query_fn = ReadFromDatastore._SplitQueryFn(num_splits)
split_queries = split_query_fn.process(self._mock_query)
self.assertEqual(expected_num_splits, len(split_queries))
def test_SplitQueryFn_without_num_splits(self):
with patch.object(helper, 'get_client', return_value=self._mock_client):
# Force _SplitQueryFn to compute the number of query splits
num_splits = 0
expected_num_splits = 23
entity_bytes = (expected_num_splits *
ReadFromDatastore._DEFAULT_BUNDLE_SIZE_BYTES)
with patch.object(
ReadFromDatastore._SplitQueryFn, 'get_estimated_size_bytes',
return_value=entity_bytes):
def fake_get_splits(unused_client, query, num_splits):
return [query] * num_splits
with patch.object(query_splitter, 'get_splits',
side_effect=fake_get_splits):
split_query_fn = ReadFromDatastore._SplitQueryFn(num_splits)
split_queries = split_query_fn.process(self._mock_query)
self.assertEqual(expected_num_splits, len(split_queries))
def test_SplitQueryFn_with_query_limit(self):
"""A test that verifies no split is performed when the query has a limit."""
with patch.object(helper, 'get_client', return_value=self._mock_client):
num_splits = 4
expected_num_splits = 1
self._mock_query.limit = 3
split_query_fn = ReadFromDatastore._SplitQueryFn(num_splits)
split_queries = split_query_fn.process(self._mock_query)
self.assertEqual(expected_num_splits, len(split_queries))
def test_SplitQueryFn_with_exception(self):
"""A test that verifies that no split is performed when failures occur."""
with patch.object(helper, 'get_client', return_value=self._mock_client):
# Force _SplitQueryFn to compute the number of query splits
num_splits = 0
expected_num_splits = 1
entity_bytes = (expected_num_splits *
ReadFromDatastore._DEFAULT_BUNDLE_SIZE_BYTES)
with patch.object(
ReadFromDatastore._SplitQueryFn, 'get_estimated_size_bytes',
return_value=entity_bytes):
with patch.object(query_splitter, 'get_splits',
side_effect=query_splitter.QuerySplitterError(
"Testing query split error")):
split_query_fn = ReadFromDatastore._SplitQueryFn(num_splits)
split_queries = split_query_fn.process(self._mock_query)
self.assertEqual(expected_num_splits, len(split_queries))
self.assertEqual(self._mock_query, split_queries[0])
def check_DatastoreWriteFn(self, num_entities, use_fixed_batch_size=False):
"""A helper function to test _DatastoreWriteFn."""
with patch.object(helper, 'get_client', return_value=self._mock_client):
entities = helper.create_entities(num_entities)
expected_entities = [entity.to_client_entity() for entity in entities]
all_batch_entities = []
commit_count = [0]
self._mock_client.batch.side_effect = (
lambda: FakeBatch(all_batch_items=all_batch_entities,
commit_count=commit_count))
datastore_write_fn = WriteToDatastore._DatastoreWriteFn(self._PROJECT)
datastore_write_fn.start_bundle()
for entity in entities:
datastore_write_fn.process(entity)
datastore_write_fn.finish_bundle()
self.assertListEqual([e.key for e in all_batch_entities],
[e.key for e in expected_entities])
batch_count = math.ceil(num_entities / util.WRITE_BATCH_MAX_SIZE)
self.assertLessEqual(batch_count, commit_count[0])
def test_DatastoreWriteLargeEntities(self):
"""100*100kB entities gets split over two Commit RPCs."""
with patch.object(helper, 'get_client', return_value=self._mock_client):
entities = helper.create_entities(100)
commit_count = [0]
self._mock_client.batch.side_effect = (
lambda: FakeBatch(commit_count=commit_count))
datastore_write_fn = WriteToDatastore._DatastoreWriteFn(
self._PROJECT)
datastore_write_fn.start_bundle()
for entity in entities:
entity.set_properties({'large': u'A' * 100000})
datastore_write_fn.process(entity)
datastore_write_fn.finish_bundle()
self.assertEqual(2, commit_count[0])
def check_estimated_size_bytes(self, entity_bytes, timestamp, namespace=None):
"""A helper method to test get_estimated_size_bytes"""
self._mock_client.namespace = namespace
self._mock_client.query.return_value = self._mock_query
self._mock_query.project = self._PROJECT
self._mock_query.namespace = namespace
self._mock_query.fetch.side_effect = [
[{'timestamp': timestamp}],
[{'entity_bytes': entity_bytes}],
]
self._mock_query.kind = self._KIND
split_query_fn = ReadFromDatastore._SplitQueryFn(num_splits=0)
self.assertEqual(entity_bytes,
split_query_fn.get_estimated_size_bytes(self._mock_client,
self._mock_query))
if namespace is None:
ns_keyword = '_'
else:
ns_keyword = '_Ns_'
self._mock_client.query.assert_has_calls([
call(kind='__Stat%sTotal__' % ns_keyword, order=['-timestamp']),
call().fetch(limit=1),
call(kind='__Stat%sKind__' % ns_keyword),
call().add_filter('kind_name', '=', self._KIND),
call().add_filter('timestamp', '=', timestamp),
call().fetch(limit=1),
])
def test_DatastoreDeleteFn(self):
with patch.object(helper, 'get_client', return_value=self._mock_client):
keys = [entity.key for entity in helper.create_entities(10)]
expected_keys = [key.to_client_key() for key in keys]
all_batch_keys = []
self._mock_client.batch.side_effect = (
lambda: FakeBatch(all_batch_items=all_batch_keys))
datastore_delete_fn = DeleteFromDatastore._DatastoreDeleteFn(
self._PROJECT)
datastore_delete_fn.start_bundle()
for key in keys:
datastore_delete_fn.process(key)
datastore_delete_fn.finish_bundle()
self.assertListEqual(all_batch_keys, expected_keys)
# Hide base class from collection by nose.
del DatastoreioTestBase
if __name__ == '__main__':
unittest.main()