blob: 42d0804665e253f7be6c85717b8332a3f1ed5bdc [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Tests for sql_chain module."""
# pytype: skip-file
import unittest
from unittest.mock import patch
import pytest
import apache_beam as beam
from apache_beam.runners.interactive import interactive_environment as ie
from apache_beam.runners.interactive.sql.sql_chain import SqlChain
from apache_beam.runners.interactive.sql.sql_chain import SqlNode
from apache_beam.runners.interactive.testing.mock_ipython import mock_get_ipython
class SqlChainTest(unittest.TestCase):
def test_init(self):
chain = SqlChain()
self.assertEqual({}, chain.nodes)
self.assertIsNone(chain.root)
self.assertIsNone(chain.current)
self.assertIsNone(chain.user_pipeline)
def test_append_first_node(self):
node = SqlNode(output_name='first', source='a', query='q1')
chain = SqlChain().append(node)
self.assertIs(node, chain.get(node.output_name))
self.assertIs(node, chain.root)
self.assertIs(node, chain.current)
def test_append_non_root_node(self):
chain = SqlChain().append(
SqlNode(output_name='root', source='root', query='q1'))
self.assertIsNone(chain.root.next)
node = SqlNode(output_name='next_node', source='root', query='q2')
chain.append(node)
self.assertIs(node, chain.root.next)
self.assertIs(node, chain.get(node.output_name))
@patch(
'apache_beam.runners.interactive.sql.sql_chain.SchemaLoadedSqlTransform.'
'__rrshift__')
def test_to_pipeline_only_evaluate_once_per_pipeline_and_node(
self, mocked_sql_transform):
p = beam.Pipeline()
ie.current_env().watch({'p': p})
pcoll_1 = p | 'create pcoll_1' >> beam.Create([1, 2, 3])
pcoll_2 = p | 'create pcoll_2' >> beam.Create([4, 5, 6])
ie.current_env().watch({'pcoll_1': pcoll_1, 'pcoll_2': pcoll_2})
node = SqlNode(
output_name='root', source={'pcoll_1', 'pcoll_2'}, query='q1')
chain = SqlChain(user_pipeline=p).append(node)
_ = chain.to_pipeline()
mocked_sql_transform.assert_called_once()
_ = chain.to_pipeline()
mocked_sql_transform.assert_called_once()
@unittest.skipIf(
not ie.current_env().is_interactive_ready,
'[interactive] dependency is not installed.')
@pytest.mark.skipif(
not ie.current_env().is_interactive_ready,
reason='[interactive] dependency is not installed.')
@patch(
'apache_beam.runners.interactive.sql.sql_chain.SchemaLoadedSqlTransform.'
'__rrshift__')
def test_nodes_with_same_outputs(self, mocked_sql_transform):
p = beam.Pipeline()
ie.current_env().watch({'p_nodes_with_same_output': p})
pcoll = p | 'create pcoll' >> beam.Create([1, 2, 3])
ie.current_env().watch({'pcoll': pcoll})
chain = SqlChain(user_pipeline=p)
output_name = 'output'
with patch('IPython.get_ipython', new_callable=mock_get_ipython) as cell:
with cell:
node_cell_1 = SqlNode(output_name, source='pcoll', query='q1')
chain.append(node_cell_1)
_ = chain.to_pipeline()
mocked_sql_transform.assert_called_with(
'schema_loaded_beam_sql_output_1')
with cell:
node_cell_2 = SqlNode(output_name, source='pcoll', query='q2')
chain.append(node_cell_2)
_ = chain.to_pipeline()
mocked_sql_transform.assert_called_with(
'schema_loaded_beam_sql_output_2')
if __name__ == '__main__':
unittest.main()