blob: db3646b994c8ecf1be2e2653a66d6f4dbe198177 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""This module implements IO classes to read and write data on MongoDB.
Read from MongoDB
-----------------
:class:`ReadFromMongoDB` is a ``PTransform`` that reads from a configured
MongoDB source and returns a ``PCollection`` of dict representing MongoDB
documents.
To configure MongoDB source, the URI to connect to MongoDB server, database
name, collection name needs to be provided.
Example usage::
pipeline | ReadFromMongoDB(uri='mongodb://localhost:27017',
db='testdb',
coll='input')
Write to MongoDB:
-----------------
:class:`WriteToMongoDB` is a ``PTransform`` that writes MongoDB documents to
configured sink, and the write is conducted through a mongodb bulk_write of
``ReplaceOne`` operations. If the document's _id field already existed in the
MongoDB collection, it results in an overwrite, otherwise, a new document
will be inserted.
Example usage::
pipeline | WriteToMongoDB(uri='mongodb://localhost:27017',
db='testdb',
coll='output',
batch_size=10)
No backward compatibility guarantees. Everything in this module is experimental.
"""
from __future__ import absolute_import
import logging
from bson import objectid
from pymongo import MongoClient
from pymongo import ReplaceOne
import apache_beam as beam
from apache_beam.io import iobase
from apache_beam.io.range_trackers import OffsetRangeTracker
from apache_beam.transforms import DoFn
from apache_beam.transforms import PTransform
from apache_beam.transforms import Reshuffle
from apache_beam.utils.annotations import experimental
__all__ = ['ReadFromMongoDB', 'WriteToMongoDB']
@experimental()
class ReadFromMongoDB(PTransform):
"""A ``PTransfrom`` to read MongoDB documents into a ``PCollection``.
"""
def __init__(self,
uri='mongodb://localhost:27017',
db=None,
coll=None,
filter=None,
projection=None,
extra_client_params=None):
"""Initialize a :class:`ReadFromMongoDB`
Args:
uri (str): The MongoDB connection string following the URI format
db (str): The MongoDB database name
coll (str): The MongoDB collection name
filter: A `bson.SON
<https://api.mongodb.com/python/current/api/bson/son.html>`_ object
specifying elements which must be present for a document to be included
in the result set
projection: A list of field names that should be returned in the result
set or a dict specifying the fields to include or exclude
extra_client_params(dict): Optional `MongoClient
<https://api.mongodb.com/python/current/api/pymongo/mongo_client.html>`_
parameters
Returns:
:class:`~apache_beam.transforms.ptransform.PTransform`
"""
if extra_client_params is None:
extra_client_params = {}
if not isinstance(db, str):
raise ValueError('ReadFromMongDB db param must be specified as a string')
if not isinstance(coll, str):
raise ValueError('ReadFromMongDB coll param must be specified as a '
'string')
self._mongo_source = _BoundedMongoSource(
uri=uri,
db=db,
coll=coll,
filter=filter,
projection=projection,
extra_client_params=extra_client_params)
def expand(self, pcoll):
return pcoll | iobase.Read(self._mongo_source)
class _BoundedMongoSource(iobase.BoundedSource):
def __init__(self,
uri=None,
db=None,
coll=None,
filter=None,
projection=None,
extra_client_params=None):
if extra_client_params is None:
extra_client_params = {}
if filter is None:
filter = {}
self.uri = uri
self.db = db
self.coll = coll
self.filter = filter
self.projection = projection
self.spec = extra_client_params
self.doc_count = self._get_document_count()
self.avg_doc_size = self._get_avg_document_size()
self.client = None
def estimate_size(self):
return self.avg_doc_size * self.doc_count
def split(self, desired_bundle_size, start_position=None, stop_position=None):
# use document cursor index as the start and stop positions
if start_position is None:
start_position = 0
if stop_position is None:
stop_position = self.doc_count
# get an estimate on how many documents should be included in a split batch
desired_bundle_count = desired_bundle_size // self.avg_doc_size
bundle_start = start_position
while bundle_start < stop_position:
bundle_end = min(stop_position, bundle_start + desired_bundle_count)
yield iobase.SourceBundle(weight=bundle_end - bundle_start,
source=self,
start_position=bundle_start,
stop_position=bundle_end)
bundle_start = bundle_end
def get_range_tracker(self, start_position, stop_position):
if start_position is None:
start_position = 0
if stop_position is None:
stop_position = self.doc_count
return OffsetRangeTracker(start_position, stop_position)
def read(self, range_tracker):
with MongoClient(self.uri, **self.spec) as client:
# docs is a MongoDB Cursor
docs = client[self.db][self.coll].find(
filter=self.filter, projection=self.projection
)[range_tracker.start_position():range_tracker.stop_position()]
for index in range(range_tracker.start_position(),
range_tracker.stop_position()):
if not range_tracker.try_claim(index):
return
yield docs[index - range_tracker.start_position()]
def display_data(self):
res = super(_BoundedMongoSource, self).display_data()
res['uri'] = self.uri
res['database'] = self.db
res['collection'] = self.coll
res['filter'] = self.filter
res['project'] = self.projection
res['mongo_client_spec'] = self.spec
return res
def _get_avg_document_size(self):
with MongoClient(self.uri, **self.spec) as client:
size = client[self.db].command('collstats', self.coll).get('avgObjSize')
if size is None or size <= 0:
raise ValueError(
'Collection %s not found or average doc size is '
'incorrect', self.coll)
return size
def _get_document_count(self):
with MongoClient(self.uri, **self.spec) as client:
return max(client[self.db][self.coll].count_documents(self.filter), 0)
@experimental()
class WriteToMongoDB(PTransform):
"""WriteToMongoDB is a ``PTransform`` that writes a ``PCollection`` of
mongodb document to the configured MongoDB server.
In order to make the document writes idempotent so that the bundles are
retry-able without creating duplicates, the PTransform added 2 transformations
before final write stage:
a ``GenerateId`` transform and a ``Reshuffle`` transform.::
-----------------------------------------------
Pipeline --> |GenerateId --> Reshuffle --> WriteToMongoSink|
-----------------------------------------------
(WriteToMongoDB)
The ``GenerateId`` transform adds a random and unique*_id* field to the
documents if they don't already have one, it uses the same format as MongoDB
default. The ``Reshuffle`` transform makes sure that no fusion happens between
``GenerateId`` and the final write stage transform,so that the set of
documents and their unique IDs are not regenerated if final write step is
retried due to a failure. This prevents duplicate writes of the same document
with different unique IDs.
"""
def __init__(self,
uri='mongodb://localhost:27017',
db=None,
coll=None,
batch_size=100,
extra_client_params=None):
"""
Args:
uri (str): The MongoDB connection string following the URI format
db (str): The MongoDB database name
coll (str): The MongoDB collection name
batch_size(int): Number of documents per bulk_write to MongoDB,
default to 100
extra_client_params(dict): Optional `MongoClient
<https://api.mongodb.com/python/current/api/pymongo/mongo_client.html>`_
parameters as keyword arguments
Returns:
:class:`~apache_beam.transforms.ptransform.PTransform`
"""
if extra_client_params is None:
extra_client_params = {}
if not isinstance(db, str):
raise ValueError('WriteToMongoDB db param must be specified as a string')
if not isinstance(coll, str):
raise ValueError('WriteToMongoDB coll param must be specified as a '
'string')
self._uri = uri
self._db = db
self._coll = coll
self._batch_size = batch_size
self._spec = extra_client_params
def expand(self, pcoll):
return pcoll \
| beam.ParDo(_GenerateObjectIdFn()) \
| Reshuffle() \
| beam.ParDo(_WriteMongoFn(self._uri, self._db, self._coll,
self._batch_size, self._spec))
class _GenerateObjectIdFn(DoFn):
def process(self, element, *args, **kwargs):
# if _id field already exist we keep it as it is, otherwise the ptransform
# generates a new _id field to achieve idempotent write to mongodb.
if '_id' not in element:
# object.ObjectId() generates a unique identifier that follows mongodb
# default format, if _id is not present in document, mongodb server
# generates it with this same function upon write. However the
# uniqueness of generated id may not be guaranteed if the work load are
# distributed across too many processes. See more on the ObjectId format
# https://docs.mongodb.com/manual/reference/bson-types/#objectid.
element['_id'] = objectid.ObjectId()
yield element
class _WriteMongoFn(DoFn):
def __init__(self,
uri=None,
db=None,
coll=None,
batch_size=100,
extra_params=None):
if extra_params is None:
extra_params = {}
self.uri = uri
self.db = db
self.coll = coll
self.spec = extra_params
self.batch_size = batch_size
self.batch = []
def finish_bundle(self):
self._flush()
def process(self, element, *args, **kwargs):
self.batch.append(element)
if len(self.batch) >= self.batch_size:
self._flush()
def _flush(self):
if len(self.batch) == 0:
return
with _MongoSink(self.uri, self.db, self.coll, self.spec) as sink:
sink.write(self.batch)
self.batch = []
def display_data(self):
res = super(_WriteMongoFn, self).display_data()
res['uri'] = self.uri
res['database'] = self.db
res['collection'] = self.coll
res['mongo_client_params'] = self.spec
res['batch_size'] = self.batch_size
return res
class _MongoSink(object):
def __init__(self, uri=None, db=None, coll=None, extra_params=None):
if extra_params is None:
extra_params = {}
self.uri = uri
self.db = db
self.coll = coll
self.spec = extra_params
self.client = None
def write(self, documents):
if self.client is None:
self.client = MongoClient(host=self.uri, **self.spec)
requests = []
for doc in documents:
# match document based on _id field, if not found in current collection,
# insert new one, otherwise overwrite it.
requests.append(
ReplaceOne(filter={'_id': doc.get('_id', None)},
replacement=doc,
upsert=True))
resp = self.client[self.db][self.coll].bulk_write(requests)
logging.debug('BulkWrite to MongoDB result in nModified:%d, nUpserted:%d, '
'nMatched:%d, Errors:%s' %
(resp.modified_count, resp.upserted_count, resp.matched_count,
resp.bulk_api_result.get('writeErrors')))
def __enter__(self):
if self.client is None:
self.client = MongoClient(host=self.uri, **self.spec)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
if self.client is not None:
self.client.close()