blob: d5747b03a91964eca4b39e21d986cb4465cddcee [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Tests for utils module."""
# pytype: skip-file
import unittest
from typing import NamedTuple
from typing import Optional
from typing import Union
from unittest.mock import patch
import pytest
import apache_beam as beam
from apache_beam.options.pipeline_options import GoogleCloudOptions
from apache_beam.options.pipeline_options import SetupOptions
from apache_beam.runners.interactive import interactive_environment as ie
from apache_beam.runners.interactive.sql.utils import DataflowOptionsForm
from apache_beam.runners.interactive.sql.utils import find_pcolls
from apache_beam.runners.interactive.sql.utils import pformat_dict
from apache_beam.runners.interactive.sql.utils import pformat_namedtuple
from apache_beam.runners.interactive.sql.utils import register_coder_for_schema
from apache_beam.runners.interactive.sql.utils import replace_single_pcoll_token
class ANamedTuple(NamedTuple):
a: int
b: str
class OptionalUnionType(NamedTuple):
unnamed: Optional[Union[int, str]]
class UtilsTest(unittest.TestCase):
def test_register_coder_for_schema(self):
self.assertNotIsInstance(
beam.coders.registry.get_coder(ANamedTuple), beam.coders.RowCoder)
register_coder_for_schema(ANamedTuple)
self.assertIsInstance(
beam.coders.registry.get_coder(ANamedTuple), beam.coders.RowCoder)
def test_find_pcolls(self):
with patch('apache_beam.runners.interactive.interactive_beam.collect',
lambda _: None):
found = find_pcolls(
"""SELECT * FROM pcoll_1 JOIN pcoll_2
USING (common_column)""", {
'pcoll_1': None, 'pcoll_2': None
})
self.assertIn('pcoll_1', found)
self.assertIn('pcoll_2', found)
def test_replace_single_pcoll_token(self):
sql = 'SELECT * FROM abc WHERE a=1 AND b=2'
replaced_sql = replace_single_pcoll_token(sql, 'wow')
self.assertEqual(replaced_sql, sql)
replaced_sql = replace_single_pcoll_token(sql, 'abc')
self.assertEqual(
replaced_sql, 'SELECT * FROM PCOLLECTION WHERE a=1 AND b=2')
def test_pformat_namedtuple(self):
actual = pformat_namedtuple(ANamedTuple)
self.assertEqual("ANamedTuple(a: <class 'int'>, b: <class 'str'>)", actual)
def test_pformat_namedtuple_with_unnamed_fields(self):
actual = pformat_namedtuple(OptionalUnionType)
# Parameters of an Union type can be in any order.
possible_expected = (
'OptionalUnionType(unnamed: typing.Union[int, str, NoneType])',
'OptionalUnionType(unnamed: typing.Union[str, int, NoneType])')
self.assertIn(actual, possible_expected)
def test_pformat_dict(self):
actual = pformat_dict({'a': 1, 'b': '2'})
self.assertEqual('{\na: 1,\nb: 2\n}', actual)
@unittest.skipIf(
not ie.current_env().is_interactive_ready,
'[interactive] dependency is not installed.')
@pytest.mark.skipif(
not ie.current_env().is_interactive_ready,
reason='[interactive] dependency is not installed.')
class OptionsFormTest(unittest.TestCase):
def test_dataflow_options_form(self):
p = beam.Pipeline()
pcoll = p | beam.Create([1, 2, 3])
with patch('google.auth') as ga:
ga.default = lambda: ['', 'default_project_id']
df_form = DataflowOptionsForm('pcoll', pcoll)
df_form.display_for_input()
df_form.entries[2].input.value = 'gs://test-bucket'
df_form.entries[3].input.value = 'a-pkg'
options = df_form.to_options()
cloud_options = options.view_as(GoogleCloudOptions)
self.assertEqual(cloud_options.project, 'default_project_id')
self.assertEqual(cloud_options.region, 'us-central1')
self.assertEqual(
cloud_options.staging_location, 'gs://test-bucket/staging')
self.assertEqual(cloud_options.temp_location, 'gs://test-bucket/temp')
self.assertIsNotNone(options.view_as(SetupOptions).requirements_file)
if __name__ == '__main__':
unittest.main()