| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # |
| |
| """This module implements IO classes to read and write data on MongoDB. |
| |
| |
| Read from MongoDB |
| ----------------- |
| :class:`ReadFromMongoDB` is a ``PTransform`` that reads from a configured |
| MongoDB source and returns a ``PCollection`` of dict representing MongoDB |
| documents. |
| To configure MongoDB source, the URI to connect to MongoDB server, database |
| name, collection name needs to be provided. |
| |
| Example usage:: |
| |
| pipeline | ReadFromMongoDB(uri='mongodb://localhost:27017', |
| db='testdb', |
| coll='input') |
| |
| To read from MongoDB Atlas, use ``bucket_auto`` option to enable |
| ``@bucketAuto`` MongoDB aggregation instead of ``splitVector`` |
| command which is a high-privilege function that cannot be assigned |
| to any user in Atlas. |
| |
| Example usage:: |
| |
| pipeline | ReadFromMongoDB(uri='mongodb+srv://user:pwd@cluster0.mongodb.net', |
| db='testdb', |
| coll='input', |
| bucket_auto=True) |
| |
| |
| Write to MongoDB: |
| ----------------- |
| :class:`WriteToMongoDB` is a ``PTransform`` that writes MongoDB documents to |
| configured sink, and the write is conducted through a mongodb bulk_write of |
| ``ReplaceOne`` operations. If the document's _id field already existed in the |
| MongoDB collection, it results in an overwrite, otherwise, a new document |
| will be inserted. |
| |
| Example usage:: |
| |
| pipeline | WriteToMongoDB(uri='mongodb://localhost:27017', |
| db='testdb', |
| coll='output', |
| batch_size=10) |
| |
| |
| No backward compatibility guarantees. Everything in this module is experimental. |
| """ |
| |
| # pytype: skip-file |
| |
| import itertools |
| import json |
| import logging |
| import math |
| import struct |
| from typing import Union |
| |
| import apache_beam as beam |
| from apache_beam.io import iobase |
| from apache_beam.io.range_trackers import LexicographicKeyRangeTracker |
| from apache_beam.io.range_trackers import OffsetRangeTracker |
| from apache_beam.io.range_trackers import OrderedPositionRangeTracker |
| from apache_beam.transforms import DoFn |
| from apache_beam.transforms import PTransform |
| from apache_beam.transforms import Reshuffle |
| |
| _LOGGER = logging.getLogger(__name__) |
| |
| try: |
| # Mongodb has its own bundled bson, which is not compatible with bson package. |
| # (https://github.com/py-bson/bson/issues/82). Try to import objectid and if |
| # it fails because bson package is installed, MongoDB IO will not work but at |
| # least rest of the SDK will work. |
| from bson import json_util |
| from bson import objectid |
| from bson.objectid import ObjectId |
| |
| # pymongo also internally depends on bson. |
| from pymongo import ASCENDING |
| from pymongo import DESCENDING |
| from pymongo import MongoClient |
| from pymongo import ReplaceOne |
| except ImportError: |
| objectid = None |
| json_util = None |
| ObjectId = None |
| ASCENDING = 1 |
| DESCENDING = -1 |
| MongoClient = None |
| ReplaceOne = None |
| _LOGGER.warning("Could not find a compatible bson package.") |
| |
| __all__ = ["ReadFromMongoDB", "WriteToMongoDB"] |
| |
| |
| class ReadFromMongoDB(PTransform): |
| """A ``PTransform`` to read MongoDB documents into a ``PCollection``.""" |
| def __init__( |
| self, |
| uri="mongodb://localhost:27017", |
| db=None, |
| coll=None, |
| filter=None, |
| projection=None, |
| extra_client_params=None, |
| bucket_auto=False, |
| ): |
| """Initialize a :class:`ReadFromMongoDB` |
| |
| Args: |
| uri (str): The MongoDB connection string following the URI format. |
| db (str): The MongoDB database name. |
| coll (str): The MongoDB collection name. |
| filter: A `bson.SON |
| <https://api.mongodb.com/python/current/api/bson/son.html>`_ object |
| specifying elements which must be present for a document to be included |
| in the result set. |
| projection: A list of field names that should be returned in the result |
| set or a dict specifying the fields to include or exclude. |
| extra_client_params(dict): Optional `MongoClient |
| <https://api.mongodb.com/python/current/api/pymongo/mongo_client.html>`_ |
| parameters. |
| bucket_auto (bool): If :data:`True`, use MongoDB `$bucketAuto` aggregation |
| to split collection into bundles instead of `splitVector` command, |
| which does not work with MongoDB Atlas. |
| If :data:`False` (the default), use `splitVector` command for bundling. |
| |
| Returns: |
| :class:`~apache_beam.transforms.ptransform.PTransform` |
| """ |
| if extra_client_params is None: |
| extra_client_params = {} |
| if not isinstance(db, str): |
| raise ValueError("ReadFromMongDB db param must be specified as a string") |
| if not isinstance(coll, str): |
| raise ValueError( |
| "ReadFromMongDB coll param must be specified as a string") |
| self._mongo_source = _BoundedMongoSource( |
| uri=uri, |
| db=db, |
| coll=coll, |
| filter=filter, |
| projection=projection, |
| extra_client_params=extra_client_params, |
| bucket_auto=bucket_auto, |
| ) |
| |
| def expand(self, pcoll): |
| return pcoll | iobase.Read(self._mongo_source) |
| |
| |
| class _ObjectIdRangeTracker(OrderedPositionRangeTracker): |
| """RangeTracker for tracking mongodb _id of bson ObjectId type.""" |
| def position_to_fraction( |
| self, |
| pos: ObjectId, |
| start: ObjectId, |
| end: ObjectId, |
| ): |
| """Returns the fraction of keys in the range [start, end) that |
| are less than the given key. |
| """ |
| pos_number = _ObjectIdHelper.id_to_int(pos) |
| start_number = _ObjectIdHelper.id_to_int(start) |
| end_number = _ObjectIdHelper.id_to_int(end) |
| return (pos_number - start_number) / (end_number - start_number) |
| |
| def fraction_to_position( |
| self, |
| fraction: float, |
| start: ObjectId, |
| end: ObjectId, |
| ): |
| """Converts a fraction between 0 and 1 |
| to a position between start and end. |
| """ |
| start_number = _ObjectIdHelper.id_to_int(start) |
| end_number = _ObjectIdHelper.id_to_int(end) |
| total = end_number - start_number |
| pos = int(total * fraction + start_number) |
| # make sure split position is larger than start position and smaller than |
| # end position. |
| if pos <= start_number: |
| return _ObjectIdHelper.increment_id(start, 1) |
| |
| if pos >= end_number: |
| return _ObjectIdHelper.increment_id(end, -1) |
| |
| return _ObjectIdHelper.int_to_id(pos) |
| |
| |
| class _BoundedMongoSource(iobase.BoundedSource): |
| """A MongoDB source that reads a finite amount of input records. |
| |
| This class defines following operations which can be used to read |
| MongoDB source efficiently. |
| |
| * Size estimation - method ``estimate_size()`` may return an accurate |
| estimation in bytes for the size of the source. |
| * Splitting into bundles of a given size - method ``split()`` can be used to |
| split the source into a set of sub-sources (bundles) based on a desired |
| bundle size. |
| * Getting a RangeTracker - method ``get_range_tracker()`` should return a |
| ``RangeTracker`` object for a given position range for the position type |
| of the records returned by the source. |
| * Reading the data - method ``read()`` can be used to read data from the |
| source while respecting the boundaries defined by a given |
| ``RangeTracker``. |
| |
| A runner will perform reading the source in two steps. |
| |
| (1) Method ``get_range_tracker()`` will be invoked with start and end |
| positions to obtain a ``RangeTracker`` for the range of positions the |
| runner intends to read. Source must define a default initial start and end |
| position range. These positions must be used if the start and/or end |
| positions passed to the method ``get_range_tracker()`` are ``None`` |
| (2) Method read() will be invoked with the ``RangeTracker`` obtained in the |
| previous step. |
| |
| **Mutability** |
| |
| A ``_BoundedMongoSource`` object should not be mutated while |
| its methods (for example, ``read()``) are being invoked by a runner. Runner |
| implementations may invoke methods of ``_BoundedMongoSource`` objects through |
| multi-threaded and/or reentrant execution modes. |
| """ |
| def __init__( |
| self, |
| uri=None, |
| db=None, |
| coll=None, |
| filter=None, |
| projection=None, |
| extra_client_params=None, |
| bucket_auto=False, |
| ): |
| if extra_client_params is None: |
| extra_client_params = {} |
| if filter is None: |
| filter = {} |
| self.uri = uri |
| self.db = db |
| self.coll = coll |
| self.filter = filter |
| self.projection = projection |
| self.spec = extra_client_params |
| self.bucket_auto = bucket_auto |
| |
| def estimate_size(self): |
| with MongoClient(self.uri, **self.spec) as client: |
| return client[self.db].command("collstats", self.coll).get("size") |
| |
| def _estimate_average_document_size(self): |
| with MongoClient(self.uri, **self.spec) as client: |
| return client[self.db].command("collstats", self.coll).get("avgObjSize") |
| |
| def split( |
| self, |
| desired_bundle_size: int, |
| start_position: Union[int, str, bytes, ObjectId] = None, |
| stop_position: Union[int, str, bytes, ObjectId] = None, |
| ): |
| """Splits the source into a set of bundles. |
| |
| Bundles should be approximately of size ``desired_bundle_size`` bytes. |
| |
| Args: |
| desired_bundle_size: the desired size (in bytes) of the bundles returned. |
| start_position: if specified the given position must be used as the |
| starting position of the first bundle. |
| stop_position: if specified the given position must be used as the ending |
| position of the last bundle. |
| Returns: |
| an iterator of objects of type 'SourceBundle' that gives information about |
| the generated bundles. |
| """ |
| |
| desired_bundle_size_in_mb = desired_bundle_size // 1024 // 1024 |
| |
| # for desired bundle size, if desired chunk size smaller than 1mb, use |
| # MongoDB default split size of 1mb. |
| desired_bundle_size_in_mb = max(desired_bundle_size_in_mb, 1) |
| |
| is_initial_split = start_position is None and stop_position is None |
| start_position, stop_position = self._replace_none_positions( |
| start_position, stop_position |
| ) |
| |
| if self.bucket_auto: |
| # Use $bucketAuto for bundling |
| split_keys = [] |
| weights = [] |
| for bucket in self._get_auto_buckets( |
| desired_bundle_size_in_mb, |
| start_position, |
| stop_position, |
| is_initial_split, |
| ): |
| split_keys.append({"_id": bucket["_id"]["max"]}) |
| weights.append(bucket["count"]) |
| else: |
| # Use splitVector for bundling |
| split_keys = self._get_split_keys( |
| desired_bundle_size_in_mb, start_position, stop_position) |
| weights = itertools.cycle((desired_bundle_size_in_mb, )) |
| |
| bundle_start = start_position |
| for split_key_id, weight in zip(split_keys, weights): |
| if bundle_start >= stop_position: |
| break |
| bundle_end = min(stop_position, split_key_id["_id"]) |
| yield iobase.SourceBundle( |
| weight=weight, |
| source=self, |
| start_position=bundle_start, |
| stop_position=bundle_end, |
| ) |
| bundle_start = bundle_end |
| # add range of last split_key to stop_position |
| if bundle_start < stop_position: |
| # bucket_auto mode can come here if not split due to single document |
| weight = 1 if self.bucket_auto else desired_bundle_size_in_mb |
| yield iobase.SourceBundle( |
| weight=weight, |
| source=self, |
| start_position=bundle_start, |
| stop_position=stop_position, |
| ) |
| |
| def get_range_tracker( |
| self, |
| start_position: Union[int, str, ObjectId] = None, |
| stop_position: Union[int, str, ObjectId] = None, |
| ) -> Union[ |
| _ObjectIdRangeTracker, OffsetRangeTracker, LexicographicKeyRangeTracker]: |
| """Returns a RangeTracker for a given position range depending on type. |
| |
| Args: |
| start_position: starting position of the range. If 'None' default start |
| position of the source must be used. |
| stop_position: ending position of the range. If 'None' default stop |
| position of the source must be used. |
| Returns: |
| a ``_ObjectIdRangeTracker``, ``OffsetRangeTracker`` |
| or ``LexicographicKeyRangeTracker`` depending on the given position range. |
| """ |
| start_position, stop_position = self._replace_none_positions( |
| start_position, stop_position |
| ) |
| |
| if isinstance(start_position, ObjectId): |
| return _ObjectIdRangeTracker(start_position, stop_position) |
| |
| if isinstance(start_position, int): |
| return OffsetRangeTracker(start_position, stop_position) |
| |
| if isinstance(start_position, str): |
| return LexicographicKeyRangeTracker(start_position, stop_position) |
| |
| raise NotImplementedError( |
| f"RangeTracker for {type(start_position)} not implemented!") |
| |
| def read(self, range_tracker): |
| """Returns an iterator that reads data from the source. |
| |
| The returned set of data must respect the boundaries defined by the given |
| ``RangeTracker`` object. For example: |
| |
| * Returned set of data must be for the range |
| ``[range_tracker.start_position, range_tracker.stop_position)``. Note |
| that a source may decide to return records that start after |
| ``range_tracker.stop_position``. See documentation in class |
| ``RangeTracker`` for more details. Also, note that framework might |
| invoke ``range_tracker.try_split()`` to perform dynamic split |
| operations. range_tracker.stop_position may be updated |
| dynamically due to successful dynamic split operations. |
| * Method ``range_tracker.try_split()`` must be invoked for every record |
| that starts at a split point. |
| * Method ``range_tracker.record_current_position()`` may be invoked for |
| records that do not start at split points. |
| |
| Args: |
| range_tracker: a ``RangeTracker`` whose boundaries must be respected |
| when reading data from the source. A runner that reads this |
| source muss pass a ``RangeTracker`` object that is not |
| ``None``. |
| Returns: |
| an iterator of data read by the source. |
| """ |
| with MongoClient(self.uri, **self.spec) as client: |
| all_filters = self._merge_id_filter( |
| range_tracker.start_position(), range_tracker.stop_position()) |
| docs_cursor = ( |
| client[self.db][self.coll].find( |
| filter=all_filters, |
| projection=self.projection).sort([("_id", ASCENDING)])) |
| for doc in docs_cursor: |
| if not range_tracker.try_claim(doc["_id"]): |
| return |
| yield doc |
| |
| def display_data(self): |
| """Returns the display data associated to a pipeline component.""" |
| res = super().display_data() |
| res["database"] = self.db |
| res["collection"] = self.coll |
| res["filter"] = json.dumps(self.filter, default=json_util.default) |
| res["projection"] = str(self.projection) |
| res["bucket_auto"] = self.bucket_auto |
| return res |
| |
| @staticmethod |
| def _range_is_not_splittable( |
| start_pos: Union[int, str, ObjectId], |
| end_pos: Union[int, str, ObjectId], |
| ): |
| """Return `True` if splitting range doesn't make sense |
| (single document is not splittable), |
| Return `False` otherwise. |
| """ |
| return (( |
| isinstance(start_pos, ObjectId) and |
| start_pos >= _ObjectIdHelper.increment_id(end_pos, -1)) or |
| (isinstance(start_pos, int) and start_pos >= end_pos - 1) or |
| (isinstance(start_pos, str) and start_pos >= end_pos)) |
| |
| def _get_split_keys( |
| self, |
| desired_chunk_size_in_mb: int, |
| start_pos: Union[int, str, ObjectId], |
| end_pos: Union[int, str, ObjectId], |
| ): |
| """Calls MongoDB `splitVector` command |
| to get document ids at split position. |
| """ |
| # single document not splittable |
| if self._range_is_not_splittable(start_pos, end_pos): |
| return [] |
| |
| with MongoClient(self.uri, **self.spec) as client: |
| name_space = "%s.%s" % (self.db, self.coll) |
| return client[self.db].command( |
| "splitVector", |
| name_space, |
| keyPattern={"_id": 1}, # Ascending index |
| min={"_id": start_pos}, |
| max={"_id": end_pos}, |
| maxChunkSize=desired_chunk_size_in_mb, |
| )["splitKeys"] |
| |
| def _get_auto_buckets( |
| self, |
| desired_chunk_size_in_mb: int, |
| start_pos: Union[int, str, ObjectId], |
| end_pos: Union[int, str, ObjectId], |
| is_initial_split: bool, |
| ) -> list: |
| """Use MongoDB `$bucketAuto` aggregation to split collection into bundles |
| instead of `splitVector` command, which does not work with MongoDB Atlas. |
| """ |
| # single document not splittable |
| if self._range_is_not_splittable(start_pos, end_pos): |
| return [] |
| |
| if is_initial_split and not self.filter: |
| # total collection size in MB |
| size_in_mb = self.estimate_size() / float(1 << 20) |
| else: |
| # size of documents within start/end id range and possibly filtered |
| documents_count = self._count_id_range(start_pos, end_pos) |
| avg_document_size = self._estimate_average_document_size() |
| size_in_mb = documents_count * avg_document_size / float(1 << 20) |
| |
| if size_in_mb == 0: |
| # no documents not splittable (maybe a result of filtering) |
| return [] |
| |
| bucket_count = math.ceil(size_in_mb / desired_chunk_size_in_mb) |
| with beam.io.mongodbio.MongoClient(self.uri, **self.spec) as client: |
| pipeline = [ |
| { |
| # filter by positions and by the custom filter if any |
| "$match": self._merge_id_filter(start_pos, end_pos) |
| }, |
| { |
| "$bucketAuto": { |
| "groupBy": "$_id", "buckets": bucket_count |
| } |
| }, |
| ] |
| buckets = list( |
| # Use `allowDiskUse` option to avoid aggregation limit of 100 Mb RAM |
| client[self.db][self.coll].aggregate(pipeline, allowDiskUse=True)) |
| if buckets: |
| buckets[-1]["_id"]["max"] = end_pos |
| |
| return buckets |
| |
| def _merge_id_filter( |
| self, |
| start_position: Union[int, str, bytes, ObjectId], |
| stop_position: Union[int, str, bytes, ObjectId] = None, |
| ) -> dict: |
| """Merge the default filter (if any) with refined _id field range |
| of range_tracker. |
| $gte specifies start position (inclusive) |
| and $lt specifies the end position (exclusive), |
| see more at |
| https://docs.mongodb.com/manual/reference/operator/query/gte/ and |
| https://docs.mongodb.com/manual/reference/operator/query/lt/ |
| """ |
| |
| if stop_position is None: |
| id_filter = {"_id": {"$gte": start_position}} |
| else: |
| id_filter = {"_id": {"$gte": start_position, "$lt": stop_position}} |
| |
| if self.filter: |
| all_filters = { |
| # see more at |
| # https://docs.mongodb.com/manual/reference/operator/query/and/ |
| "$and": [self.filter.copy(), id_filter] |
| } |
| else: |
| all_filters = id_filter |
| |
| return all_filters |
| |
| def _get_head_document_id(self, sort_order): |
| with MongoClient(self.uri, **self.spec) as client: |
| cursor = ( |
| client[self.db][self.coll].find(filter={}, projection=[]).sort([ |
| ("_id", sort_order) |
| ]).limit(1)) |
| try: |
| return cursor[0]["_id"] |
| |
| except IndexError: |
| raise ValueError("Empty Mongodb collection") |
| |
| def _replace_none_positions(self, start_position, stop_position): |
| |
| if start_position is None: |
| start_position = self._get_head_document_id(ASCENDING) |
| if stop_position is None: |
| last_doc_id = self._get_head_document_id(DESCENDING) |
| # increment last doc id binary value by 1 to make sure the last document |
| # is not excluded |
| if isinstance(last_doc_id, ObjectId): |
| stop_position = _ObjectIdHelper.increment_id(last_doc_id, 1) |
| elif isinstance(last_doc_id, int): |
| stop_position = last_doc_id + 1 |
| elif isinstance(last_doc_id, str): |
| stop_position = last_doc_id + '\x00' |
| |
| return start_position, stop_position |
| |
| def _count_id_range(self, start_position, stop_position): |
| """Number of documents between start_position (inclusive) |
| and stop_position (exclusive), respecting the custom filter if any. |
| """ |
| with MongoClient(self.uri, **self.spec) as client: |
| return client[self.db][self.coll].count_documents( |
| filter=self._merge_id_filter(start_position, stop_position)) |
| |
| |
| class _ObjectIdHelper: |
| """A Utility class to manipulate bson object ids.""" |
| @classmethod |
| def id_to_int(cls, _id: Union[int, ObjectId]) -> int: |
| """ |
| Args: |
| _id: ObjectId required for each MongoDB document _id field. |
| |
| Returns: Converted integer value of ObjectId's 12 bytes binary value. |
| """ |
| if isinstance(_id, int): |
| return _id |
| |
| # converts object id binary to integer |
| # id object is bytes type with size of 12 |
| ints = struct.unpack(">III", _id.binary) |
| return (ints[0] << 64) + (ints[1] << 32) + ints[2] |
| |
| @classmethod |
| def int_to_id(cls, number): |
| """ |
| Args: |
| number(int): The integer value to be used to convert to ObjectId. |
| |
| Returns: The ObjectId that has the 12 bytes binary converted from the |
| integer value. |
| """ |
| # converts integer value to object id. Int value should be less than |
| # (2 ^ 96) so it can be convert to 12 bytes required by object id. |
| if number < 0 or number >= (1 << 96): |
| raise ValueError("number value must be within [0, %s)" % (1 << 96)) |
| ints = [ |
| (number & 0xFFFFFFFF0000000000000000) >> 64, |
| (number & 0x00000000FFFFFFFF00000000) >> 32, |
| number & 0x0000000000000000FFFFFFFF, |
| ] |
| |
| number_bytes = struct.pack(">III", *ints) |
| return ObjectId(number_bytes) |
| |
| @classmethod |
| def increment_id( |
| cls, |
| _id: ObjectId, |
| inc: int, |
| ) -> ObjectId: |
| """ |
| Increment object_id binary value by inc value and return new object id. |
| |
| Args: |
| _id: The `_id` to change. |
| inc(int): The incremental int value to be added to `_id`. |
| |
| Returns: |
| `_id` incremented by `inc` value |
| """ |
| id_number = _ObjectIdHelper.id_to_int(_id) |
| new_number = id_number + inc |
| if new_number < 0 or new_number >= (1 << 96): |
| raise ValueError( |
| "invalid incremental, inc value must be within [" |
| "%s, %s)" % (0 - id_number, 1 << 96 - id_number)) |
| return _ObjectIdHelper.int_to_id(new_number) |
| |
| |
| class WriteToMongoDB(PTransform): |
| """WriteToMongoDB is a ``PTransform`` that writes a ``PCollection`` of |
| mongodb document to the configured MongoDB server. |
| |
| In order to make the document writes idempotent so that the bundles are |
| retry-able without creating duplicates, the PTransform added 2 transformations |
| before final write stage: |
| a ``GenerateId`` transform and a ``Reshuffle`` transform.:: |
| |
| ----------------------------------------------- |
| Pipeline --> |GenerateId --> Reshuffle --> WriteToMongoSink| |
| ----------------------------------------------- |
| (WriteToMongoDB) |
| |
| The ``GenerateId`` transform adds a random and unique*_id* field to the |
| documents if they don't already have one, it uses the same format as MongoDB |
| default. The ``Reshuffle`` transform makes sure that no fusion happens between |
| ``GenerateId`` and the final write stage transform,so that the set of |
| documents and their unique IDs are not regenerated if final write step is |
| retried due to a failure. This prevents duplicate writes of the same document |
| with different unique IDs. |
| |
| """ |
| def __init__( |
| self, |
| uri="mongodb://localhost:27017", |
| db=None, |
| coll=None, |
| batch_size=100, |
| extra_client_params=None, |
| ): |
| """ |
| |
| Args: |
| uri (str): The MongoDB connection string following the URI format |
| db (str): The MongoDB database name |
| coll (str): The MongoDB collection name |
| batch_size(int): Number of documents per bulk_write to MongoDB, |
| default to 100 |
| extra_client_params(dict): Optional `MongoClient |
| <https://api.mongodb.com/python/current/api/pymongo/mongo_client.html>`_ |
| parameters as keyword arguments |
| |
| Returns: |
| :class:`~apache_beam.transforms.ptransform.PTransform` |
| |
| """ |
| if extra_client_params is None: |
| extra_client_params = {} |
| if not isinstance(db, str): |
| raise ValueError("WriteToMongoDB db param must be specified as a string") |
| if not isinstance(coll, str): |
| raise ValueError( |
| "WriteToMongoDB coll param must be specified as a string") |
| self._uri = uri |
| self._db = db |
| self._coll = coll |
| self._batch_size = batch_size |
| self._spec = extra_client_params |
| |
| def expand(self, pcoll): |
| return ( |
| pcoll |
| | beam.ParDo(_GenerateObjectIdFn()) |
| | Reshuffle() |
| | beam.ParDo( |
| _WriteMongoFn( |
| self._uri, self._db, self._coll, self._batch_size, self._spec))) |
| |
| |
| class _GenerateObjectIdFn(DoFn): |
| def process(self, element, *args, **kwargs): |
| # if _id field already exist we keep it as it is, otherwise the ptransform |
| # generates a new _id field to achieve idempotent write to mongodb. |
| if "_id" not in element: |
| # object.ObjectId() generates a unique identifier that follows mongodb |
| # default format, if _id is not present in document, mongodb server |
| # generates it with this same function upon write. However the |
| # uniqueness of generated id may not be guaranteed if the work load are |
| # distributed across too many processes. See more on the ObjectId format |
| # https://docs.mongodb.com/manual/reference/bson-types/#objectid. |
| element["_id"] = objectid.ObjectId() |
| |
| yield element |
| |
| |
| class _WriteMongoFn(DoFn): |
| def __init__( |
| self, uri=None, db=None, coll=None, batch_size=100, extra_params=None): |
| if extra_params is None: |
| extra_params = {} |
| self.uri = uri |
| self.db = db |
| self.coll = coll |
| self.spec = extra_params |
| self.batch_size = batch_size |
| self.batch = [] |
| |
| def finish_bundle(self): |
| self._flush() |
| |
| def process(self, element, *args, **kwargs): |
| self.batch.append(element) |
| if len(self.batch) >= self.batch_size: |
| self._flush() |
| |
| def _flush(self): |
| if len(self.batch) == 0: |
| return |
| with _MongoSink(self.uri, self.db, self.coll, self.spec) as sink: |
| sink.write(self.batch) |
| self.batch = [] |
| |
| def display_data(self): |
| res = super().display_data() |
| res["database"] = self.db |
| res["collection"] = self.coll |
| res["batch_size"] = self.batch_size |
| return res |
| |
| |
| class _MongoSink: |
| def __init__(self, uri=None, db=None, coll=None, extra_params=None): |
| if extra_params is None: |
| extra_params = {} |
| self.uri = uri |
| self.db = db |
| self.coll = coll |
| self.spec = extra_params |
| self.client = None |
| |
| def write(self, documents): |
| if self.client is None: |
| self.client = MongoClient(host=self.uri, **self.spec) |
| requests = [] |
| for doc in documents: |
| # match document based on _id field, if not found in current collection, |
| # insert new one, otherwise overwrite it. |
| requests.append( |
| ReplaceOne( |
| filter={"_id": doc.get("_id", None)}, |
| replacement=doc, |
| upsert=True)) |
| resp = self.client[self.db][self.coll].bulk_write(requests) |
| _LOGGER.debug( |
| "BulkWrite to MongoDB result in nModified:%d, nUpserted:%d, " |
| "nMatched:%d, Errors:%s" % ( |
| resp.modified_count, |
| resp.upserted_count, |
| resp.matched_count, |
| resp.bulk_api_result.get("writeErrors"), |
| )) |
| |
| def __enter__(self): |
| if self.client is None: |
| self.client = MongoClient(host=self.uri, **self.spec) |
| return self |
| |
| def __exit__(self, exc_type, exc_val, exc_tb): |
| if self.client is not None: |
| self.client.close() |