#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""Cloud Datastore query splitter test."""

from __future__ import absolute_import

import unittest

import mock

# Protect against environments where datastore library is not available.
try:
  from apache_beam.io.gcp.datastore.v1new import helper
  from apache_beam.io.gcp.datastore.v1new import query_splitter
  from apache_beam.io.gcp.datastore.v1new import types
  from apache_beam.io.gcp.datastore.v1new.query_splitter import SplitNotPossibleError
  from google.cloud.datastore import key
  # Keep this import last so it doesn't import conflicting pb2 modules.
  from apache_beam.io.gcp.datastore.v1 import query_splitter_test  # pylint: disable=ungrouped-imports
  QuerySplitterTestBase = query_splitter_test.QuerySplitterTest

# TODO(BEAM-4543): Remove TypeError once googledatastore dependency is removed.
except (ImportError, TypeError):
  query_splitter = None
  SplitNotPossibleError = None
  QuerySplitterTestBase = unittest.TestCase


@unittest.skipIf(query_splitter is None, 'GCP dependencies are not installed')
class QuerySplitterTest(QuerySplitterTestBase):
  """v1new adaptation of QuerySplitterTest.

  NOTE: This test inherits test cases from QuerySplitterTestBase.
  Please prefer to add new test cases to v1/query_splitter_test if possible.
  """
  _PROJECT = 'project'
  _NAMESPACE = 'namespace'

  split_error = SplitNotPossibleError
  query_splitter = query_splitter

  def setUp(self):
    """Overrides base class version with skipIf() decorators."""

  def create_query(self, kinds=(), order=False, limit=None, offset=None,
                   inequality_filter=False):
    if len(kinds) > 1:
      self.skipTest('v1new queries do not support more than one kind.')
    if offset is not None:
      self.skipTest('v1new queries do not support offsets.')

    kind = None
    filters = []
    if kinds:
      kind = kinds[0]
    if order:
      order = ['prop1']
    if inequality_filter:
      filters = [('prop1', '>', 'value1')]

    return types.Query(kind=kind, filters=filters, order=order, limit=limit)

  def test_get_splits_query_with_num_splits_of_one(self):
    query = self.create_query()
    with self.assertRaisesRegexp(self.split_error, r'num_splits'):
      query_splitter.get_splits(None, query, 1)

  def test_create_scatter_query(self):
    query = types.Query(kind='shakespeare-demo')
    num_splits = 10
    scatter_query = query_splitter._create_scatter_query(query, num_splits)
    self.assertEqual(scatter_query.kind, query.kind)
    self.assertEqual(scatter_query.limit,
                     (num_splits -1) * query_splitter.KEYS_PER_SPLIT)
    self.assertEqual(scatter_query.order,
                     [query_splitter.SCATTER_PROPERTY_NAME])
    self.assertEqual(scatter_query.projection,
                     [query_splitter.KEY_PROPERTY_NAME])

  def check_get_splits(self, query, num_splits, num_entities,
                       unused_batch_size):
    """A helper method to test the query_splitter get_splits method.

    Args:
      query: the query to be split
      num_splits: number of splits
      num_entities: number of scatter entities returned to the splitter.
      unused_batch_size: ignored in v1new since query results are entirely
        handled by the Datastore client.
    """
    # Test for both random long ids and string ids.
    for id_or_name in [True, False]:
      client_entities = helper.create_client_entities(num_entities, id_or_name)

      mock_client = mock.MagicMock()
      mock_client_query = mock.MagicMock()
      mock_client_query.fetch.return_value = client_entities
      with mock.patch.object(
          types.Query, '_to_client_query', return_value=mock_client_query):
        split_queries = query_splitter.get_splits(
            mock_client, query, num_splits)

      mock_client_query.fetch.assert_called_once()
      # if request num_splits is greater than num_entities, the best it can
      # do is one entity per split.
      expected_num_splits = min(num_splits, num_entities + 1)
      self.assertEqual(len(split_queries), expected_num_splits)

      # Verify no gaps in key ranges. Filters should look like:
      # query1: (__key__ < key1)
      # query2: (__key__ >= key1), (__key__ < key2)
      # ...
      # queryN: (__key__ >=keyN-1)
      prev_client_key = None
      last_query_seen = False
      for split_query in split_queries:
        self.assertFalse(last_query_seen)
        lt_key = None
        gte_key = None
        for _filter in split_query.filters:
          self.assertEqual(query_splitter.KEY_PROPERTY_NAME, _filter[0])
          if _filter[1] == '<':
            lt_key = _filter[2]
          elif _filter[1] == '>=':
            gte_key = _filter[2]

        # Case where the scatter query has no results.
        if lt_key is None and gte_key is None:
          self.assertEqual(1, len(split_queries))
          break

        if prev_client_key is None:
          self.assertIsNone(gte_key)
          self.assertIsNotNone(lt_key)
          prev_client_key = lt_key
        else:
          self.assertEqual(prev_client_key, gte_key)
          prev_client_key = lt_key
          if lt_key is None:
            last_query_seen = True

  def test_client_key_sort_key(self):
    k = key.Key('kind1', 1, project=self._PROJECT, namespace=self._NAMESPACE)
    k2 = key.Key('kind2', 'a', parent=k)
    k3 = key.Key('kind2', 'b', parent=k)
    k4 = key.Key('kind1', 'a', project=self._PROJECT, namespace=self._NAMESPACE)
    k5 = key.Key('kind1', 'a', project=self._PROJECT)
    keys = [k5, k, k4, k3, k2, k2, k]
    expected_sort = [k5, k, k, k2, k2, k3, k4]
    keys.sort(key=query_splitter.client_key_sort_key)
    self.assertEqual(expected_sort, keys)


# Hide base class from collection by nose.
del QuerySplitterTestBase


if __name__ == '__main__':
  unittest.main()
