blob: aa624d1bb559d80becd02fd6ed67b92a3051a9be [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import logging
import os
import tempfile
import unittest
from unittest import TestCase
from liminal.build import liminal_apps_builder
from liminal.kubernetes import volume_util
from liminal.runners.airflow.tasks import python
from tests.util import dag_test_utils
class TestPythonTask(TestCase):
_VOLUME_NAME = 'myvol1'
def setUp(self) -> None:
volume_util.delete_local_volume(self._VOLUME_NAME)
self.temp_dir = tempfile.mkdtemp()
liminal_apps_builder.build_liminal_apps(
os.path.join(os.path.dirname(__file__), '../liminal'))
def test_apply_task_to_dag(self):
dag = dag_test_utils.create_dag()
task0 = self.__create_python_task(dag,
'my_input_task',
None,
'my_python_task_img',
'python -u write_inputs.py',
env_vars={
'NUM_FILES': 10,
'NUM_SPLITS': 3
})
task0.apply_task_to_dag()
task1 = self.__create_python_task(dag,
'my_output_task',
dag.tasks[0],
'my_parallelized_python_task_img',
'python -u write_outputs.py',
executors=3)
task1.apply_task_to_dag()
for task in dag.tasks:
logging.info(f'Executing task {task.task_id}')
task.execute({})
inputs_dir = os.path.join(self.temp_dir, 'inputs')
outputs_dir = os.path.join(self.temp_dir, 'outputs')
self.assertListEqual(os.listdir(self.temp_dir), ['outputs', 'inputs'])
inputs_dir_contents = os.listdir(inputs_dir)
self.assertListEqual(inputs_dir_contents, ['0', '1', '2'])
self.assertListEqual(sorted(os.listdir(os.path.join(inputs_dir, '0'))),
['input0.json', 'input3.json', 'input6.json', 'input9.json'])
self.assertListEqual(sorted(os.listdir(os.path.join(inputs_dir, '1'))),
['input1.json', 'input4.json', 'input7.json'])
self.assertListEqual(sorted(os.listdir(os.path.join(inputs_dir, '2'))),
['input2.json', 'input5.json', 'input8.json'])
self.assertListEqual(sorted(os.listdir(outputs_dir)),
['output0.txt', 'output1.txt',
'output2.txt', 'output3.txt',
'output4.txt', 'output5.txt',
'output6.txt', 'output7.txt',
'output8.txt', 'output9.txt'])
for filename in os.listdir(outputs_dir):
with open(os.path.join(outputs_dir, filename)) as f:
expected_file_content = filename.replace('output', 'myval').replace('.txt', '')
self.assertEqual(f.read(), expected_file_content)
def __create_python_task(self,
dag,
task_id,
parent,
image,
cmd,
env_vars=None,
executors=None):
task_config = {
'task': task_id,
'cmd': cmd,
'image': image,
'env_vars': env_vars if env_vars is not None else {},
'mounts': [
{
'mount': 'mymount',
'volume': self._VOLUME_NAME,
'path': '/mnt/vol1'
}
]
}
if executors:
task_config['executors'] = executors
return python.PythonTask(dag=dag,
liminal_config={
'volumes': [
{
'volume': self._VOLUME_NAME,
'local': {
'path': self.temp_dir.replace(
"/var/folders",
"/private/var/folders"
)
}
}
]},
pipeline_config={
'pipeline': 'my_pipeline'
},
task_config=task_config,
parent=parent,
trigger_rule='all_success')
if __name__ == '__main__':
unittest.main()