blob: 106f7542b2302602b3e49df13f8d7bcaa9fcb0bd [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# pytype: skip-file
import copy
import itertools
import random
import threading
import unittest
from apache_beam.metrics.cells import BoundedTrieData
from apache_beam.metrics.cells import CounterCell
from apache_beam.metrics.cells import DistributionCell
from apache_beam.metrics.cells import DistributionData
from apache_beam.metrics.cells import GaugeCell
from apache_beam.metrics.cells import GaugeData
from apache_beam.metrics.cells import StringSetCell
from apache_beam.metrics.cells import StringSetData
from apache_beam.metrics.cells import _BoundedTrieNode
from apache_beam.metrics.metricbase import MetricName
class TestCounterCell(unittest.TestCase):
@classmethod
def _modify_counter(cls, d):
for i in range(cls.NUM_ITERATIONS):
d.inc(i)
NUM_THREADS = 5
NUM_ITERATIONS = 100
def test_parallel_access(self):
# We create NUM_THREADS threads that concurrently modify the counter.
threads = []
c = CounterCell()
for _ in range(TestCounterCell.NUM_THREADS):
t = threading.Thread(
target=TestCounterCell._modify_counter, args=(c, ))
threads.append(t)
t.start()
for t in threads:
t.join()
total = (
self.NUM_ITERATIONS * (self.NUM_ITERATIONS - 1) // 2 * self.NUM_THREADS)
self.assertEqual(c.get_cumulative(), total)
def test_basic_operations(self):
c = CounterCell()
c.inc(2)
self.assertEqual(c.get_cumulative(), 2)
c.dec(10)
self.assertEqual(c.get_cumulative(), -8)
c.dec()
self.assertEqual(c.get_cumulative(), -9)
c.inc()
self.assertEqual(c.get_cumulative(), -8)
def test_start_time_set(self):
c = CounterCell()
c.inc(2)
name = MetricName('namespace', 'name1')
mi = c.to_runner_api_monitoring_info(name, 'transform_id')
self.assertGreater(mi.start_time.seconds, 0)
class TestDistributionCell(unittest.TestCase):
@classmethod
def _modify_distribution(cls, d):
for i in range(cls.NUM_ITERATIONS):
d.update(i)
NUM_THREADS = 5
NUM_ITERATIONS = 100
def test_parallel_access(self):
# We create NUM_THREADS threads that concurrently modify the distribution.
threads = []
d = DistributionCell()
for _ in range(TestDistributionCell.NUM_THREADS):
t = threading.Thread(
target=TestDistributionCell._modify_distribution, args=(d, ))
threads.append(t)
t.start()
for t in threads:
t.join()
total = (
self.NUM_ITERATIONS * (self.NUM_ITERATIONS - 1) // 2 * self.NUM_THREADS)
count = (self.NUM_ITERATIONS * self.NUM_THREADS)
self.assertEqual(
d.get_cumulative(),
DistributionData(total, count, 0, self.NUM_ITERATIONS - 1))
def test_basic_operations(self):
d = DistributionCell()
d.update(10)
self.assertEqual(d.get_cumulative(), DistributionData(10, 1, 10, 10))
d.update(2)
self.assertEqual(d.get_cumulative(), DistributionData(12, 2, 2, 10))
d.update(900)
self.assertEqual(d.get_cumulative(), DistributionData(912, 3, 2, 900))
def test_integer_only(self):
d = DistributionCell()
d.update(3.1)
d.update(3.2)
d.update(3.3)
self.assertEqual(d.get_cumulative(), DistributionData(9, 3, 3, 3))
def test_start_time_set(self):
d = DistributionCell()
d.update(3.1)
name = MetricName('namespace', 'name1')
mi = d.to_runner_api_monitoring_info(name, 'transform_id')
self.assertGreater(mi.start_time.seconds, 0)
class TestGaugeCell(unittest.TestCase):
def test_basic_operations(self):
g = GaugeCell()
g.set(10)
self.assertEqual(g.get_cumulative().value, GaugeData(10).value)
g.set(2)
self.assertEqual(g.get_cumulative().value, 2)
def test_integer_only(self):
g = GaugeCell()
g.set(3.3)
self.assertEqual(g.get_cumulative().value, 3)
def test_combine_appropriately(self):
g1 = GaugeCell()
g1.set(3)
g2 = GaugeCell()
g2.set(1)
# THe second Gauge, with value 1 was the most recent, so it should be
# the final result.
result = g2.combine(g1)
self.assertEqual(result.data.value, 1)
def test_start_time_set(self):
g1 = GaugeCell()
g1.set(3)
name = MetricName('namespace', 'name1')
mi = g1.to_runner_api_monitoring_info(name, 'transform_id')
self.assertGreater(mi.start_time.seconds, 0)
class TestStringSetCell(unittest.TestCase):
def test_not_leak_mutable_set(self):
c = StringSetCell()
c.add('test')
c.add('another')
s = c.get_cumulative()
self.assertEqual(s, StringSetData({'test', 'another'}, 11))
s.add('yet another')
self.assertEqual(c.get_cumulative(), StringSetData({'test', 'another'}, 11))
def test_combine_appropriately(self):
s1 = StringSetCell()
s1.add('1')
s1.add('2')
s2 = StringSetCell()
s2.add('1')
s2.add('3')
result = s2.combine(s1)
self.assertEqual(result.data, StringSetData({'1', '2', '3'}))
def test_add_size_tracked_correctly(self):
s = StringSetCell()
s.add('1')
s.add('2')
self.assertEqual(s.data.string_size, 2)
s.add('2')
s.add('3')
self.assertEqual(s.data.string_size, 3)
class TestBoundedTrieNode(unittest.TestCase):
@classmethod
def random_segments_fixed_depth(cls, n, depth, overlap, rand):
if depth == 0:
yield from ((), ) * n
else:
seen = []
to_string = lambda ix: chr(ord('a') + ix) if ix < 26 else f'z{ix}'
for suffix in cls.random_segments_fixed_depth(n, depth - 1, overlap,
rand):
if not seen or rand.random() > overlap:
prefix = to_string(len(seen))
seen.append(prefix)
else:
prefix = rand.choice(seen)
yield (prefix, ) + suffix
@classmethod
def random_segments(cls, n, min_depth, max_depth, overlap, rand):
for depth, segments in zip(
itertools.cycle(range(min_depth, max_depth + 1)),
cls.random_segments_fixed_depth(n, max_depth, overlap, rand)):
yield segments[:depth]
def assert_covers(self, node, expected, max_truncated=0):
self.assert_covers_flattened(node.flattened(), expected, max_truncated)
def assert_covers_flattened(self, flattened, expected, max_truncated=0):
expected = set(expected)
# Split node into the exact and truncated segments.
partitioned = {True: set(), False: set()}
for segments in flattened:
partitioned[segments[-1]].add(segments[:-1])
exact, truncated = partitioned[False], partitioned[True]
# Check we cover both parts.
self.assertLessEqual(len(truncated), max_truncated, truncated)
self.assertTrue(exact.issubset(expected), exact - expected)
seen_truncated = set()
for segments in expected - exact:
found = 0
for ix in range(len(segments)):
if segments[:ix] in truncated:
seen_truncated.add(segments[:ix])
found += 1
if found != 1:
self.fail(
f"Expected exactly one prefix of {segments} "
f"to occur in {truncated}, found {found}")
self.assertEqual(seen_truncated, truncated, truncated - seen_truncated)
def run_covers_test(self, flattened, expected, max_truncated):
def parse(s):
return tuple(s.strip('*')) + (s.endswith('*'), )
self.assert_covers_flattened([parse(s) for s in flattened],
[tuple(s) for s in expected],
max_truncated)
def test_covers_exact(self):
self.run_covers_test(['ab', 'ac', 'cd'], ['ab', 'ac', 'cd'], 0)
with self.assertRaises(AssertionError):
self.run_covers_test(['ab', 'ac', 'cd'], ['ac', 'cd'], 0)
with self.assertRaises(AssertionError):
self.run_covers_test(['ab', 'ac'], ['ab', 'ac', 'cd'], 0)
with self.assertRaises(AssertionError):
self.run_covers_test(['a*', 'cd'], ['ab', 'ac', 'cd'], 0)
def test_covers_trunacted(self):
self.run_covers_test(['a*', 'cd'], ['ab', 'ac', 'cd'], 1)
self.run_covers_test(['a*', 'cd'], ['ab', 'ac', 'abcde', 'cd'], 1)
with self.assertRaises(AssertionError):
self.run_covers_test(['ab', 'ac', 'cd'], ['ac', 'cd'], 1)
with self.assertRaises(AssertionError):
self.run_covers_test(['ab', 'ac'], ['ab', 'ac', 'cd'], 1)
with self.assertRaises(AssertionError):
self.run_covers_test(['a*', 'c*'], ['ab', 'ac', 'cd'], 1)
with self.assertRaises(AssertionError):
self.run_covers_test(['a*', 'c*'], ['ab', 'ac'], 1)
def run_test(self, to_add):
everything = list(set(to_add))
all_prefixees = set(
segments[:ix] for segments in everything for ix in range(len(segments)))
everything_deduped = set(everything) - all_prefixees
# Check basic addition.
node = _BoundedTrieNode()
total_size = node.size()
self.assertEqual(total_size, 1)
for segments in everything:
total_size += node.add(segments)
self.assertEqual(node.size(), len(everything_deduped), node)
self.assertEqual(node.size(), total_size, node)
self.assert_covers(node, everything_deduped)
# Check merging
node0 = _BoundedTrieNode()
node0.add_all(everything[0::2])
node1 = _BoundedTrieNode()
node1.add_all(everything[1::2])
pre_merge_size = node0.size()
merge_delta = node0.merge(node1)
self.assertEqual(node0.size(), pre_merge_size + merge_delta)
self.assertEqual(node0, node)
# Check trimming.
if node.size() > 1:
trim_delta = node.trim()
self.assertLess(trim_delta, 0, node)
self.assertEqual(node.size(), total_size + trim_delta)
self.assert_covers(node, everything_deduped, max_truncated=1)
if node.size() > 1:
trim2_delta = node.trim()
self.assertLess(trim2_delta, 0)
self.assertEqual(node.size(), total_size + trim_delta + trim2_delta)
self.assert_covers(node, everything_deduped, max_truncated=2)
# Adding after trimming should be a no-op.
node_copy = copy.deepcopy(node)
for segments in everything:
self.assertEqual(node.add(segments), 0)
self.assertEqual(node, node_copy)
# Merging after trimming should be a no-op.
self.assertEqual(node.merge(node0), 0)
self.assertEqual(node.merge(node1), 0)
self.assertEqual(node, node_copy)
if node._truncated:
expected_delta = 0
else:
expected_delta = 2
# Adding something new is not.
new_values = [('new1', ), ('new2', 'new2.1')]
self.assertEqual(node.add_all(new_values), expected_delta)
self.assert_covers(
node, list(everything_deduped) + new_values, max_truncated=2)
# Nor is merging something new.
new_values_node = _BoundedTrieNode()
new_values_node.add_all(new_values)
self.assertEqual(node_copy.merge(new_values_node), expected_delta)
self.assert_covers(
node_copy, list(everything_deduped) + new_values, max_truncated=2)
def run_fuzz(self, iterations=10, **params):
for _ in range(iterations):
seed = random.getrandbits(64)
segments = self.random_segments(**params, rand=random.Random(seed))
try:
self.run_test(segments)
except:
print("SEED", seed)
raise
def test_trivial(self):
self.run_test([('a', 'b'), ('a', 'c')])
def test_flat(self):
self.run_test([('a', 'a'), ('b', 'b'), ('c', 'c')])
def test_deep(self):
self.run_test([('a', ) * 10, ('b', ) * 12])
def test_small(self):
self.run_fuzz(n=5, min_depth=2, max_depth=3, overlap=0.5)
def test_medium(self):
self.run_fuzz(n=20, min_depth=2, max_depth=4, overlap=0.5)
def test_large_sparse(self):
self.run_fuzz(n=120, min_depth=2, max_depth=4, overlap=0.2)
def test_large_dense(self):
self.run_fuzz(n=120, min_depth=2, max_depth=4, overlap=0.8)
def test_bounded_trie_data_combine(self):
empty = BoundedTrieData()
# The merging here isn't complicated we're just ensuring that
# BoundedTrieData invokes _BoundedTrieNode correctly.
singletonA = BoundedTrieData(singleton=('a', 'a'))
singletonB = BoundedTrieData(singleton=('b', 'b'))
lots_root = _BoundedTrieNode()
lots_root.add_all([('c', 'c'), ('d', 'd')])
lots = BoundedTrieData(root=lots_root)
self.assertEqual(empty.get_result(), set())
self.assertEqual(
empty.combine(singletonA).get_result(), set([('a', 'a', False)]))
self.assertEqual(
singletonA.combine(empty).get_result(), set([('a', 'a', False)]))
self.assertEqual(
singletonA.combine(singletonB).get_result(),
set([('a', 'a', False), ('b', 'b', False)]))
self.assertEqual(
singletonA.combine(lots).get_result(),
set([('a', 'a', False), ('c', 'c', False), ('d', 'd', False)]))
self.assertEqual(
lots.combine(singletonA).get_result(),
set([('a', 'a', False), ('c', 'c', False), ('d', 'd', False)]))
def test_bounded_trie_data_combine_trim(self):
left = _BoundedTrieNode()
left.add_all([('a', 'x'), ('b', 'd')])
right = _BoundedTrieNode()
right.add_all([('a', 'y'), ('c', 'd')])
self.assertEqual(
BoundedTrieData(root=left).combine(
BoundedTrieData(root=right, bound=3)).get_result(),
set([('a', True), ('b', 'd', False), ('c', 'd', False)]))
def test_merge_on_empty_node(self):
root1 = _BoundedTrieNode()
root2 = _BoundedTrieNode()
root2.add_all([["a", "b", "c"], ["a", "b", "d"], ["a", "e"]])
self.assertEqual(2, root1.merge(root2))
self.assertEqual(3, root1.size())
self.assertFalse(root1._truncated)
def test_merge_with_empty_node(self):
root1 = _BoundedTrieNode()
root1.add_all([["a", "b", "c"], ["a", "b", "d"], ["a", "e"]])
root2 = _BoundedTrieNode()
self.assertEqual(0, root1.merge(root2))
self.assertEqual(3, root1.size())
self.assertFalse(root1._truncated)
if __name__ == '__main__':
unittest.main()