blob: 6c4eacfc7e21b14cb0375dfb5203957f862d96b7 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from __future__ import absolute_import
import logging
import unittest
import mock
from bson import objectid
from pymongo import ReplaceOne
import apache_beam as beam
from apache_beam.io import ReadFromMongoDB
from apache_beam.io import WriteToMongoDB
from apache_beam.io import source_test_utils
from apache_beam.io.mongodbio import _BoundedMongoSource
from apache_beam.io.mongodbio import _GenerateObjectIdFn
from apache_beam.io.mongodbio import _MongoSink
from apache_beam.io.mongodbio import _WriteMongoFn
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
class MongoSourceTest(unittest.TestCase):
@mock.patch('apache_beam.io.mongodbio._BoundedMongoSource'
'._get_document_count')
@mock.patch('apache_beam.io.mongodbio._BoundedMongoSource'
'._get_avg_document_size')
def setUp(self, mock_size, mock_count):
mock_size.return_value = 10
mock_count.return_value = 5
self.mongo_source = _BoundedMongoSource('mongodb://test', 'testdb',
'testcoll')
def test_estimate_size(self):
self.assertEqual(self.mongo_source.estimate_size(), 50)
@mock.patch('apache_beam.io.mongodbio.MongoClient')
def test_split(self, mock_client):
# desired bundle size is 1 times of avg doc size, each bundle contains 1
# documents
mock_client.return_value.__enter__.return_value.__getitem__.return_value \
.__getitem__.return_value.find.return_value = [{'x': 1}, {'x': 2},
{'x': 3}, {'x': 4},
{'x': 5}]
for size in [10, 20, 100]:
splits = list(
self.mongo_source.split(start_position=0,
stop_position=5,
desired_bundle_size=size))
reference_info = (self.mongo_source, None, None)
sources_info = ([(split.source, split.start_position, split.stop_position)
for split in splits])
source_test_utils.assert_sources_equal_reference_source(
reference_info, sources_info)
@mock.patch('apache_beam.io.mongodbio.MongoClient')
def test_dynamic_work_rebalancing(self, mock_client):
splits = list(self.mongo_source.split(desired_bundle_size=3000))
mock_client.return_value.__enter__.return_value.__getitem__.return_value \
.__getitem__.return_value.find.return_value = [{'x': 1}, {'x': 2},
{'x': 3}, {'x': 4},
{'x': 5}]
assert len(splits) == 1
source_test_utils.assert_split_at_fraction_exhaustive(
splits[0].source, splits[0].start_position, splits[0].stop_position)
@mock.patch('apache_beam.io.mongodbio.OffsetRangeTracker')
def test_get_range_tracker(self, mock_tracker):
self.mongo_source.get_range_tracker(None, None)
mock_tracker.assert_called_with(0, 5)
self.mongo_source.get_range_tracker(10, 20)
mock_tracker.assert_called_with(10, 20)
@mock.patch('apache_beam.io.mongodbio.MongoClient')
def test_read(self, mock_client):
mock_tracker = mock.MagicMock()
mock_tracker.try_claim.return_value = True
mock_tracker.start_position.return_value = 0
mock_tracker.stop_position.return_value = 2
mock_client.return_value.__enter__.return_value.__getitem__.return_value\
.__getitem__.return_value.find.return_value = [{'x':1}, {'x':2}]
result = []
for i in self.mongo_source.read(mock_tracker):
result.append(i)
self.assertListEqual([{'x': 1}, {'x': 2}], result)
def test_display_data(self):
data = self.mongo_source.display_data()
self.assertTrue('uri' in data)
self.assertTrue('database' in data)
self.assertTrue('collection' in data)
@mock.patch('apache_beam.io.mongodbio.MongoClient')
def test__get_avg_document_size(self, mock_client):
mock_client.return_value.__enter__.return_value.__getitem__\
.return_value.command.return_value = {'avgObjSize': 5}
self.assertEqual(5, self.mongo_source._get_avg_document_size())
@mock.patch('apache_beam.io.mongodbio.MongoClient')
def test_get_document_count(self, mock_client):
mock_client.return_value.__enter__.return_value.__getitem__ \
.return_value.__getitem__.return_value.count_documents.return_value = 10
self.assertEqual(10, self.mongo_source._get_document_count())
class ReadFromMongoDBTest(unittest.TestCase):
@mock.patch('apache_beam.io.mongodbio.MongoClient')
def test_read_from_mongodb(self, mock_client):
objects = [{'x': 1}, {'x': 2}]
mock_client.return_value.__enter__.return_value.__getitem__.return_value. \
command.return_value = {'avgObjSize': 1}
mock_client.return_value.__enter__.return_value.__getitem__.return_value. \
__getitem__.return_value.find.return_value = objects
mock_client.return_value.__enter__.return_value.__getitem__.return_value. \
__getitem__.return_value.count_documents.return_value = 2
with TestPipeline() as p:
docs = p | 'ReadFromMongoDB' >> ReadFromMongoDB(
uri='mongodb://test', db='db', coll='collection')
assert_that(docs, equal_to(objects))
class GenerateObjectIdFnTest(unittest.TestCase):
def test_process(self):
with TestPipeline() as p:
output = (p | "Create" >> beam.Create([{
'x': 1
}, {
'x': 2,
'_id': 123
}])
| "Generate ID" >> beam.ParDo(_GenerateObjectIdFn())
| "Check" >> beam.Map(lambda x: '_id' in x))
assert_that(output, equal_to([True] * 2))
class WriteMongoFnTest(unittest.TestCase):
@mock.patch('apache_beam.io.mongodbio._MongoSink')
def test_process(self, mock_sink):
docs = [{'x': 1}, {'x': 2}, {'x': 3}]
with TestPipeline() as p:
_ = (p | "Create" >> beam.Create(docs)
| "Write" >> beam.ParDo(_WriteMongoFn(batch_size=2)))
p.run()
self.assertEqual(
2, mock_sink.return_value.__enter__.return_value.write.call_count)
def test_display_data(self):
data = _WriteMongoFn(batch_size=10).display_data()
self.assertEqual(10, data['batch_size'])
class MongoSinkTest(unittest.TestCase):
@mock.patch('apache_beam.io.mongodbio.MongoClient')
def test_write(self, mock_client):
docs = [{'x': 1}, {'x': 2}, {'x': 3}]
_MongoSink(uri='test', db='test', coll='test').write(docs)
self.assertTrue(mock_client.return_value.__getitem__.return_value.
__getitem__.return_value.bulk_write.called)
class WriteToMongoDBTest(unittest.TestCase):
@mock.patch('apache_beam.io.mongodbio.MongoClient')
def test_write_to_mongodb_with_existing_id(self, mock_client):
id = objectid.ObjectId()
docs = [{'x': 1, '_id': id}]
expected_update = [ReplaceOne({'_id': id}, {'x': 1, '_id': id}, True, None)]
with TestPipeline() as p:
_ = (p | "Create" >> beam.Create(docs)
| "Write" >> WriteToMongoDB(db='test', coll='test'))
p.run()
mock_client.return_value.__getitem__.return_value.__getitem__.\
return_value.bulk_write.assert_called_with(expected_update)
@mock.patch('apache_beam.io.mongodbio.MongoClient')
def test_write_to_mongodb_with_generated_id(self, mock_client):
docs = [{'x': 1}]
expected_update = [
ReplaceOne({'_id': mock.ANY}, {
'x': 1,
'_id': mock.ANY
}, True, None)
]
with TestPipeline() as p:
_ = (p | "Create" >> beam.Create(docs)
| "Write" >> WriteToMongoDB(db='test', coll='test'))
p.run()
mock_client.return_value.__getitem__.return_value.__getitem__. \
return_value.bulk_write.assert_called_with(expected_update)
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()