blob: 7e02ff6a416a49ed6c68e26e43fc259178e2a278 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Tests for state caching."""
# pytype: skip-file
import logging
import re
import sys
import threading
import time
import unittest
import weakref
import objsize
from hamcrest import assert_that
from hamcrest import contains_string
from apache_beam.runners.worker.statecache import CacheAware
from apache_beam.runners.worker.statecache import StateCache
from apache_beam.runners.worker.statecache import WeightedValue
from apache_beam.runners.worker.statecache import _LoadingValue
from apache_beam.runners.worker.statecache import get_deep_size
class StateCacheTest(unittest.TestCase):
def test_weakref(self):
test_value = WeightedValue('test', 10 << 20)
class WeightedValueRef():
def __init__(self):
self.ref = weakref.ref(test_value)
cache = StateCache(5 << 20)
wait_event = threading.Event()
o = WeightedValueRef()
cache.put('deep ref', o)
# Ensure that the contents of the internal weak ref isn't sized
self.assertIsNotNone(cache.peek('deep ref'))
self.assertEqual(
cache.describe_stats(),
'used/max 0/5 MB, hit 100.00%, lookups 1, avg load time 0 ns, loads 0, '
'evictions 0')
cache.invalidate_all()
# Ensure that putting in a weakref doesn't fail regardless of whether
# it is alive or not
o_ref = weakref.ref(o, lambda value: wait_event.set())
cache.put('not deleted ref', o_ref)
del o
wait_event.wait()
cache.put('deleted', o_ref)
def test_weakref_proxy(self):
test_value = WeightedValue('test', 10 << 20)
class WeightedValueRef():
def __init__(self):
self.ref = weakref.ref(test_value)
cache = StateCache(5 << 20)
wait_event = threading.Event()
o = WeightedValueRef()
cache.put('deep ref', o)
# Ensure that the contents of the internal weak ref isn't sized
self.assertIsNotNone(cache.peek('deep ref'))
self.assertEqual(
cache.describe_stats(),
'used/max 0/5 MB, hit 100.00%, lookups 1, avg load time 0 ns, loads 0, '
'evictions 0')
cache.invalidate_all()
# Ensure that putting in a weakref doesn't fail regardless of whether
# it is alive or not
o_ref = weakref.proxy(o, lambda value: wait_event.set())
cache.put('not deleted', o_ref)
del o
wait_event.wait()
cache.put('deleted', o_ref)
def test_size_of_fails(self):
class BadSizeOf(object):
def __sizeof__(self):
raise RuntimeError("TestRuntimeError")
cache = StateCache(5 << 20)
with self.assertLogs('apache_beam.runners.worker.statecache',
level='WARNING') as context:
cache.put('key', BadSizeOf())
self.assertEqual(1, len(context.output))
self.assertTrue('Failed to size' in context.output[0])
# Test that we don't spam the logs
cache.put('key', BadSizeOf())
self.assertEqual(1, len(context.output))
def test_empty_cache_peek(self):
cache = StateCache(5 << 20)
self.assertEqual(cache.peek("key"), None)
self.assertEqual(
cache.describe_stats(),
(
'used/max 0/5 MB, hit 0.00%, lookups 1, '
'avg load time 0 ns, loads 0, evictions 0'))
def test_put_peek(self):
cache = StateCache(5 << 20)
cache.put("key", WeightedValue("value", 1 << 20))
self.assertEqual(cache.size(), 1)
self.assertEqual(cache.peek("key"), "value")
self.assertEqual(cache.peek("key2"), None)
self.assertEqual(
cache.describe_stats(),
(
'used/max 1/5 MB, hit 50.00%, lookups 2, '
'avg load time 0 ns, loads 0, evictions 0'))
def test_default_sized_put(self):
cache = StateCache(5 << 20)
cache.put("key", bytearray(1 << 20))
cache.put("key2", bytearray(1 << 20))
cache.put("key3", bytearray(1 << 20))
self.assertEqual(cache.peek("key3"), bytearray(1 << 20))
cache.put("key4", bytearray(1 << 20))
cache.put("key5", bytearray(1 << 20))
# note that each byte array instance takes slightly over 1 MB which is why
# these 5 byte arrays can't all be stored in the cache causing a single
# eviction
self.assertEqual(
cache.describe_stats(),
(
'used/max 4/5 MB, hit 100.00%, lookups 1, '
'avg load time 0 ns, loads 0, evictions 1'))
def test_max_size(self):
cache = StateCache(2 << 20)
cache.put("key", WeightedValue("value", 1 << 20))
cache.put("key2", WeightedValue("value2", 1 << 20))
self.assertEqual(cache.size(), 2)
cache.put("key3", WeightedValue("value3", 1 << 20))
self.assertEqual(cache.size(), 2)
self.assertEqual(
cache.describe_stats(),
(
'used/max 2/2 MB, hit 100.00%, lookups 0, '
'avg load time 0 ns, loads 0, evictions 1'))
def test_invalidate_all(self):
cache = StateCache(5 << 20)
cache.put("key", WeightedValue("value", 1 << 20))
cache.put("key2", WeightedValue("value2", 1 << 20))
self.assertEqual(cache.size(), 2)
cache.invalidate_all()
self.assertEqual(cache.size(), 0)
self.assertEqual(cache.peek("key"), None)
self.assertEqual(cache.peek("key2"), None)
self.assertEqual(
cache.describe_stats(),
(
'used/max 0/5 MB, hit 0.00%, lookups 2, '
'avg load time 0 ns, loads 0, evictions 0'))
def test_lru(self):
cache = StateCache(5 << 20)
cache.put("key", WeightedValue("value", 1 << 20))
cache.put("key2", WeightedValue("value2", 1 << 20))
cache.put("key3", WeightedValue("value0", 1 << 20))
cache.put("key3", WeightedValue("value3", 1 << 20))
cache.put("key4", WeightedValue("value4", 1 << 20))
cache.put("key5", WeightedValue("value0", 1 << 20))
cache.put("key5", WeightedValue(["value5"], 1 << 20))
self.assertEqual(cache.size(), 5)
self.assertEqual(cache.peek("key"), "value")
self.assertEqual(cache.peek("key2"), "value2")
self.assertEqual(cache.peek("key3"), "value3")
self.assertEqual(cache.peek("key4"), "value4")
self.assertEqual(cache.peek("key5"), ["value5"])
# insert another key to trigger cache eviction
cache.put("key6", WeightedValue("value6", 1 << 20))
self.assertEqual(cache.size(), 5)
# least recently used key should be gone ("key")
self.assertEqual(cache.peek("key"), None)
# trigger a read on "key2"
cache.peek("key2")
# insert another key to trigger cache eviction
cache.put("key7", WeightedValue("value7", 1 << 20))
self.assertEqual(cache.size(), 5)
# least recently used key should be gone ("key3")
self.assertEqual(cache.peek("key3"), None)
# insert another key to trigger cache eviction
cache.put("key8", WeightedValue("put", 1 << 20))
self.assertEqual(cache.size(), 5)
# insert another key to trigger cache eviction
cache.put("key9", WeightedValue("value8", 1 << 20))
self.assertEqual(cache.size(), 5)
# least recently used key should be gone ("key4")
self.assertEqual(cache.peek("key4"), None)
# make "key5" used by writing to it
cache.put("key5", WeightedValue("val", 1 << 20))
# least recently used key should be gone ("key6")
self.assertEqual(cache.peek("key6"), None)
self.assertEqual(
cache.describe_stats(),
(
'used/max 5/5 MB, hit 60.00%, lookups 10, '
'avg load time 0 ns, loads 0, evictions 5'))
def test_get(self):
def check_key(key):
self.assertEqual(key, "key")
time.sleep(0.5)
return "value"
def raise_exception(key):
time.sleep(0.5)
raise Exception("TestException")
cache = StateCache(5 << 20)
self.assertEqual("value", cache.get("key", check_key))
with cache._lock:
self.assertFalse(isinstance(cache._cache["key"], _LoadingValue))
self.assertEqual("value", cache.peek("key"))
cache.invalidate_all()
with self.assertRaisesRegex(Exception, "TestException"):
cache.get("key", raise_exception)
# The cache should not have the value after the failing load causing
# check_key to load the value.
self.assertEqual("value", cache.get("key", check_key))
with cache._lock:
self.assertFalse(isinstance(cache._cache["key"], _LoadingValue))
self.assertEqual("value", cache.peek("key"))
assert_that(cache.describe_stats(), contains_string(", loads 3,"))
load_time_ns = re.search(
", avg load time (.+) ns,", cache.describe_stats()).group(1)
# Load time should be larger then the sleep time and less than 2x sleep time
self.assertGreater(int(load_time_ns), 0.5 * 1_000_000_000)
self.assertLess(int(load_time_ns), 1_000_000_000)
def test_concurrent_get_waits(self):
event = threading.Semaphore(0)
threads_running = threading.Barrier(3)
def wait_for_event(key):
with cache._lock:
self.assertTrue(isinstance(cache._cache["key"], _LoadingValue))
event.release()
return "value"
cache = StateCache(5 << 20)
def load_key(output):
threads_running.wait()
output["value"] = cache.get("key", wait_for_event)
output["time"] = time.time_ns()
t1_output = {}
t1 = threading.Thread(
target=load_key, args=(t1_output, ))
t1.start()
t2_output = {}
t2 = threading.Thread(
target=load_key, args=(t2_output, ))
t2.start()
# Wait for both threads to start
threads_running.wait()
# Record the time and wait for the load to start
current_time_ns = time.time_ns()
event.acquire()
t1.join()
t2.join()
# Ensure that only one thread did the loading and not both by checking that
# the semaphore was only released once
self.assertFalse(event.acquire(blocking=False))
# Ensure that the load time is greater than the set time ensuring that
# both loads had to wait for the event
self.assertLessEqual(current_time_ns, t1_output["time"])
self.assertLessEqual(current_time_ns, t2_output["time"])
self.assertEqual("value", t1_output["value"])
self.assertEqual("value", t2_output["value"])
self.assertEqual("value", cache.peek("key"))
def test_concurrent_get_superseded_by_put(self):
load_happening = threading.Event()
finish_loading = threading.Event()
def wait_for_event(key):
load_happening.set()
finish_loading.wait()
return "value"
cache = StateCache(5 << 20)
def load_key(output):
output["value"] = cache.get("key", wait_for_event)
t1_output = {}
t1 = threading.Thread(
target=load_key, args=(t1_output, ))
t1.start()
# Wait for the load to start, update the key, and then let the load finish
load_happening.wait()
cache.put("key", "value2")
finish_loading.set()
t1.join()
# Ensure that the original value is loaded and returned and not the
# updated value
self.assertEqual("value", t1_output["value"])
# Ensure that the updated value supersedes the loaded value.
self.assertEqual("value2", cache.peek("key"))
def test_is_cached_enabled(self):
cache = StateCache(1 << 20)
self.assertEqual(cache.is_cache_enabled(), True)
self.assertEqual(
cache.describe_stats(),
(
'used/max 0/1 MB, hit 100.00%, lookups 0, '
'avg load time 0 ns, loads 0, evictions 0'))
cache = StateCache(0)
self.assertEqual(cache.is_cache_enabled(), False)
self.assertEqual(
cache.describe_stats(),
(
'used/max 0/0 MB, hit 100.00%, lookups 0, '
'avg load time 0 ns, loads 0, evictions 0'))
def test_get_referents_for_cache(self):
class GetReferentsForCache(CacheAware):
def __init__(self):
self.measure_me = bytearray(1 << 20)
self.ignore_me = bytearray(2 << 20)
def get_referents_for_cache(self):
return [self.measure_me]
cache = StateCache(5 << 20)
cache.put("key", GetReferentsForCache())
self.assertEqual(
cache.describe_stats(),
(
'used/max 1/5 MB, hit 100.00%, lookups 0, '
'avg load time 0 ns, loads 0, evictions 0'))
def test_get_deep_size_builtin_objects(self):
"""
`statecache.get_deep_copy` should work same with objsize unless the `objs`
has `CacheAware` or a filtered object. They should return the same size for
built-in objects.
"""
primitive_test_objects = [
1, # int
2.0, # float
1 + 1j, # complex
True, # bool
'hello,world', # str
b'\00\01\02', # bytes
]
collection_test_objects = [
[3, 4, 5], # list
(6, 7), # tuple
{'a', 'b', 'c'}, # set
{
'k': 8, 'l': 9
}, # dict
]
for obj in primitive_test_objects:
self.assertEqual(
get_deep_size(obj),
objsize.get_deep_size(obj),
f'different size for obj: `{obj}`, type: {type(obj)}')
self.assertEqual(
get_deep_size(obj),
sys.getsizeof(obj),
f'different size for obj: `{obj}`, type: {type(obj)}')
for obj in collection_test_objects:
self.assertEqual(
get_deep_size(obj),
objsize.get_deep_size(obj),
f'different size for obj: `{obj}`, type: {type(obj)}')
def test_current_weight_between_get_and_put(self):
value = 1234567
get_cache = StateCache(100)
get_cache.get("key", lambda k: value)
put_cache = StateCache(100)
put_cache.put("key", value)
self.assertEqual(get_cache._current_weight, put_cache._current_weight)
if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
unittest.main()