blob: 854aec078ce52bfb1dd303142b058a9f12559433 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""Tests for transforms that use the SQL Expansion service."""
# pytype: skip-file
import logging
import typing
import unittest
import pytest
import apache_beam as beam
from apache_beam import coders
from apache_beam.options.pipeline_options import StandardOptions
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.util import assert_that
from apache_beam.testing.util import equal_to
from apache_beam.transforms.sql import SqlTransform
SimpleRow = typing.NamedTuple(
"SimpleRow", [("id", int), ("str", str), ("flt", float)])
coders.registry.register_coder(SimpleRow, coders.RowCoder)
Enrich = typing.NamedTuple("Enrich", [("id", int), ("metadata", str)])
coders.registry.register_coder(Enrich, coders.RowCoder)
Shopper = typing.NamedTuple(
"Shopper", [("shopper", str), ("cart", typing.Mapping[str, int])])
coders.registry.register_coder(Shopper, coders.RowCoder)
@pytest.mark.xlang_sql_expansion_service
@unittest.skipIf(
TestPipeline().get_pipeline_options().view_as(StandardOptions).runner is
None,
"Must be run with a runner that supports staging java artifacts.")
class SqlTransformTest(unittest.TestCase):
"""Tests that exercise the cross-language SqlTransform (implemented in java).
Note this test must be executed with pipeline options that run jobs on a local
job server. The easiest way to accomplish this is to run the
`validatesCrossLanguageRunnerPythonUsingSql` gradle target for a particular
job server, which will start the runner and job server for you. For example,
`:runners:flink:1.13:job-server:validatesCrossLanguageRunnerPythonUsingSql` to
test on Flink 1.13.
Alternatively, you may be able to iterate faster if you run the tests directly
using a runner like `FlinkRunner`, which can start a local Flink cluster and
job server for you:
$ pip install -e './sdks/python[gcp,test]'
$ pytest apache_beam/transforms/sql_test.py \\
--test-pipeline-options="--runner=FlinkRunner"
"""
_multiprocess_can_split_ = True
def test_generate_data(self):
with TestPipeline() as p:
out = p | SqlTransform(
"""SELECT
CAST(1 AS INT) AS `id`,
CAST('foo' AS VARCHAR) AS `str`,
CAST(3.14 AS DOUBLE) AS `flt`""")
assert_that(out, equal_to([(1, "foo", 3.14)]))
def test_project(self):
with TestPipeline() as p:
out = (
p | beam.Create([SimpleRow(1, "foo", 3.14)])
| SqlTransform("SELECT `id`, `flt` FROM PCOLLECTION"))
assert_that(out, equal_to([(1, 3.14)]))
def test_filter(self):
with TestPipeline() as p:
out = (
p
| beam.Create([SimpleRow(1, "foo", 3.14), SimpleRow(2, "bar", 1.414)])
| SqlTransform("SELECT * FROM PCOLLECTION WHERE `str` = 'bar'"))
assert_that(out, equal_to([(2, "bar", 1.414)]))
def test_agg(self):
with TestPipeline() as p:
out = (
p
| beam.Create([
SimpleRow(1, "foo", 1.),
SimpleRow(1, "foo", 2.),
SimpleRow(1, "foo", 3.),
SimpleRow(2, "bar", 1.414),
SimpleRow(2, "bar", 1.414),
SimpleRow(2, "bar", 1.414),
SimpleRow(2, "bar", 1.414),
])
| SqlTransform(
"""
SELECT
`str`,
COUNT(*) AS `count`,
SUM(`id`) AS `sum`,
AVG(`flt`) AS `avg`
FROM PCOLLECTION GROUP BY `str`"""))
assert_that(out, equal_to([("foo", 3, 3, 2), ("bar", 4, 8, 1.414)]))
def test_tagged_join(self):
with TestPipeline() as p:
enrich = (
p | "Create enrich" >> beam.Create(
[Enrich(1, "a"), Enrich(2, "b"), Enrich(26, "z")]))
simple = (
p | "Create simple" >> beam.Create([
SimpleRow(1, "foo", 3.14),
SimpleRow(26, "bar", 1.11),
SimpleRow(1, "baz", 2.34)
]))
out = ({
'simple': simple, 'enrich': enrich
}
| SqlTransform(
"""
SELECT
simple.`id` AS `id`,
enrich.metadata AS metadata
FROM simple
JOIN enrich
ON simple.`id` = enrich.`id`"""))
assert_that(out, equal_to([(1, "a"), (26, "z"), (1, "a")]))
def test_row(self):
with TestPipeline() as p:
out = (
p
| beam.Create([1, 2, 10])
| beam.Map(lambda x: beam.Row(a=x, b=str(x)))
| SqlTransform("SELECT a*a as s, LENGTH(b) AS c FROM PCOLLECTION"))
assert_that(out, equal_to([(1, 1), (4, 1), (100, 2)]))
def test_zetasql_generate_data(self):
with TestPipeline() as p:
out = p | SqlTransform(
"""SELECT
CAST(1 AS INT64) AS `int`,
CAST('foo' AS STRING) AS `str`,
CAST(3.14 AS FLOAT64) AS `flt`""",
dialect="zetasql")
assert_that(out, equal_to([(1, "foo", 3.14)]))
def test_windowing_before_sql(self):
with TestPipeline() as p:
out = (
p | beam.Create([
SimpleRow(5, "foo", 1.),
SimpleRow(15, "bar", 2.),
SimpleRow(25, "baz", 3.)
])
| beam.Map(lambda v: beam.window.TimestampedValue(v, v.id)).
with_output_types(SimpleRow)
| beam.WindowInto(
beam.window.FixedWindows(10)).with_output_types(SimpleRow)
| SqlTransform("SELECT COUNT(*) as `count` FROM PCOLLECTION"))
assert_that(out, equal_to([(1, ), (1, ), (1, )]))
def test_map(self):
with TestPipeline() as p:
out = (
p
| beam.Create([
Shopper('bob', {
'bananas': 6, 'cherries': 3
}),
Shopper('alice', {
'apples': 2, 'bananas': 3
})
]).with_output_types(Shopper)
| SqlTransform("SELECT * FROM PCOLLECTION WHERE shopper = 'alice'"))
assert_that(out, equal_to([('alice', {'apples': 2, 'bananas': 3})]))
if __name__ == "__main__":
logging.getLogger().setLevel(logging.INFO)
unittest.main()