blob: e3d19652c306d796d06704e5080cf440ac1c1fa7 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import datetime
import logging
import random
import string
import typing
import unittest
import mock
import apache_beam as beam
from apache_beam.metrics.metric import MetricsFilter
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
# Protect against environments where spanner library is not available.
# pylint: disable=wrong-import-order, wrong-import-position, ungrouped-imports
# pylint: disable=unused-import
try:
from google.cloud import spanner
from apache_beam.io.gcp.experimental.spannerio import create_transaction
from apache_beam.io.gcp.experimental.spannerio import ReadOperation
from apache_beam.io.gcp.experimental.spannerio import ReadFromSpanner
from apache_beam.io.gcp.experimental.spannerio import WriteMutation
from apache_beam.io.gcp.experimental.spannerio import MutationGroup
from apache_beam.io.gcp.experimental.spannerio import WriteToSpanner
from apache_beam.io.gcp.experimental.spannerio import _BatchFn
except ImportError:
spanner = None
# pylint: enable=wrong-import-order, wrong-import-position, ungrouped-imports
# pylint: enable=unused-import
MAX_DB_NAME_LENGTH = 30
TEST_PROJECT_ID = 'apache-beam-testing'
TEST_INSTANCE_ID = 'beam-test'
TEST_DATABASE_PREFIX = 'spanner-testdb-'
FAKE_TRANSACTION_INFO = {"session_id": "qwerty", "transaction_id": "qwerty"}
FAKE_ROWS = [[1, 'Alice'], [2, 'Bob'], [3, 'Carl'], [4, 'Dan'], [5, 'Evan'],
[6, 'Floyd']]
def _generate_database_name():
mask = string.ascii_lowercase + string.digits
length = MAX_DB_NAME_LENGTH - 1 - len(TEST_DATABASE_PREFIX)
return TEST_DATABASE_PREFIX + ''.join(
random.choice(mask) for i in range(length))
def _generate_test_data():
mask = string.ascii_lowercase + string.digits
length = 100
return [(
'users', ['Key', 'Value'],
[(x, ''.join(random.choice(mask) for _ in range(length)))
for x in range(1, 5)])]
@unittest.skipIf(spanner is None, 'GCP dependencies are not installed.')
@mock.patch('apache_beam.io.gcp.experimental.spannerio.Client')
@mock.patch('apache_beam.io.gcp.experimental.spannerio.BatchSnapshot')
class SpannerReadTest(unittest.TestCase):
def test_read_with_query_batch(
self, mock_batch_snapshot_class, mock_client_class):
mock_snapshot = mock.MagicMock()
mock_snapshot.generate_query_batches.return_value = [{
'query': {
'sql': 'SELECT * FROM users'
}, 'partition': 'test_partition'
} for _ in range(3)]
mock_snapshot.process_query_batch.side_effect = [
FAKE_ROWS[0:2], FAKE_ROWS[2:4], FAKE_ROWS[4:]
]
ro = [ReadOperation.query("Select * from users")]
pipeline = TestPipeline()
read = (
pipeline
| 'read' >> ReadFromSpanner(
TEST_PROJECT_ID,
TEST_INSTANCE_ID,
_generate_database_name(),
sql="SELECT * FROM users"))
readall = (
pipeline
| 'read all' >> ReadFromSpanner(
TEST_PROJECT_ID,
TEST_INSTANCE_ID,
_generate_database_name(),
read_operations=ro))
readpipeline = (
pipeline
| 'create reads' >> beam.Create(ro)
| 'reads' >> ReadFromSpanner(
TEST_PROJECT_ID, TEST_INSTANCE_ID, _generate_database_name()))
pipeline.run()
assert_that(read, equal_to(FAKE_ROWS), label='checkRead')
assert_that(readall, equal_to(FAKE_ROWS), label='checkReadAll')
assert_that(readpipeline, equal_to(FAKE_ROWS), label='checkReadPipeline')
def test_read_with_table_batch(
self, mock_batch_snapshot_class, mock_client_class):
mock_snapshot = mock.MagicMock()
mock_snapshot.generate_read_batches.return_value = [{
'read': {
'table': 'users',
'keyset': {
'all': True
},
'columns': ['Key', 'Value'],
'index': ''
},
'partition': 'test_partition'
} for _ in range(3)]
mock_snapshot.process_read_batch.side_effect = [
FAKE_ROWS[0:2], FAKE_ROWS[2:4], FAKE_ROWS[4:]
]
ro = [ReadOperation.table("users", ["Key", "Value"])]
pipeline = TestPipeline()
read = (
pipeline
| 'read' >> ReadFromSpanner(
TEST_PROJECT_ID,
TEST_INSTANCE_ID,
_generate_database_name(),
table="users",
columns=["Key", "Value"]))
readall = (
pipeline
| 'read all' >> ReadFromSpanner(
TEST_PROJECT_ID,
TEST_INSTANCE_ID,
_generate_database_name(),
read_operations=ro))
readpipeline = (
pipeline
| 'create reads' >> beam.Create(ro)
| 'reads' >> ReadFromSpanner(
TEST_PROJECT_ID, TEST_INSTANCE_ID, _generate_database_name()))
pipeline.run()
assert_that(read, equal_to(FAKE_ROWS), label='checkRead')
assert_that(readall, equal_to(FAKE_ROWS), label='checkReadAll')
assert_that(readpipeline, equal_to(FAKE_ROWS), label='checkReadPipeline')
with self.assertRaises(ValueError):
# Test the exception raised when user passes the read operations in the
# constructor and also in the pipeline.
_ = (
pipeline | 'reads error' >> ReadFromSpanner(
project_id=TEST_PROJECT_ID,
instance_id=TEST_INSTANCE_ID,
database_id=_generate_database_name(),
table="users"))
pipeline.run()
def test_read_with_index(self, mock_batch_snapshot_class, mock_client_class):
mock_snapshot = mock.MagicMock()
mock_snapshot.generate_read_batches.return_value = [{
'read': {
'table': 'users',
'keyset': {
'all': True
},
'columns': ['Key', 'Value'],
'index': ''
},
'partition': 'test_partition'
} for _ in range(3)]
mock_snapshot.process_read_batch.side_effect = [
FAKE_ROWS[0:2], FAKE_ROWS[2:4], FAKE_ROWS[4:]
]
ro = [ReadOperation.table("users", ["Key", "Value"], index="Key")]
pipeline = TestPipeline()
read = (
pipeline
| 'read' >> ReadFromSpanner(
TEST_PROJECT_ID,
TEST_INSTANCE_ID,
_generate_database_name(),
table="users",
columns=["Key", "Value"]))
readall = (
pipeline
| 'read all' >> ReadFromSpanner(
TEST_PROJECT_ID,
TEST_INSTANCE_ID,
_generate_database_name(),
read_operations=ro))
readpipeline = (
pipeline
| 'create reads' >> beam.Create(ro)
| 'reads' >> ReadFromSpanner(
TEST_PROJECT_ID, TEST_INSTANCE_ID, _generate_database_name()))
pipeline.run()
assert_that(read, equal_to(FAKE_ROWS), label='checkRead')
assert_that(readall, equal_to(FAKE_ROWS), label='checkReadAll')
assert_that(readpipeline, equal_to(FAKE_ROWS), label='checkReadPipeline')
with self.assertRaises(ValueError):
# Test the exception raised when user passes the read operations in the
# constructor and also in the pipeline.
_ = (
pipeline | 'reads error' >> ReadFromSpanner(
project_id=TEST_PROJECT_ID,
instance_id=TEST_INSTANCE_ID,
database_id=_generate_database_name(),
table="users"))
pipeline.run()
def test_read_with_transaction(
self, mock_batch_snapshot_class, mock_client_class):
mock_client = mock.MagicMock()
mock_instance = mock.MagicMock()
mock_database = mock.MagicMock()
mock_snapshot = mock.MagicMock()
mock_client_class.return_value = mock_client
mock_client.instance.return_value = mock_instance
mock_instance.database.return_value = mock_database
mock_database.batch_snapshot.return_value = mock_snapshot
mock_batch_snapshot_class.return_value = mock_snapshot
mock_batch_snapshot_class.from_dict.return_value = mock_snapshot
mock_snapshot.to_dict.return_value = FAKE_TRANSACTION_INFO
mock_session = mock.MagicMock()
mock_transaction_ctx = mock.MagicMock()
mock_transaction = mock.MagicMock()
mock_snapshot._get_session.return_value = mock_session
mock_session.transaction.return_value = mock_transaction
mock_transaction.__enter__.return_value = mock_transaction_ctx
mock_transaction_ctx.execute_sql.return_value = FAKE_ROWS
ro = [ReadOperation.query("Select * from users")]
p = TestPipeline()
transaction = (
p | create_transaction(
project_id=TEST_PROJECT_ID,
instance_id=TEST_INSTANCE_ID,
database_id=_generate_database_name(),
exact_staleness=datetime.timedelta(seconds=10)))
read_query = (
p | 'with query' >> ReadFromSpanner(
project_id=TEST_PROJECT_ID,
instance_id=TEST_INSTANCE_ID,
database_id=_generate_database_name(),
transaction=transaction,
sql="Select * from users"))
read_table = (
p | 'with table' >> ReadFromSpanner(
project_id=TEST_PROJECT_ID,
instance_id=TEST_INSTANCE_ID,
database_id=_generate_database_name(),
transaction=transaction,
table="users",
columns=["Key", "Value"]))
read_indexed_table = (
p | 'with index' >> ReadFromSpanner(
project_id=TEST_PROJECT_ID,
instance_id=TEST_INSTANCE_ID,
database_id=_generate_database_name(),
transaction=transaction,
table="users",
index="Key",
columns=["Key", "Value"]))
read = (
p | 'read all' >> ReadFromSpanner(
TEST_PROJECT_ID,
TEST_INSTANCE_ID,
_generate_database_name(),
transaction=transaction,
read_operations=ro))
read_pipeline = (
p
| 'create read operations' >> beam.Create(ro)
| 'reads' >> ReadFromSpanner(
TEST_PROJECT_ID,
TEST_INSTANCE_ID,
_generate_database_name(),
transaction=transaction))
p.run()
assert_that(read_query, equal_to(FAKE_ROWS), label='checkQuery')
assert_that(read_table, equal_to(FAKE_ROWS), label='checkTable')
assert_that(
read_indexed_table, equal_to(FAKE_ROWS), label='checkTableIndex')
assert_that(read, equal_to(FAKE_ROWS), label='checkReadAll')
assert_that(read_pipeline, equal_to(FAKE_ROWS), label='checkReadPipeline')
with self.assertRaises(ValueError):
# Test the exception raised when user passes the read operations in the
# constructor and also in the pipeline.
_ = (
p
| 'create read operations2' >> beam.Create(ro)
| 'reads with error' >> ReadFromSpanner(
TEST_PROJECT_ID,
TEST_INSTANCE_ID,
_generate_database_name(),
transaction=transaction,
read_operations=ro))
p.run()
def test_invalid_transaction(
self, mock_batch_snapshot_class, mock_client_class):
with self.assertRaises(ValueError):
p = TestPipeline()
transaction = (
p | beam.Create([{
"invalid": "transaction"
}]).with_output_types(typing.Any))
_ = (
p | 'with query' >> ReadFromSpanner(
project_id=TEST_PROJECT_ID,
instance_id=TEST_INSTANCE_ID,
database_id=_generate_database_name(),
transaction=transaction,
sql="Select * from users"))
p.run()
def test_display_data(self, *args):
dd_sql = ReadFromSpanner(
project_id=TEST_PROJECT_ID,
instance_id=TEST_INSTANCE_ID,
database_id=_generate_database_name(),
sql="Select * from users").display_data()
dd_table = ReadFromSpanner(
project_id=TEST_PROJECT_ID,
instance_id=TEST_INSTANCE_ID,
database_id=_generate_database_name(),
table="users",
columns=['id', 'name']).display_data()
dd_transaction = ReadFromSpanner(
project_id=TEST_PROJECT_ID,
instance_id=TEST_INSTANCE_ID,
database_id=_generate_database_name(),
table="users",
columns=['id', 'name'],
transaction={
"transaction_id": "test123", "session_id": "test456"
}).display_data()
self.assertTrue("sql" in dd_sql)
self.assertTrue("table" in dd_table)
self.assertTrue("table" in dd_transaction)
self.assertTrue("transaction" in dd_transaction)
@unittest.skipIf(spanner is None, 'GCP dependencies are not installed.')
@mock.patch('apache_beam.io.gcp.experimental.spannerio.Client')
@mock.patch('google.cloud.spanner_v1.database.BatchCheckout')
class SpannerWriteTest(unittest.TestCase):
def test_spanner_write(self, mock_batch_snapshot_class, mock_batch_checkout):
ks = spanner.KeySet(keys=[[1233], [1234]])
mutations = [
WriteMutation.delete("roles", ks),
WriteMutation.insert(
"roles", ("key", "rolename"), [('1233', "mutations-inset-1233")]),
WriteMutation.insert(
"roles", ("key", "rolename"), [('1234', "mutations-inset-1234")]),
WriteMutation.update(
"roles", ("key", "rolename"),
[('1234', "mutations-inset-1233-updated")]),
]
p = TestPipeline()
_ = (
p
| beam.Create(mutations)
| WriteToSpanner(
project_id=TEST_PROJECT_ID,
instance_id=TEST_INSTANCE_ID,
database_id=_generate_database_name(),
max_batch_size_bytes=1024))
res = p.run()
res.wait_until_finish()
metric_results = res.metrics().query(
MetricsFilter().with_name("SpannerBatches"))
batches_counter = metric_results['counters'][0]
self.assertEqual(batches_counter.committed, 2)
self.assertEqual(batches_counter.attempted, 2)
def test_spanner_bundles_size(
self, mock_batch_snapshot_class, mock_batch_checkout):
ks = spanner.KeySet(keys=[[1233], [1234]])
mutations = [
WriteMutation.delete("roles", ks),
WriteMutation.insert(
"roles", ("key", "rolename"), [('1234', "mutations-inset-1234")])
] * 50
p = TestPipeline()
_ = (
p
| beam.Create(mutations)
| WriteToSpanner(
project_id=TEST_PROJECT_ID,
instance_id=TEST_INSTANCE_ID,
database_id=_generate_database_name(),
max_batch_size_bytes=1024))
res = p.run()
res.wait_until_finish()
metric_results = res.metrics().query(
MetricsFilter().with_name('SpannerBatches'))
batches_counter = metric_results['counters'][0]
self.assertEqual(batches_counter.committed, 53)
self.assertEqual(batches_counter.attempted, 53)
def test_spanner_write_mutation_groups(
self, mock_batch_snapshot_class, mock_batch_checkout):
ks = spanner.KeySet(keys=[[1233], [1234]])
mutation_groups = [
MutationGroup([
WriteMutation.insert(
"roles", ("key", "rolename"),
[('9001233', "mutations-inset-1233")]),
WriteMutation.insert(
"roles", ("key", "rolename"),
[('9001234', "mutations-inset-1234")])
]),
MutationGroup([
WriteMutation.update(
"roles", ("key", "rolename"),
[('9001234', "mutations-inset-9001233-updated")])
]),
MutationGroup([WriteMutation.delete("roles", ks)])
]
p = TestPipeline()
_ = (
p
| beam.Create(mutation_groups)
| WriteToSpanner(
project_id=TEST_PROJECT_ID,
instance_id=TEST_INSTANCE_ID,
database_id=_generate_database_name(),
max_batch_size_bytes=100))
res = p.run()
res.wait_until_finish()
metric_results = res.metrics().query(
MetricsFilter().with_name('SpannerBatches'))
batches_counter = metric_results['counters'][0]
self.assertEqual(batches_counter.committed, 3)
self.assertEqual(batches_counter.attempted, 3)
def test_batch_byte_size(
self, mock_batch_snapshot_class, mock_batch_checkout):
# each mutation group byte size is 58 bytes.
mutation_group = [
MutationGroup([
WriteMutation.insert(
"roles",
("key", "rolename"), [('1234', "mutations-inset-1234")])
])
] * 50
with TestPipeline() as p:
# the total 50 mutation group size will be 2900 (58 * 50)
# if we want to make two batches, so batch size should be 1450 (2900 / 2)
# and each bach should contains 25 mutations.
res = (
p | beam.Create(mutation_group)
| beam.ParDo(
_BatchFn(
max_batch_size_bytes=1450,
max_number_rows=50,
max_number_cells=500))
| beam.Map(lambda x: len(x)))
assert_that(res, equal_to([25] * 2))
def test_batch_disable(self, mock_batch_snapshot_class, mock_batch_checkout):
mutation_group = [
MutationGroup([
WriteMutation.insert(
"roles",
("key", "rolename"), [('1234', "mutations-inset-1234")])
])
] * 4
with TestPipeline() as p:
# to disable to batching, we need to set any of the batching parameters
# either to lower value or zero
res = (
p | beam.Create(mutation_group)
| beam.ParDo(
_BatchFn(
max_batch_size_bytes=1450,
max_number_rows=0,
max_number_cells=500))
| beam.Map(lambda x: len(x)))
assert_that(res, equal_to([1] * 4))
def test_batch_max_rows(self, mock_batch_snapshot_class, mock_batch_checkout):
mutation_group = [
MutationGroup([
WriteMutation.insert(
"roles", ("key", "rolename"),
[
('1234', "mutations-inset-1234"),
('1235', "mutations-inset-1235"),
])
])
] * 50
with TestPipeline() as p:
# There are total 50 mutation groups, each contains two rows.
# The total number of rows will be 100 (50 * 2).
# If each batch contains 10 rows max then batch count should be 10
# (contains 5 mutation groups each).
res = (
p | beam.Create(mutation_group)
| beam.ParDo(
_BatchFn(
max_batch_size_bytes=1048576,
max_number_rows=10,
max_number_cells=500))
| beam.Map(lambda x: len(x)))
assert_that(res, equal_to([5] * 10))
def test_batch_max_cells(
self, mock_batch_snapshot_class, mock_batch_checkout):
mutation_group = [
MutationGroup([
WriteMutation.insert(
"roles", ("key", "rolename"),
[
('1234', "mutations-inset-1234"),
('1235', "mutations-inset-1235"),
])
])
] * 50
with TestPipeline() as p:
# There are total 50 mutation groups, each contains two rows (or 4 cells).
# The total number of cells will be 200 (50 groups * 4 cells).
# If each batch contains 50 cells max then batch count should be 5.
# 4 batches contains 12 mutations groups and the fifth batch should be
# consists of 2 mutation group element.
# No. of mutations groups per batch = Max Cells / Cells per mutation group
# total_batches = Total Number of Cells / Max Cells
res = (
p | beam.Create(mutation_group)
| beam.ParDo(
_BatchFn(
max_batch_size_bytes=1048576,
max_number_rows=500,
max_number_cells=50))
| beam.Map(lambda x: len(x)))
assert_that(res, equal_to([12, 12, 12, 12, 2]))
def test_write_mutation_error(self, *args):
with self.assertRaises(ValueError):
# since `WriteMutation` only accept one operation.
WriteMutation(insert="table-name", update="table-name")
def test_display_data(self, *args):
data = WriteToSpanner(
project_id=TEST_PROJECT_ID,
instance_id=TEST_INSTANCE_ID,
database_id=_generate_database_name(),
max_batch_size_bytes=1024).display_data()
self.assertTrue("project_id" in data)
self.assertTrue("instance_id" in data)
self.assertTrue("pool" in data)
self.assertTrue("database" in data)
self.assertTrue("batch_size" in data)
self.assertTrue("max_number_rows" in data)
self.assertTrue("max_number_cells" in data)
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()