blob: 1dc42e0bf9594c98efb3963d1b2f118f347f13d1 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.
"""Module of beam_sql cell magic that executes a Beam SQL.
Only works within an IPython kernel.
import importlib
import keyword
import logging
from typing import Dict
from typing import Optional
from typing import Tuple
from typing import Union
import apache_beam as beam
from apache_beam.pvalue import PValue
from apache_beam.runners.interactive import cache_manager as cache
from apache_beam.runners.interactive import interactive_beam as ib
from apache_beam.runners.interactive import interactive_environment as ie
from apache_beam.runners.interactive import pipeline_instrument as inst
from apache_beam.runners.interactive.cache_manager import FileBasedCacheManager
from apache_beam.runners.interactive.caching.streaming_cache import StreamingCache
from apache_beam.runners.interactive.sql.utils import find_pcolls
from apache_beam.runners.interactive.sql.utils import is_namedtuple
from apache_beam.runners.interactive.sql.utils import pcolls_by_name
from apache_beam.runners.interactive.sql.utils import register_coder_for_schema
from apache_beam.runners.interactive.sql.utils import replace_single_pcoll_token
from apache_beam.runners.interactive.utils import obfuscate
from apache_beam.runners.interactive.utils import progress_indicated
from apache_beam.testing import test_stream
from apache_beam.testing.test_stream_service import TestStreamServiceController
from apache_beam.transforms.sql import SqlTransform
from IPython.core.magic import Magics
from IPython.core.magic import cell_magic
from IPython.core.magic import magics_class
_LOGGER = logging.getLogger(__name__)
_EXAMPLE_USAGE = """Usage:
%%%%beam_sql [output_name]
Calcite SQL statement
Please make sure that there is no conflicts between your variable names and
the SQL keywords, such as "SELECT", "FROM", "WHERE" and etc.
output_name is optional. If not supplied, a variable name is automatically
assigned to the output of the magic.
The output of the magic is usually a PCollection or similar PValue,
depending on the SQL statement executed.
def on_error(error_msg, *args):
"""Logs the error and the usage example."""
_LOGGER.error(error_msg, *args)
class BeamSqlMagics(Magics):
def beam_sql(self, line: str, cell: str) -> Union[None, PValue]:
"""The beam_sql cell magic that executes a Beam SQL.
line: (optional) the string on the same line after the beam_sql magic.
Used as the output variable name in the __main__ module.
cell: everything else in the same notebook cell as a string. Used as a
Beam SQL query.
Returns None if running into an error, otherwise a PValue as if a
SqlTransform is applied.
if line and not line.strip().isidentifier() or keyword.iskeyword(
'The output_name "%s" is not a valid identifier. Please supply a '
'valid identifier that is not a Python keyword.',
if not cell or cell.isspace():
on_error('Please supply the sql to be executed.')
found = find_pcolls(cell, pcolls_by_name())
for _, pcoll in found.items():
if not is_namedtuple(pcoll.element_type):
'PCollection %s of type %s is not a NamedTuple. See '
' '
'for more details.',
# TODO(BEAM-10708): implicitly execute the pipeline and write output into
# cache.
return apply_sql(cell, line, found)
def apply_sql(
query: str, output_name: Optional[str],
found: Dict[str, beam.PCollection]) -> PValue:
"""Applies a SqlTransform with the given sql and queried PCollections.
query: The SQL query executed in the magic.
output_name: (optional) The output variable name in __main__ module.
found: The PCollections with variable names found to be used in the query.
A PValue, mostly a PCollection, depending on the query.
output_name = _generate_output_name(output_name, query, found)
query, sql_source = _build_query_components(query, found)
output = sql_source | SqlTransform(query)
# Declare a variable with the output_name and output value in the
# __main__ module so that the user can use the output smoothly.
setattr(importlib.import_module('__main__'), output_name, output){output_name: output})
"The output PCollection variable is %s: %s", output_name, output)
return output
except (KeyboardInterrupt, SystemExit):
except Exception as e:
on_error('Error when applying the Beam SQL: %s', e)
def pcoll_from_file_cache(
query_pipeline: beam.Pipeline,
pcoll: beam.PCollection,
cache_manager: FileBasedCacheManager,
key: str) -> beam.PCollection:
"""Reads PCollection cache from files.
query_pipeline: The beam.Pipeline object built by the magic to execute the
SQL query.
pcoll: The PCollection to read cache for.
cache_manager: The file based cache manager that holds the PCollection
key: The key of the PCollection cache.
A PCollection read from the cache.
schema = pcoll.element_type
class Unreify(beam.DoFn):
def process(self, e):
if isinstance(e, beam.Row) and hasattr(e, 'windowed_value'):
yield e.windowed_value
return (
'{}{}'.format('QuerySource', key) >> cache.ReadCache(cache_manager, key)
| '{}{}'.format('Unreify', key) >> beam.ParDo(
def pcolls_from_streaming_cache(
user_pipeline: beam.Pipeline,
query_pipeline: beam.Pipeline,
name_to_pcoll: Dict[str, beam.PCollection],
instrumentation: inst.PipelineInstrument,
cache_manager: StreamingCache) -> Dict[str, beam.PCollection]:
"""Reads PCollection cache through the TestStream.
user_pipeline: The beam.Pipeline object defined by the user in the
query_pipeline: The beam.Pipeline object built by the magic to execute the
SQL query.
name_to_pcoll: PCollections with variable names used in the SQL query.
instrumentation: A pipeline_instrument.PipelineInstrument that helps
calculate the cache key of a given PCollection.
cache_manager: The streaming cache manager that holds the PCollection cache.
A Dict[str, beam.PCollection], where each PCollection is tagged with
their PCollection variable names, read from the cache.
When the user_pipeline has unbounded sources, we force all cache reads to go
through the TestStream even if they are bounded sources.
def exception_handler(e):
return True
test_stream_service = ie.current_env().get_test_stream_service_controller(
if not test_stream_service:
test_stream_service = TestStreamServiceController(
cache_manager, exception_handler=exception_handler)
user_pipeline, test_stream_service)
tag_to_name = {}
for name, pcoll in name_to_pcoll.items():
key = instrumentation.cache_key(pcoll)
tag_to_name[key] = name
output_pcolls = query_pipeline | test_stream.TestStream(
sql_source = {}
for tag, output in output_pcolls.items():
name = tag_to_name[tag]
# Must mark the element_type to avoid introducing pickled Python coder
# to the Java expansion service.
output.element_type = name_to_pcoll[name].element_type
sql_source[name] = output
return sql_source
def _generate_output_name(
output_name: Optional[str], query: str,
found: Dict[str, beam.PCollection]) -> str:
"""Generates a unique output name if None is provided.
Otherwise, returns the given output name directly.
The generated output name is sql_output_{uuid} where uuid is an obfuscated
value from the query and PCollections found to be used in the query.
if not output_name:
execution_id = obfuscate(query, found)[:12]
output_name = 'sql_output_' + execution_id
return output_name
def _build_query_components(
query: str, found: Dict[str, beam.PCollection]
) -> Tuple[str,
Union[Dict[str, beam.PCollection], beam.PCollection, beam.Pipeline]]:
"""Builds necessary components needed to apply the SqlTransform.
query: The SQL query to be executed by the magic.
found: The PCollections with variable names found to be used by the query.
The processed query to be executed by the magic and a source to apply the
SqlTransform to: a dictionary of tagged PCollections, or a single
PCollection, or the pipeline to execute the query.
if found:
user_pipeline = next(iter(found.values())).pipeline
cache_manager = ie.current_env().get_cache_manager(user_pipeline)
instrumentation = inst.build_pipeline_instrument(user_pipeline)
sql_pipeline = beam.Pipeline(options=user_pipeline._options)
ie.current_env().add_derived_pipeline(user_pipeline, sql_pipeline)
sql_source = {}
if instrumentation.has_unbounded_sources:
sql_source = pcolls_from_streaming_cache(
user_pipeline, sql_pipeline, found, instrumentation, cache_manager)
for pcoll_name, pcoll in found.items():
cache_key = instrumentation.cache_key(pcoll)
sql_source[pcoll_name] = pcoll_from_file_cache(
sql_pipeline, pcoll, cache_manager, cache_key)
if len(sql_source) == 1:
query = replace_single_pcoll_token(query, next(iter(sql_source.keys())))
sql_source = next(iter(sql_source.values()))
sql_source = beam.Pipeline()
return query, sql_source
def load_ipython_extension(ipython):
"""Marks this module as an IPython extension.
To load this magic in an IPython environment, execute:
%load_ext apache_beam.runners.interactive.sql.beam_sql_magics.