blob: 7719165c4c12071d02efbaf5313d3721e774e6ca [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from unittest.mock import MagicMock, ANY
import mock
import pytest
from mock.mock import call
from google.cloud import datastore
from config import Origin, Config
from datastore_client import DatastoreClient, DatastoreException
from models import SdkEnum
from test_utils import _get_examples
"""
Unit tests for the Cloud Datastore client
"""
@mock.patch("google.cloud.datastore.Client")
def test_save_to_cloud_datastore_when_schema_version_not_found(
mock_datastore_client
):
"""
Test saving examples to the cloud datastore when the schema version not found
"""
with pytest.raises(
DatastoreException,
match="Schema versions not found. Schema versions must be downloaded during application startup",
):
examples = _get_examples(1)
client = DatastoreClient("MOCK_PROJECT_ID", Config.DEFAULT_NAMESPACE)
client.save_to_cloud_datastore(examples, SdkEnum.JAVA, Origin.PG_EXAMPLES)
@pytest.mark.parametrize("is_multifile", [False, True])
@pytest.mark.parametrize("with_kafka", [False, True])
@pytest.mark.parametrize(
"origin, key_prefix",
[
pytest.param(Origin.PG_EXAMPLES, "", id="PG_EXAMPLES"),
pytest.param(Origin.PG_BEAMDOC, "PG_BEAMDOC_", id="PG_BEAMDOC"),
pytest.param(Origin.TB_EXAMPLES, "TB_EXAMPLES_", id="TB_EXAMPLES"),
],
)
@pytest.mark.parametrize("namespace", [Config.DEFAULT_NAMESPACE, "Staging"])
@mock.patch("datastore_client.DatastoreClient._get_all_examples")
@mock.patch("datastore_client.DatastoreClient._get_actual_schema_version_key")
@mock.patch("google.cloud.datastore.Client")
def test_save_to_cloud_datastore_in_the_usual_case(
mock_client,
mock_get_schema,
mock_get_examples,
create_test_example,
origin,
key_prefix,
with_kafka,
is_multifile,
namespace,
):
"""
Test saving examples to the cloud datastore in the usual case
"""
mock_schema_key = MagicMock()
mock_get_schema.return_value = mock_schema_key
mock_examples = MagicMock()
mock_get_examples.return_value = mock_examples
project_id = "MOCK_PROJECT_ID"
examples = [create_test_example(is_multifile=is_multifile, with_kafka=with_kafka)]
client = DatastoreClient(project_id, namespace)
client.save_to_cloud_datastore(examples, SdkEnum.JAVA, origin)
mock_client.assert_called_once_with(namespace=namespace, project=project_id)
mock_client.assert_called_once()
mock_get_schema.assert_called_once()
mock_get_examples.assert_called_once()
calls = [
call().transaction(),
call().transaction().__enter__(),
call().key("pg_sdks", "SDK_JAVA"),
call().key("pg_examples", key_prefix + "SDK_JAVA_MOCK_NAME"),
call().put(ANY),
call().key("pg_snippets", key_prefix + "SDK_JAVA_MOCK_NAME"),
]
if with_kafka:
calls.append(
call().key(
"pg_datasets", "dataset_id_1"
), # used in the nested datasets construction
)
calls.extend(
[
call().put(ANY),
call().key("pg_pc_objects", key_prefix + "SDK_JAVA_MOCK_NAME_GRAPH"),
call().key("pg_pc_objects", key_prefix + "SDK_JAVA_MOCK_NAME_OUTPUT"),
call().key("pg_pc_objects", key_prefix + "SDK_JAVA_MOCK_NAME_LOG"),
call().put_multi([ANY, ANY, ANY]),
call().key("pg_files", key_prefix + "SDK_JAVA_MOCK_NAME_0"),
call().put(ANY),
]
)
if is_multifile:
calls.extend(
[
call().key("pg_files", key_prefix + "SDK_JAVA_MOCK_NAME_1"),
call().key("pg_files", key_prefix + "SDK_JAVA_MOCK_NAME_2"),
call().put_multi([ANY, ANY]),
]
)
if with_kafka:
calls.extend(
[
call().key("pg_datasets", "dataset_id_1"),
call().put_multi([ANY]),
]
)
calls.append(
call().transaction().__exit__(None, None, None),
)
mock_client.assert_has_calls(calls, any_order=False)
mock_client.delete_multi.assert_not_called()