blob: 68953824d8328d05302bed3f02f97de995b16ab8 [file] [log] [blame]
import unittest
from unittest import TestCase
import itertools
from liminal.runners.airflow.operators.\
kubernetes_pod_operator_with_input_output import _split_list
class TestSplitList(TestCase):
def setUp(self) -> None:
self.short_seq = [{f'task_{i}': f'value_{i}'} for i in range(3)]
self.long_seq = [{f'task_{i}': f'value_{i}'} for i in range(10)]
def test_seq_equal_num(self):
num = len(self.short_seq)
result = _split_list(self.short_seq, num)
expected = [[{'task_0': 'value_0'}], [{'task_1': 'value_1'}],
[{'task_2': 'value_2'}]]
self.assertListEqual(expected, result)
def test_seq_grater_than_num(self):
num = 3
result = _split_list(self.long_seq, num)
n_tasks = len(self.long_seq)
min_length = min([len(i) for i in result])
max_length = max([len(i) for i in result])
flat_results = list(itertools.chain(*result))
self.assertGreaterEqual(max_length - min_length, 1)
self.assertEqual(n_tasks, len(flat_results))
self.assertTrue(all([{f'task_{i}': f'value_{i}'} in flat_results
for i in range(n_tasks)]))
def test_seq_smaller_than_num(self):
test_num_range = [8, 9, 10, 11, 12]
for num in test_num_range:
result = _split_list(self.short_seq, num)
self.assertEqual(len(result), num)
self.assertTrue(all([[i] in result for i in self.short_seq]))
self.assertEqual([[]] * (num - len(self.short_seq)),
[i for i in result if i == []])