Merge pull request #9342 from [BEAM-7866][BEAM-5148] Cherry-picks mongodb fixes to 2.15.0 release branch

diff --git a/sdks/python/apache_beam/io/mongodbio.py b/sdks/python/apache_beam/io/mongodbio.py
index db3646b..57e0c7c 100644
--- a/sdks/python/apache_beam/io/mongodbio.py
+++ b/sdks/python/apache_beam/io/mongodbio.py
@@ -52,21 +52,35 @@
 """
 
 from __future__ import absolute_import
+from __future__ import division
 
 import logging
-
-from bson import objectid
-from pymongo import MongoClient
-from pymongo import ReplaceOne
+import struct
 
 import apache_beam as beam
 from apache_beam.io import iobase
-from apache_beam.io.range_trackers import OffsetRangeTracker
+from apache_beam.io.range_trackers import OrderedPositionRangeTracker
 from apache_beam.transforms import DoFn
 from apache_beam.transforms import PTransform
 from apache_beam.transforms import Reshuffle
 from apache_beam.utils.annotations import experimental
 
+try:
+  # Mongodb has its own bundled bson, which is not compatible with bson pakcage.
+  # (https://github.com/py-bson/bson/issues/82). Try to import objectid and if
+  # it fails because bson package is installed, MongoDB IO will not work but at
+  # least rest of the SDK will work.
+  from bson import objectid
+
+  # pymongo also internally depends on bson.
+  from pymongo import ASCENDING
+  from pymongo import DESCENDING
+  from pymongo import MongoClient
+  from pymongo import ReplaceOne
+except ImportError:
+  objectid = None
+  logging.warning("Could not find a compatible bson package.")
+
 __all__ = ['ReadFromMongoDB', 'WriteToMongoDB']
 
 
@@ -139,50 +153,49 @@
     self.filter = filter
     self.projection = projection
     self.spec = extra_client_params
-    self.doc_count = self._get_document_count()
-    self.avg_doc_size = self._get_avg_document_size()
-    self.client = None
 
   def estimate_size(self):
-    return self.avg_doc_size * self.doc_count
+    with MongoClient(self.uri, **self.spec) as client:
+      return client[self.db].command('collstats', self.coll).get('size')
 
   def split(self, desired_bundle_size, start_position=None, stop_position=None):
-    # use document cursor index as the start and stop positions
-    if start_position is None:
-      start_position = 0
-    if stop_position is None:
-      stop_position = self.doc_count
+    start_position, stop_position = self._replace_none_positions(
+        start_position, stop_position)
 
-    # get an estimate on how many documents should be included in a split batch
-    desired_bundle_count = desired_bundle_size // self.avg_doc_size
+    desired_bundle_size_in_mb = desired_bundle_size // 1024 // 1024
+    split_keys = self._get_split_keys(desired_bundle_size_in_mb, start_position,
+                                      stop_position)
 
     bundle_start = start_position
-    while bundle_start < stop_position:
-      bundle_end = min(stop_position, bundle_start + desired_bundle_count)
-      yield iobase.SourceBundle(weight=bundle_end - bundle_start,
+    for split_key_id in split_keys:
+      if bundle_start >= stop_position:
+        break
+      bundle_end = min(stop_position, split_key_id)
+      yield iobase.SourceBundle(weight=desired_bundle_size_in_mb,
                                 source=self,
                                 start_position=bundle_start,
                                 stop_position=bundle_end)
       bundle_start = bundle_end
+    # add range of last split_key to stop_position
+    if bundle_start < stop_position:
+      yield iobase.SourceBundle(weight=desired_bundle_size_in_mb,
+                                source=self,
+                                start_position=bundle_start,
+                                stop_position=stop_position)
 
   def get_range_tracker(self, start_position, stop_position):
-    if start_position is None:
-      start_position = 0
-    if stop_position is None:
-      stop_position = self.doc_count
-    return OffsetRangeTracker(start_position, stop_position)
+    start_position, stop_position = self._replace_none_positions(
+        start_position, stop_position)
+    return _ObjectIdRangeTracker(start_position, stop_position)
 
   def read(self, range_tracker):
     with MongoClient(self.uri, **self.spec) as client:
-      # docs is a MongoDB Cursor
-      docs = client[self.db][self.coll].find(
-          filter=self.filter, projection=self.projection
-      )[range_tracker.start_position():range_tracker.stop_position()]
-      for index in range(range_tracker.start_position(),
-                         range_tracker.stop_position()):
-        if not range_tracker.try_claim(index):
+      all_filters = self._merge_id_filter(range_tracker)
+      docs_cursor = client[self.db][self.coll].find(filter=all_filters)
+      for doc in docs_cursor:
+        if not range_tracker.try_claim(doc['_id']):
           return
-        yield docs[index - range_tracker.start_position()]
+        yield doc
 
   def display_data(self):
     res = super(_BoundedMongoSource, self).display_data()
@@ -194,18 +207,146 @@
     res['mongo_client_spec'] = self.spec
     return res
 
-  def _get_avg_document_size(self):
+  def _get_split_keys(self, desired_chunk_size_in_mb, start_pos, end_pos):
+    # calls mongodb splitVector command to get document ids at split position
+    # for desired bundle size, if desired chunk size smaller than 1mb, use
+    # mongodb default split size of 1mb.
+    if desired_chunk_size_in_mb < 1:
+      desired_chunk_size_in_mb = 1
+    if start_pos >= end_pos:
+      # single document not splittable
+      return []
     with MongoClient(self.uri, **self.spec) as client:
-      size = client[self.db].command('collstats', self.coll).get('avgObjSize')
-      if size is None or size <= 0:
-        raise ValueError(
-            'Collection %s not found or average doc size is '
-            'incorrect', self.coll)
-      return size
+      name_space = '%s.%s' % (self.db, self.coll)
+      return (client[self.db].command(
+          'splitVector',
+          name_space,
+          keyPattern={'_id': 1},  # Ascending index
+          min={'_id': start_pos},
+          max={'_id': end_pos},
+          maxChunkSize=desired_chunk_size_in_mb)['splitKeys'])
 
-  def _get_document_count(self):
+  def _merge_id_filter(self, range_tracker):
+    # Merge the default filter with refined _id field range of range_tracker.
+    # see more at https://docs.mongodb.com/manual/reference/operator/query/and/
+    all_filters = {
+        '$and': [
+            self.filter.copy(),
+            # add additional range filter to query. $gte specifies start
+            # position(inclusive) and $lt specifies the end position(exclusive),
+            # see more at
+            # https://docs.mongodb.com/manual/reference/operator/query/gte/ and
+            # https://docs.mongodb.com/manual/reference/operator/query/lt/
+            {
+                '_id': {
+                    '$gte': range_tracker.start_position(),
+                    '$lt': range_tracker.stop_position()
+                }
+            },
+        ]
+    }
+
+    return all_filters
+
+  def _get_head_document_id(self, sort_order):
     with MongoClient(self.uri, **self.spec) as client:
-      return max(client[self.db][self.coll].count_documents(self.filter), 0)
+      cursor = client[self.db][self.coll].find(filter={}, projection=[]).sort([
+          ('_id', sort_order)
+      ]).limit(1)
+      try:
+        return cursor[0]['_id']
+      except IndexError:
+        raise ValueError('Empty Mongodb collection')
+
+  def _replace_none_positions(self, start_position, stop_position):
+    if start_position is None:
+      start_position = self._get_head_document_id(ASCENDING)
+    if stop_position is None:
+      last_doc_id = self._get_head_document_id(DESCENDING)
+      # increment last doc id binary value by 1 to make sure the last document
+      # is not excluded
+      stop_position = _ObjectIdHelper.increment_id(last_doc_id, 1)
+    return start_position, stop_position
+
+
+class _ObjectIdHelper(object):
+  """A Utility class to manipulate bson object ids."""
+
+  @classmethod
+  def id_to_int(cls, id):
+    """
+    Args:
+      id: ObjectId required for each MongoDB document _id field.
+
+    Returns: Converted integer value of ObjectId's 12 bytes binary value.
+
+    """
+    # converts object id binary to integer
+    # id object is bytes type with size of 12
+    ints = struct.unpack('>III', id.binary)
+    return (ints[0] << 64) + (ints[1] << 32) + ints[2]
+
+  @classmethod
+  def int_to_id(cls, number):
+    """
+    Args:
+      number(int): The integer value to be used to convert to ObjectId.
+
+    Returns: The ObjectId that has the 12 bytes binary converted from the
+      integer value.
+
+    """
+    # converts integer value to object id. Int value should be less than
+    # (2 ^ 96) so it can be convert to 12 bytes required by object id.
+    if number < 0 or number >= (1 << 96):
+      raise ValueError('number value must be within [0, %s)' % (1 << 96))
+    ints = [(number & 0xffffffff0000000000000000) >> 64,
+            (number & 0x00000000ffffffff00000000) >> 32,
+            number & 0x0000000000000000ffffffff]
+
+    bytes = struct.pack('>III', *ints)
+    return objectid.ObjectId(bytes)
+
+  @classmethod
+  def increment_id(cls, object_id, inc):
+    """
+    Args:
+      object_id: The ObjectId to change.
+      inc(int): The incremental int value to be added to ObjectId.
+
+    Returns:
+
+    """
+    # increment object_id binary value by inc value and return new object id.
+    id_number = _ObjectIdHelper.id_to_int(object_id)
+    new_number = id_number + inc
+    if new_number < 0 or new_number >= (1 << 96):
+      raise ValueError('invalid incremental, inc value must be within ['
+                       '%s, %s)' % (0 - id_number, 1 << 96 - id_number))
+    return _ObjectIdHelper.int_to_id(new_number)
+
+
+class _ObjectIdRangeTracker(OrderedPositionRangeTracker):
+  """RangeTracker for tracking mongodb _id of bson ObjectId type."""
+
+  def position_to_fraction(self, pos, start, end):
+    pos_number = _ObjectIdHelper.id_to_int(pos)
+    start_number = _ObjectIdHelper.id_to_int(start)
+    end_number = _ObjectIdHelper.id_to_int(end)
+    return (pos_number - start_number) / (end_number - start_number)
+
+  def fraction_to_position(self, fraction, start, end):
+    start_number = _ObjectIdHelper.id_to_int(start)
+    end_number = _ObjectIdHelper.id_to_int(end)
+    total = end_number - start_number
+    pos = int(total * fraction + start_number)
+    # make sure split position is larger than start position and smaller than
+    # end position.
+    if pos <= start_number:
+      return _ObjectIdHelper.increment_id(start, 1)
+    if pos >= end_number:
+      return _ObjectIdHelper.increment_id(end, -1)
+    return _ObjectIdHelper.int_to_id(pos)
 
 
 @experimental()
diff --git a/sdks/python/apache_beam/io/mongodbio_it_test.py b/sdks/python/apache_beam/io/mongodbio_it_test.py
index 9b2ddcb..bfc6099 100644
--- a/sdks/python/apache_beam/io/mongodbio_it_test.py
+++ b/sdks/python/apache_beam/io/mongodbio_it_test.py
@@ -42,17 +42,18 @@
                       default=default_coll,
                       help='mongo uri string for connection')
   parser.add_argument('--num_documents',
-                      default=1000,
+                      default=100000,
                       help='The expected number of documents to be generated '
                       'for write or read',
                       type=int)
   parser.add_argument('--batch_size',
-                      default=100,
+                      default=10000,
                       help=('batch size for writing to mongodb'))
   known_args, pipeline_args = parser.parse_known_args(argv)
 
   # Test Write to MongoDB
   with TestPipeline(options=PipelineOptions(pipeline_args)) as p:
+    start_time = time.time()
     logging.info('Writing %d documents to mongodb' % known_args.num_documents)
     docs = [{
         'number': x,
@@ -60,14 +61,14 @@
         'number_mod_3': x % 3
     } for x in range(known_args.num_documents)]
 
-    start_time = time.time()
     _ = p | 'Create documents' >> beam.Create(docs) \
           | 'WriteToMongoDB' >> beam.io.WriteToMongoDB(known_args.mongo_uri,
                                                        known_args.mongo_db,
                                                        known_args.mongo_coll,
                                                        known_args.batch_size)
-    logging.info('Writing %d documents to mongodb finished in %.3f seconds' %
-                 (known_args.num_documents, time.time() - start_time))
+  elapsed = time.time() - start_time
+  logging.info('Writing %d documents to mongodb finished in %.3f seconds' %
+               (known_args.num_documents, elapsed))
 
   # Test Read from MongoDB
   with TestPipeline(options=PipelineOptions(pipeline_args)) as p:
@@ -80,11 +81,12 @@
                                         known_args.mongo_coll,
                                         projection=['number']) \
           | 'Map' >> beam.Map(lambda doc: doc['number'])
-
     assert_that(
         r, equal_to([number for number in range(known_args.num_documents)]))
-    logging.info('Read %d documents from mongodb finished in %.3f seconds' %
-                 (known_args.num_documents, time.time() - start_time))
+
+  elapsed = time.time() - start_time
+  logging.info('Read %d documents from mongodb finished in %.3f seconds' %
+               (known_args.num_documents, elapsed))
 
 
 if __name__ == "__main__":
diff --git a/sdks/python/apache_beam/io/mongodbio_test.py b/sdks/python/apache_beam/io/mongodbio_test.py
index 6c4eacf..3f07ec1 100644
--- a/sdks/python/apache_beam/io/mongodbio_test.py
+++ b/sdks/python/apache_beam/io/mongodbio_test.py
@@ -15,12 +15,18 @@
 #
 
 from __future__ import absolute_import
+from __future__ import division
 
+import datetime
 import logging
+import random
+import sys
 import unittest
+from unittest import TestCase
 
 import mock
 from bson import objectid
+from pymongo import ASCENDING
 from pymongo import ReplaceOne
 
 import apache_beam as beam
@@ -30,38 +36,139 @@
 from apache_beam.io.mongodbio import _BoundedMongoSource
 from apache_beam.io.mongodbio import _GenerateObjectIdFn
 from apache_beam.io.mongodbio import _MongoSink
+from apache_beam.io.mongodbio import _ObjectIdHelper
+from apache_beam.io.mongodbio import _ObjectIdRangeTracker
 from apache_beam.io.mongodbio import _WriteMongoFn
 from apache_beam.testing.test_pipeline import TestPipeline
 from apache_beam.testing.util import assert_that
 from apache_beam.testing.util import equal_to
 
 
+class _MockMongoColl(object):
+  """Fake mongodb collection cursor."""
+
+  def __init__(self, docs):
+    self.docs = docs
+
+  def _filter(self, filter):
+    match = []
+    if not filter:
+      return self
+    if '$and' not in filter or not filter['$and']:
+      return self
+    start = filter['$and'][1]['_id'].get('$gte')
+    end = filter['$and'][1]['_id'].get('$lt')
+    assert start is not None
+    assert end is not None
+    for doc in self.docs:
+      if start and doc['_id'] < start:
+        continue
+      if end and doc['_id'] >= end:
+        continue
+      match.append(doc)
+    return match
+
+  def find(self, filter=None, **kwargs):
+    return _MockMongoColl(self._filter(filter))
+
+  def sort(self, sort_items):
+    key, order = sort_items[0]
+    self.docs = sorted(self.docs,
+                       key=lambda x: x[key],
+                       reverse=(order != ASCENDING))
+    return self
+
+  def limit(self, num):
+    return _MockMongoColl(self.docs[0:num])
+
+  def count_documents(self, filter):
+    return len(self._filter(filter))
+
+  def __getitem__(self, index):
+    return self.docs[index]
+
+
+class _MockMongoDb(object):
+  """Fake Mongo Db."""
+
+  def __init__(self, docs):
+    self.docs = docs
+
+  def __getitem__(self, coll_name):
+    return _MockMongoColl(self.docs)
+
+  def command(self, command, *args, **kwargs):
+    if command == 'collstats':
+      return {'size': 5, 'avgSize': 1}
+    elif command == 'splitVector':
+      return self.get_split_keys(command, *args, **kwargs)
+
+  def get_split_keys(self, command, ns, min, max, maxChunkSize, **kwargs):
+    # simulate mongo db splitVector command, return split keys base on chunk
+    # size, assuming every doc is of size 1mb
+    start_id = min['_id']
+    end_id = max['_id']
+    if start_id >= end_id:
+      return []
+    start_index = 0
+    end_index = 0
+    # get split range of [min, max]
+    for doc in self.docs:
+      if doc['_id'] < start_id:
+        start_index += 1
+      if doc['_id'] <= end_id:
+        end_index += 1
+      else:
+        break
+    # Return ids of elements in the range with chunk size skip and exclude
+    # head element. For simplicity of tests every document is considered 1Mb
+    # by default.
+    return {
+        'splitKeys':
+        [x['_id'] for x in self.docs[start_index:end_index:maxChunkSize]][1:]
+    }
+
+
+class _MockMongoClient(object):
+  def __init__(self, docs):
+    self.docs = docs
+
+  def __getitem__(self, db_name):
+    return _MockMongoDb(self.docs)
+
+  def __enter__(self):
+    return self
+
+  def __exit__(self, exc_type, exc_val, exc_tb):
+    pass
+
+
 class MongoSourceTest(unittest.TestCase):
-  @mock.patch('apache_beam.io.mongodbio._BoundedMongoSource'
-              '._get_document_count')
-  @mock.patch('apache_beam.io.mongodbio._BoundedMongoSource'
-              '._get_avg_document_size')
-  def setUp(self, mock_size, mock_count):
-    mock_size.return_value = 10
-    mock_count.return_value = 5
+  @mock.patch('apache_beam.io.mongodbio.MongoClient')
+  def setUp(self, mock_client):
+    self._ids = [
+        objectid.ObjectId.from_datetime(
+            datetime.datetime(year=2020, month=i + 1, day=i + 1))
+        for i in range(5)
+    ]
+    self._docs = [{'_id': self._ids[i], 'x': i} for i in range(len(self._ids))]
+    mock_client.return_value = _MockMongoClient(self._docs)
+
     self.mongo_source = _BoundedMongoSource('mongodb://test', 'testdb',
                                             'testcoll')
 
-  def test_estimate_size(self):
-    self.assertEqual(self.mongo_source.estimate_size(), 50)
+  @mock.patch('apache_beam.io.mongodbio.MongoClient')
+  def test_estimate_size(self, mock_client):
+    mock_client.return_value = _MockMongoClient(self._docs)
+    self.assertEqual(self.mongo_source.estimate_size(), 5)
 
   @mock.patch('apache_beam.io.mongodbio.MongoClient')
   def test_split(self, mock_client):
-    # desired bundle size is 1 times of avg doc size, each bundle contains 1
-    # documents
-    mock_client.return_value.__enter__.return_value.__getitem__.return_value \
-      .__getitem__.return_value.find.return_value = [{'x': 1}, {'x': 2},
-                                                     {'x': 3}, {'x': 4},
-                                                     {'x': 5}]
-    for size in [10, 20, 100]:
+    mock_client.return_value = _MockMongoClient(self._docs)
+    for size in [i * 1024 * 1024 for i in (1, 2, 10)]:
       splits = list(
-          self.mongo_source.split(start_position=0,
-                                  stop_position=5,
+          self.mongo_source.split(start_position=None,
+                                  stop_position=None,
                                   desired_bundle_size=size))
 
       reference_info = (self.mongo_source, None, None)
@@ -72,36 +179,60 @@
 
   @mock.patch('apache_beam.io.mongodbio.MongoClient')
   def test_dynamic_work_rebalancing(self, mock_client):
-    splits = list(self.mongo_source.split(desired_bundle_size=3000))
-    mock_client.return_value.__enter__.return_value.__getitem__.return_value \
-      .__getitem__.return_value.find.return_value = [{'x': 1}, {'x': 2},
-                                                     {'x': 3}, {'x': 4},
-                                                     {'x': 5}]
+    mock_client.return_value = _MockMongoClient(self._docs)
+    splits = list(
+        self.mongo_source.split(desired_bundle_size=3000 * 1024 * 1024))
     assert len(splits) == 1
     source_test_utils.assert_split_at_fraction_exhaustive(
         splits[0].source, splits[0].start_position, splits[0].stop_position)
 
-  @mock.patch('apache_beam.io.mongodbio.OffsetRangeTracker')
-  def test_get_range_tracker(self, mock_tracker):
-    self.mongo_source.get_range_tracker(None, None)
-    mock_tracker.assert_called_with(0, 5)
-    self.mongo_source.get_range_tracker(10, 20)
-    mock_tracker.assert_called_with(10, 20)
+  @mock.patch('apache_beam.io.mongodbio.MongoClient')
+  def test_get_range_tracker(self, mock_client):
+    mock_client.return_value = _MockMongoClient(self._docs)
+    self.assertIsInstance(self.mongo_source.get_range_tracker(None, None),
+                          _ObjectIdRangeTracker)
 
   @mock.patch('apache_beam.io.mongodbio.MongoClient')
   def test_read(self, mock_client):
     mock_tracker = mock.MagicMock()
-    mock_tracker.try_claim.return_value = True
-    mock_tracker.start_position.return_value = 0
-    mock_tracker.stop_position.return_value = 2
-
-    mock_client.return_value.__enter__.return_value.__getitem__.return_value\
-      .__getitem__.return_value.find.return_value = [{'x':1}, {'x':2}]
-
-    result = []
-    for i in self.mongo_source.read(mock_tracker):
-      result.append(i)
-    self.assertListEqual([{'x': 1}, {'x': 2}], result)
+    test_cases = [
+        {
+            # range covers the first(inclusive) to third(exclusive) documents
+            'start': self._ids[0],
+            'stop': self._ids[2],
+            'expected': self._docs[0:2]
+        },
+        {
+            # range covers from the first to the third documents
+            'start': _ObjectIdHelper.int_to_id(0),  # smallest possible id
+            'stop': self._ids[2],
+            'expected': self._docs[0:2]
+        },
+        {
+            # range covers from the third to last documents
+            'start': self._ids[2],
+            'stop': _ObjectIdHelper.int_to_id(2**96 - 1),  # largest possible id
+            'expected': self._docs[2:]
+        },
+        {
+            # range covers all documents
+            'start': _ObjectIdHelper.int_to_id(0),
+            'stop': _ObjectIdHelper.int_to_id(2**96 - 1),
+            'expected': self._docs
+        },
+        {
+            # range doesn't include any document
+            'start': _ObjectIdHelper.increment_id(self._ids[2], 1),
+            'stop': _ObjectIdHelper.increment_id(self._ids[3], -1),
+            'expected': []
+        },
+    ]
+    mock_client.return_value = _MockMongoClient(self._docs)
+    for case in test_cases:
+      mock_tracker.start_position.return_value = case['start']
+      mock_tracker.stop_position.return_value = case['stop']
+      result = list(self.mongo_source.read(mock_tracker))
+      self.assertListEqual(case['expected'], result)
 
   def test_display_data(self):
     data = self.mongo_source.display_data()
@@ -109,35 +240,17 @@
     self.assertTrue('database' in data)
     self.assertTrue('collection' in data)
 
-  @mock.patch('apache_beam.io.mongodbio.MongoClient')
-  def test__get_avg_document_size(self, mock_client):
-    mock_client.return_value.__enter__.return_value.__getitem__\
-      .return_value.command.return_value = {'avgObjSize': 5}
-    self.assertEqual(5, self.mongo_source._get_avg_document_size())
-
-  @mock.patch('apache_beam.io.mongodbio.MongoClient')
-  def test_get_document_count(self, mock_client):
-    mock_client.return_value.__enter__.return_value.__getitem__ \
-      .return_value.__getitem__.return_value.count_documents.return_value = 10
-
-    self.assertEqual(10, self.mongo_source._get_document_count())
-
 
 class ReadFromMongoDBTest(unittest.TestCase):
   @mock.patch('apache_beam.io.mongodbio.MongoClient')
   def test_read_from_mongodb(self, mock_client):
-    objects = [{'x': 1}, {'x': 2}]
-    mock_client.return_value.__enter__.return_value.__getitem__.return_value. \
-      command.return_value = {'avgObjSize': 1}
-    mock_client.return_value.__enter__.return_value.__getitem__.return_value. \
-      __getitem__.return_value.find.return_value = objects
-    mock_client.return_value.__enter__.return_value.__getitem__.return_value. \
-      __getitem__.return_value.count_documents.return_value = 2
+    documents = [{'_id': objectid.ObjectId(), 'x': i} for i in range(3)]
+    mock_client.return_value = _MockMongoClient(documents)
 
     with TestPipeline() as p:
       docs = p | 'ReadFromMongoDB' >> ReadFromMongoDB(
           uri='mongodb://test', db='db', coll='collection')
-      assert_that(docs, equal_to(objects))
+      assert_that(docs, equal_to(documents))
 
 
 class GenerateObjectIdFnTest(unittest.TestCase):
@@ -190,7 +303,7 @@
       _ = (p | "Create" >> beam.Create(docs)
            | "Write" >> WriteToMongoDB(db='test', coll='test'))
       p.run()
-      mock_client.return_value.__getitem__.return_value.__getitem__.\
+      mock_client.return_value.__getitem__.return_value.__getitem__. \
         return_value.bulk_write.assert_called_with(expected_update)
 
   @mock.patch('apache_beam.io.mongodbio.MongoClient')
@@ -210,6 +323,74 @@
         return_value.bulk_write.assert_called_with(expected_update)
 
 
+class ObjectIdHelperTest(TestCase):
+  def test_conversion(self):
+    test_cases = [
+        (objectid.ObjectId('000000000000000000000000'), 0),
+        (objectid.ObjectId('000000000000000100000000'), 2**32),
+        (objectid.ObjectId('0000000000000000ffffffff'), 2**32 - 1),
+        (objectid.ObjectId('000000010000000000000000'), 2**64),
+        (objectid.ObjectId('00000000ffffffffffffffff'), 2**64 - 1),
+        (objectid.ObjectId('ffffffffffffffffffffffff'), 2**96 - 1),
+    ]
+    for (id, number) in test_cases:
+      self.assertEqual(id, _ObjectIdHelper.int_to_id(number))
+      self.assertEqual(number, _ObjectIdHelper.id_to_int(id))
+
+    # random tests
+    for _ in range(100):
+      id = objectid.ObjectId()
+      if sys.version_info[0] < 3:
+        number = int(id.binary.encode('hex'), 16)
+      else:  # PY3
+        number = int(id.binary.hex(), 16)
+      self.assertEqual(id, _ObjectIdHelper.int_to_id(number))
+      self.assertEqual(number, _ObjectIdHelper.id_to_int(id))
+
+  def test_increment_id(self):
+    test_cases = [
+        (objectid.ObjectId('000000000000000100000000'),
+         objectid.ObjectId('0000000000000000ffffffff')),
+        (objectid.ObjectId('000000010000000000000000'),
+         objectid.ObjectId('00000000ffffffffffffffff')),
+    ]
+    for (first, second) in test_cases:
+      self.assertEqual(second, _ObjectIdHelper.increment_id(first, -1))
+      self.assertEqual(first, _ObjectIdHelper.increment_id(second, 1))
+
+    for _ in range(100):
+      id = objectid.ObjectId()
+      self.assertLess(id, _ObjectIdHelper.increment_id(id, 1))
+      self.assertGreater(id, _ObjectIdHelper.increment_id(id, -1))
+
+
+class ObjectRangeTrackerTest(TestCase):
+  def test_fraction_position_conversion(self):
+    start_int = 0
+    stop_int = 2**96 - 1
+    start = _ObjectIdHelper.int_to_id(start_int)
+    stop = _ObjectIdHelper.int_to_id(stop_int)
+    test_cases = ([start_int, stop_int, 2**32, 2**32 - 1, 2**64, 2**64 - 1] +
+                  [random.randint(start_int, stop_int) for _ in range(100)])
+    tracker = _ObjectIdRangeTracker()
+    for pos in test_cases:
+      id = _ObjectIdHelper.int_to_id(pos - start_int)
+      desired_fraction = (pos - start_int) / (stop_int - start_int)
+      self.assertAlmostEqual(tracker.position_to_fraction(id, start, stop),
+                             desired_fraction,
+                             places=20)
+
+      convert_id = tracker.fraction_to_position(
+          (pos - start_int) / (stop_int - start_int), start, stop)
+      # due to precision loss, the convert fraction is only gonna be close to
+      # original fraction.
+      convert_fraction = tracker.position_to_fraction(convert_id, start, stop)
+
+      self.assertGreater(convert_id, start)
+      self.assertLess(convert_id, stop)
+      self.assertAlmostEqual(convert_fraction, desired_fraction, places=20)
+
+
 if __name__ == '__main__':
   logging.getLogger().setLevel(logging.INFO)
   unittest.main()