blob: 27ccb2421844d0c27a867d34b3bb0b88a32e30b3 [file]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Unit tests for byte-limited queue."""
import queue
import threading
import time
import unittest
from apache_beam.utils.byte_limited_queue import ByteLimitedQueue
class ByteLimitedQueueTest(unittest.TestCase):
def test_unbounded(self):
bq = ByteLimitedQueue()
for i in range(201):
bq.put(str(i), i)
self.assertEqual(bq.byte_size(), sum(range(201)))
self.assertEqual(bq.qsize(), 201)
def test_put_and_get(self):
bq = ByteLimitedQueue(maxbytes=200)
bq.put('50', 50)
bq.put('140', 140)
self.assertEqual(bq.byte_size(), 190)
self.assertEqual(bq.qsize(), 2)
# Putting another would exceed 200.
with self.assertRaises(queue.Full):
bq.put('20', 20, block=False)
bq.put('10', 10, block=False)
self.assertEqual(bq.byte_size(), 200)
self.assertEqual(bq.qsize(), 3)
self.assertEqual(bq.get(), '50')
self.assertEqual(bq.byte_size(), 150)
self.assertEqual(bq.qsize(), 2)
bq.put('20', 20, block=False)
def test_dual_limit(self):
# Queue limits: at most 3 items, OR at most 100 item bytes.
bq = ByteLimitedQueue(maxsize=3, maxbytes=100)
bq.put('30', 30)
bq.put('40', 40)
bq.put('20', 20)
self.assertEqual(bq.byte_size(), 90)
self.assertEqual(bq.qsize(), 3)
# Full on element count (size=3).
with self.assertRaises(queue.Full):
bq.put('10', 10, block=False)
self.assertEqual(bq.get(), '30')
self.assertEqual(bq.get(), '40')
bq.put('10', 10)
# Full on byte count
with self.assertRaises(queue.Full):
bq.put('90', 90, block=False)
self.assertEqual(bq.get(), '20')
bq.put('90', 90, block=False)
def test_multithreading(self):
bq = ByteLimitedQueue(maxsize=0, maxbytes=100)
received = []
def producer():
for i in range(101):
bq.put(str(i), i)
poison_pill = 'POISON'
def consumer():
while True:
item = bq.get()
if item == poison_pill:
break
received.append(int(item))
t1 = threading.Thread(target=producer)
t2 = threading.Thread(target=producer)
t3 = threading.Thread(target=consumer)
t1.start()
t2.start()
t3.start()
t1.join()
t2.join()
bq.put(poison_pill, 0)
t3.join()
self.assertEqual(len(received), 202)
self.assertEqual(sum(received), 2 * sum(range(101)))
def test_put_timeout(self):
bq = ByteLimitedQueue(maxsize=0, maxbytes=10)
bq.put('10', 10)
# The queue is completely full. A timeout put should raise queue.Full.
with self.assertRaises(queue.Full):
bq.put('5', 5, timeout=0.01)
def delayed_consumer():
time.sleep(0.05)
bq.get()
# Start a thread that will free up space after 50ms.
t = threading.Thread(target=delayed_consumer)
t.start()
# The put should succeed once the consumer runs, use a high timeout to
# flakiness.
bq.put('item', 5, timeout=60)
t.join()
def test_get_timeout(self):
bq = ByteLimitedQueue(maxsize=0, maxbytes=100)
with self.assertRaises(queue.Empty):
bq.get(block=False)
with self.assertRaises(queue.Empty):
bq.get(timeout=0.0)
with self.assertRaises(queue.Empty):
bq.get(timeout=.01)
bq.put('1', 1)
self.assertEqual('1', bq.get(timeout=0))
bq.put('2', 2)
self.assertEqual('2', bq.get(timeout=0.1))
def delayed_producer():
time.sleep(0.05)
bq.put('3', 3)
# Start a thread that will produce soon
t = threading.Thread(target=delayed_producer)
t.start()
# The get should succeed once the produer runs, use a high timeout to
# flakiness.
self.assertEqual('3', bq.get(timeout=60))
t.join()
def test_negative_timeout(self):
bq = ByteLimitedQueue()
# Putting an item with a negative timeout should raise ValueError.
with self.assertRaises(ValueError):
bq.put('5', 5, timeout=-1)
with self.assertRaises(ValueError):
bq.get(timeout=-1)
def test_single_element_override(self):
bq = ByteLimitedQueue(maxbytes=10)
# An item of size 50 exceeds maxbytes 10, but should be admitted
# immediately without blocking since the queue is currently empty!
bq.put('50', 50, block=False)
self.assertEqual(bq.qsize(), 1)
self.assertEqual(bq.byte_size(), 50)
def test_fairness(self):
bq = ByteLimitedQueue(maxbytes=10)
# Put an initial item so that the queue is not empty,
# causing the subsequent large item to block.
bq.put('first', 2)
self.assertEqual(bq.blocked_byte_size(), 0)
def producer(item, size):
bq.put(item, size)
# Add an item in a background thread that should block due to exceeding
# the limit.
t1 = threading.Thread(target=producer, args=('too_large', 9))
t1.start()
# Wait until the background write is queued.
while bq.blocked_byte_size() < 1:
time.sleep(0.005)
self.assertEqual(bq.blocked_byte_size(), 9)
# Add smaller items afterwards.
t2 = threading.Thread(target=producer, args=('small1', 1))
t2.start()
while bq.blocked_byte_size() < 10:
time.sleep(0.005)
self.assertEqual(bq.blocked_byte_size(), 10)
t3 = threading.Thread(target=producer, args=('small2', 1))
t3.start()
while bq.blocked_byte_size() < 11:
time.sleep(0.005)
self.assertEqual(bq.blocked_byte_size(), 11)
# Verify all items are received in order.
self.assertEqual(bq.get(), 'first')
t1.join()
t2.join()
self.assertEqual(bq.get(), 'too_large')
t3.join()
self.assertEqual(bq.get(), 'small1')
self.assertEqual(bq.get(), 'small2')
def test_blocked_waiter_timeout_multiple(self):
bq = ByteLimitedQueue(maxbytes=10)
bq.put('initial', 5)
status = []
lock = threading.Lock()
def producer(name, size, timeout_val):
try:
bq.put(name, size, timeout=timeout_val)
with lock:
status.append((name, 'success'))
except queue.Full:
with lock:
status.append((name, 'timeout'))
threads = []
threads.append(threading.Thread(target=producer, args=('t1', 8, 0.2)))
threads.append(threading.Thread(target=producer, args=('t2', 8, 60.0)))
threads.append(threading.Thread(target=producer, args=('t3', 3, 0.1)))
threads.append(threading.Thread(target=producer, args=('t4', 3, 60.0)))
threads.append(threading.Thread(target=producer, args=('t5', 3, 0.1)))
for t in threads:
t.start()
# Wait for the short-timeout threads.
threads[4].join()
threads[2].join()
threads[0].join()
# Now waiting writers should just be t1 and t3
self.assertEqual(bq.blocked_byte_size(), 11)
self.assertEqual(bq.get(), 'initial')
threads[1].join()
self.assertGreater(bq.blocked_byte_size(), 0)
elem = bq.get()
self.assertTrue(elem == 't2' or elem == 't4')
threads[3].join()
self.assertEqual(bq.blocked_byte_size(), 0)
elem = bq.get()
self.assertTrue(elem == 't2' or elem == 't4')
with lock:
self.assertIn(('t1', 'timeout'), status)
self.assertIn(('t2', 'success'), status)
self.assertIn(('t3', 'timeout'), status)
self.assertIn(('t4', 'success'), status)
self.assertIn(('t5', 'timeout'), status)
if __name__ == '__main__':
unittest.main()