| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| """Unit tests for datastoreio.""" |
| |
| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
| |
| import datetime |
| import math |
| import unittest |
| |
| from mock import MagicMock |
| from mock import call |
| from mock import patch |
| |
| # Protect against environments where datastore library is not available. |
| try: |
| from apache_beam.io.gcp.datastore.v1 import util |
| from apache_beam.io.gcp.datastore.v1new import helper |
| from apache_beam.io.gcp.datastore.v1new import query_splitter |
| from apache_beam.io.gcp.datastore.v1new.datastoreio import DeleteFromDatastore |
| from apache_beam.io.gcp.datastore.v1new.datastoreio import ReadFromDatastore |
| from apache_beam.io.gcp.datastore.v1new.datastoreio import WriteToDatastore |
| from google.cloud.datastore import client |
| from google.cloud.datastore import entity |
| from google.cloud.datastore import helpers |
| from google.cloud.datastore import key |
| # Keep this import last so it doesn't import conflicting pb2 modules. |
| from apache_beam.io.gcp.datastore.v1 import datastoreio_test # pylint: disable=ungrouped-imports |
| DatastoreioTestBase = datastoreio_test.DatastoreioTest |
| # TODO(BEAM-4543): Remove TypeError once googledatastore dependency is removed. |
| except (ImportError, TypeError): |
| client = None |
| DatastoreioTestBase = unittest.TestCase |
| |
| |
| class FakeMutation(object): |
| def __init__(self, entity=None, key=None): |
| """Fake mutation request object. |
| |
| Requires exactly one of entity or key to be set. |
| |
| Args: |
| entity: (``google.cloud.datastore.entity.Entity``) entity representing |
| this upsert mutation |
| key: (``google.cloud.datastore.key.Key``) key representing |
| this delete mutation |
| """ |
| self.entity = entity |
| self.key = key |
| |
| def ByteSize(self): |
| if self.entity is not None: |
| return helpers.entity_to_protobuf(self.entity).ByteSize() |
| else: |
| return self.key.to_protobuf().ByteSize() |
| |
| |
| class FakeBatch(object): |
| def __init__(self, all_batch_items=None, commit_count=None): |
| """Fake ``google.cloud.datastore.batch.Batch`` object. |
| |
| Args: |
| all_batch_items: (list) If set, will append all entities/keys added to |
| this batch. |
| commit_count: (list of int) If set, will increment commit_count[0] on |
| each ``commit``. |
| """ |
| self._all_batch_items = all_batch_items |
| self._commit_count = commit_count |
| self.mutations = [] |
| |
| def put(self, _entity): |
| assert isinstance(_entity, entity.Entity) |
| self.mutations.append(FakeMutation(entity=_entity)) |
| if self._all_batch_items is not None: |
| self._all_batch_items.append(_entity) |
| |
| def delete(self, _key): |
| assert isinstance(_key, key.Key) |
| self.mutations.append(FakeMutation(key=_key)) |
| if self._all_batch_items is not None: |
| self._all_batch_items.append(_key) |
| |
| def begin(self): |
| pass |
| |
| def commit(self): |
| if self._commit_count: |
| self._commit_count[0] += 1 |
| |
| |
| @unittest.skipIf(client is None, 'Datastore dependencies are not installed') |
| class DatastoreioTest(DatastoreioTestBase): |
| """ |
| NOTE: This test inherits test cases from DatastoreioTestBase. |
| Please prefer to add new test cases to v1/datastoreio_test if possible. |
| """ |
| def setUp(self): |
| self._WRITE_BATCH_INITIAL_SIZE = util.WRITE_BATCH_INITIAL_SIZE |
| self._mock_client = MagicMock() |
| self._mock_client.project = self._PROJECT |
| self._mock_client.namespace = self._NAMESPACE |
| self._mock_query = MagicMock() |
| self._mock_query.limit = None |
| self._mock_query.order = None |
| |
| self._real_client = client.Client( |
| project=self._PROJECT, namespace=self._NAMESPACE, |
| # Don't do any network requests. |
| _http=MagicMock()) |
| |
| def get_timestamp(self): |
| return datetime.datetime(2019, 3, 14, 15, 9, 26, 535897) |
| |
| def test_SplitQueryFn_with_num_splits(self): |
| with patch.object(helper, 'get_client', return_value=self._mock_client): |
| num_splits = 23 |
| expected_num_splits = 23 |
| |
| def fake_get_splits(unused_client, query, num_splits): |
| return [query] * num_splits |
| |
| with patch.object(query_splitter, 'get_splits', |
| side_effect=fake_get_splits): |
| split_query_fn = ReadFromDatastore._SplitQueryFn(num_splits) |
| split_queries = split_query_fn.process(self._mock_query) |
| |
| self.assertEqual(expected_num_splits, len(split_queries)) |
| |
| def test_SplitQueryFn_without_num_splits(self): |
| with patch.object(helper, 'get_client', return_value=self._mock_client): |
| # Force _SplitQueryFn to compute the number of query splits |
| num_splits = 0 |
| expected_num_splits = 23 |
| entity_bytes = (expected_num_splits * |
| ReadFromDatastore._DEFAULT_BUNDLE_SIZE_BYTES) |
| with patch.object( |
| ReadFromDatastore._SplitQueryFn, 'get_estimated_size_bytes', |
| return_value=entity_bytes): |
| |
| def fake_get_splits(unused_client, query, num_splits): |
| return [query] * num_splits |
| |
| with patch.object(query_splitter, 'get_splits', |
| side_effect=fake_get_splits): |
| split_query_fn = ReadFromDatastore._SplitQueryFn(num_splits) |
| split_queries = split_query_fn.process(self._mock_query) |
| |
| self.assertEqual(expected_num_splits, len(split_queries)) |
| |
| def test_SplitQueryFn_with_query_limit(self): |
| """A test that verifies no split is performed when the query has a limit.""" |
| with patch.object(helper, 'get_client', return_value=self._mock_client): |
| num_splits = 4 |
| expected_num_splits = 1 |
| self._mock_query.limit = 3 |
| split_query_fn = ReadFromDatastore._SplitQueryFn(num_splits) |
| split_queries = split_query_fn.process(self._mock_query) |
| |
| self.assertEqual(expected_num_splits, len(split_queries)) |
| |
| def test_SplitQueryFn_with_exception(self): |
| """A test that verifies that no split is performed when failures occur.""" |
| with patch.object(helper, 'get_client', return_value=self._mock_client): |
| # Force _SplitQueryFn to compute the number of query splits |
| num_splits = 0 |
| expected_num_splits = 1 |
| entity_bytes = (expected_num_splits * |
| ReadFromDatastore._DEFAULT_BUNDLE_SIZE_BYTES) |
| with patch.object( |
| ReadFromDatastore._SplitQueryFn, 'get_estimated_size_bytes', |
| return_value=entity_bytes): |
| |
| with patch.object(query_splitter, 'get_splits', |
| side_effect=query_splitter.QuerySplitterError( |
| "Testing query split error")): |
| split_query_fn = ReadFromDatastore._SplitQueryFn(num_splits) |
| split_queries = split_query_fn.process(self._mock_query) |
| |
| self.assertEqual(expected_num_splits, len(split_queries)) |
| self.assertEqual(self._mock_query, split_queries[0]) |
| |
| def check_DatastoreWriteFn(self, num_entities, use_fixed_batch_size=False): |
| """A helper function to test _DatastoreWriteFn.""" |
| with patch.object(helper, 'get_client', return_value=self._mock_client): |
| entities = helper.create_entities(num_entities) |
| expected_entities = [entity.to_client_entity() for entity in entities] |
| |
| all_batch_entities = [] |
| commit_count = [0] |
| self._mock_client.batch.side_effect = ( |
| lambda: FakeBatch(all_batch_items=all_batch_entities, |
| commit_count=commit_count)) |
| |
| datastore_write_fn = WriteToDatastore._DatastoreWriteFn(self._PROJECT) |
| |
| datastore_write_fn.start_bundle() |
| for entity in entities: |
| datastore_write_fn.process(entity) |
| datastore_write_fn.finish_bundle() |
| |
| self.assertListEqual([e.key for e in all_batch_entities], |
| [e.key for e in expected_entities]) |
| batch_count = math.ceil(num_entities / util.WRITE_BATCH_MAX_SIZE) |
| self.assertLessEqual(batch_count, commit_count[0]) |
| |
| def test_DatastoreWriteLargeEntities(self): |
| """100*100kB entities gets split over two Commit RPCs.""" |
| with patch.object(helper, 'get_client', return_value=self._mock_client): |
| entities = helper.create_entities(100) |
| commit_count = [0] |
| self._mock_client.batch.side_effect = ( |
| lambda: FakeBatch(commit_count=commit_count)) |
| |
| datastore_write_fn = WriteToDatastore._DatastoreWriteFn( |
| self._PROJECT) |
| datastore_write_fn.start_bundle() |
| for entity in entities: |
| entity.set_properties({'large': u'A' * 100000}) |
| datastore_write_fn.process(entity) |
| datastore_write_fn.finish_bundle() |
| |
| self.assertEqual(2, commit_count[0]) |
| |
| def check_estimated_size_bytes(self, entity_bytes, timestamp, namespace=None): |
| """A helper method to test get_estimated_size_bytes""" |
| self._mock_client.namespace = namespace |
| self._mock_client.query.return_value = self._mock_query |
| self._mock_query.project = self._PROJECT |
| self._mock_query.namespace = namespace |
| self._mock_query.fetch.side_effect = [ |
| [{'timestamp': timestamp}], |
| [{'entity_bytes': entity_bytes}], |
| ] |
| self._mock_query.kind = self._KIND |
| |
| split_query_fn = ReadFromDatastore._SplitQueryFn(num_splits=0) |
| self.assertEqual(entity_bytes, |
| split_query_fn.get_estimated_size_bytes(self._mock_client, |
| self._mock_query)) |
| |
| if namespace is None: |
| ns_keyword = '_' |
| else: |
| ns_keyword = '_Ns_' |
| self._mock_client.query.assert_has_calls([ |
| call(kind='__Stat%sTotal__' % ns_keyword, order=['-timestamp']), |
| call().fetch(limit=1), |
| call(kind='__Stat%sKind__' % ns_keyword), |
| call().add_filter('kind_name', '=', self._KIND), |
| call().add_filter('timestamp', '=', timestamp), |
| call().fetch(limit=1), |
| ]) |
| |
| def test_DatastoreDeleteFn(self): |
| with patch.object(helper, 'get_client', return_value=self._mock_client): |
| keys = [entity.key for entity in helper.create_entities(10)] |
| expected_keys = [key.to_client_key() for key in keys] |
| |
| all_batch_keys = [] |
| self._mock_client.batch.side_effect = ( |
| lambda: FakeBatch(all_batch_items=all_batch_keys)) |
| |
| datastore_delete_fn = DeleteFromDatastore._DatastoreDeleteFn( |
| self._PROJECT) |
| |
| datastore_delete_fn.start_bundle() |
| for key in keys: |
| datastore_delete_fn.process(key) |
| datastore_delete_fn.finish_bundle() |
| |
| self.assertListEqual(all_batch_keys, expected_keys) |
| |
| |
| # Hide base class from collection by nose. |
| del DatastoreioTestBase |
| |
| |
| if __name__ == '__main__': |
| unittest.main() |