| ################################################################################ |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| ################################################################################ |
| import json |
| import logging |
| import sys |
| |
| from pyflink.common import Row |
| from pyflink.table import (DataTypes, TableEnvironment, EnvironmentSettings, ExplainDetail) |
| from pyflink.table.expressions import * |
| from pyflink.table.udf import udtf, udf, udaf, AggregateFunction, TableAggregateFunction, udtaf |
| |
| |
| def basic_operations(): |
| t_env = TableEnvironment.create(EnvironmentSettings.in_streaming_mode()) |
| |
| # define the source |
| table = t_env.from_elements( |
| elements=[ |
| (1, '{"name": "Flink", "tel": 123, "addr": {"country": "Germany", "city": "Berlin"}}'), |
| (2, '{"name": "hello", "tel": 135, "addr": {"country": "China", "city": "Shanghai"}}'), |
| (3, '{"name": "world", "tel": 124, "addr": {"country": "USA", "city": "NewYork"}}'), |
| (4, '{"name": "PyFlink", "tel": 32, "addr": {"country": "China", "city": "Hangzhou"}}') |
| ], |
| schema=['id', 'data']) |
| |
| right_table = t_env.from_elements(elements=[(1, 18), (2, 30), (3, 25), (4, 10)], |
| schema=['id', 'age']) |
| |
| table = table.add_columns( |
| col('data').json_value('$.name', DataTypes.STRING()).alias('name'), |
| col('data').json_value('$.tel', DataTypes.STRING()).alias('tel'), |
| col('data').json_value('$.addr.country', DataTypes.STRING()).alias('country')) \ |
| .drop_columns(col('data')) |
| table.execute().print() |
| # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+ |
| # | op | id | name | tel | country | |
| # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+ |
| # | +I | 1 | Flink | 123 | Germany | |
| # | +I | 2 | hello | 135 | China | |
| # | +I | 3 | world | 124 | USA | |
| # | +I | 4 | PyFlink | 32 | China | |
| # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+ |
| |
| # limit the number of outputs |
| table.limit(3).execute().print() |
| # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+ |
| # | op | id | name | tel | country | |
| # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+ |
| # | +I | 1 | Flink | 123 | Germany | |
| # | +I | 2 | hello | 135 | China | |
| # | +I | 3 | world | 124 | USA | |
| # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+ |
| |
| # filter |
| table.filter(col('id') != 3).execute().print() |
| # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+ |
| # | op | id | name | tel | country | |
| # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+ |
| # | +I | 1 | Flink | 123 | Germany | |
| # | +I | 2 | hello | 135 | China | |
| # | +I | 4 | PyFlink | 32 | China | |
| # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+ |
| |
| # aggregation |
| table.group_by(col('country')) \ |
| .select(col('country'), col('id').count, col('tel').cast(DataTypes.BIGINT()).max) \ |
| .execute().print() |
| # +----+--------------------------------+----------------------+----------------------+ |
| # | op | country | EXPR$0 | EXPR$1 | |
| # +----+--------------------------------+----------------------+----------------------+ |
| # | +I | Germany | 1 | 123 | |
| # | +I | USA | 1 | 124 | |
| # | +I | China | 1 | 135 | |
| # | -U | China | 1 | 135 | |
| # | +U | China | 2 | 135 | |
| # +----+--------------------------------+----------------------+----------------------+ |
| |
| # distinct |
| table.select(col('country')).distinct() \ |
| .execute().print() |
| # +----+--------------------------------+ |
| # | op | country | |
| # +----+--------------------------------+ |
| # | +I | Germany | |
| # | +I | China | |
| # | +I | USA | |
| # +----+--------------------------------+ |
| |
| # join |
| # Note that it still doesn't support duplicate column names between the joined tables |
| table.join(right_table.rename_columns(col('id').alias('r_id')), col('id') == col('r_id')) \ |
| .execute().print() |
| # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+----------------------+----------------------+ |
| # | op | id | name | tel | country | r_id | age | |
| # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+----------------------+----------------------+ |
| # | +I | 4 | PyFlink | 32 | China | 4 | 10 | |
| # | +I | 1 | Flink | 123 | Germany | 1 | 18 | |
| # | +I | 2 | hello | 135 | China | 2 | 30 | |
| # | +I | 3 | world | 124 | USA | 3 | 25 | |
| # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+----------------------+----------------------+ |
| |
| # join lateral |
| @udtf(result_types=[DataTypes.STRING()]) |
| def split(r: Row): |
| for s in r.name.split("i"): |
| yield s |
| |
| table.join_lateral(split.alias('a')) \ |
| .execute().print() |
| # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+--------------------------------+ |
| # | op | id | name | tel | country | a | |
| # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+--------------------------------+ |
| # | +I | 1 | Flink | 123 | Germany | Fl | |
| # | +I | 1 | Flink | 123 | Germany | nk | |
| # | +I | 2 | hello | 135 | China | hello | |
| # | +I | 3 | world | 124 | USA | world | |
| # | +I | 4 | PyFlink | 32 | China | PyFl | |
| # | +I | 4 | PyFlink | 32 | China | nk | |
| # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+--------------------------------+ |
| |
| # show schema |
| table.print_schema() |
| # ( |
| # `id` BIGINT, |
| # `name` STRING, |
| # `tel` STRING, |
| # `country` STRING |
| # ) |
| |
| # show execute plan |
| print(table.join_lateral(split.alias('a')).explain()) |
| # == Abstract Syntax Tree == |
| # LogicalCorrelate(correlation=[$cor1], joinType=[inner], requiredColumns=[{}]) |
| # :- LogicalProject(id=[$0], name=[JSON_VALUE($1, _UTF-16LE'$.name', FLAG(NULL), FLAG(ON EMPTY), FLAG(NULL), FLAG(ON ERROR))], tel=[JSON_VALUE($1, _UTF-16LE'$.tel', FLAG(NULL), FLAG(ON EMPTY), FLAG(NULL), FLAG(ON ERROR))], country=[JSON_VALUE($1, _UTF-16LE'$.addr.country', FLAG(NULL), FLAG(ON EMPTY), FLAG(NULL), FLAG(ON ERROR))]) |
| # : +- LogicalTableScan(table=[[default_catalog, default_database, Unregistered_TableSource_249535355, source: [PythonInputFormatTableSource(id, data)]]]) |
| # +- LogicalTableFunctionScan(invocation=[*org.apache.flink.table.functions.python.PythonTableFunction$1f0568d1f39bef59b4c969a5d620ba46*($0, $1, $2, $3)], rowType=[RecordType(VARCHAR(2147483647) a)], elementType=[class [Ljava.lang.Object;]) |
| # |
| # == Optimized Physical Plan == |
| # PythonCorrelate(invocation=[*org.apache.flink.table.functions.python.PythonTableFunction$1f0568d1f39bef59b4c969a5d620ba46*($0, $1, $2, $3)], correlate=[table(split(id,name,tel,country))], select=[id,name,tel,country,a], rowType=[RecordType(BIGINT id, VARCHAR(2147483647) name, VARCHAR(2147483647) tel, VARCHAR(2147483647) country, VARCHAR(2147483647) a)], joinType=[INNER]) |
| # +- Calc(select=[id, JSON_VALUE(data, _UTF-16LE'$.name', FLAG(NULL), FLAG(ON EMPTY), FLAG(NULL), FLAG(ON ERROR)) AS name, JSON_VALUE(data, _UTF-16LE'$.tel', FLAG(NULL), FLAG(ON EMPTY), FLAG(NULL), FLAG(ON ERROR)) AS tel, JSON_VALUE(data, _UTF-16LE'$.addr.country', FLAG(NULL), FLAG(ON EMPTY), FLAG(NULL), FLAG(ON ERROR)) AS country]) |
| # +- LegacyTableSourceScan(table=[[default_catalog, default_database, Unregistered_TableSource_249535355, source: [PythonInputFormatTableSource(id, data)]]], fields=[id, data]) |
| # |
| # == Optimized Execution Plan == |
| # PythonCorrelate(invocation=[*org.apache.flink.table.functions.python.PythonTableFunction$1f0568d1f39bef59b4c969a5d620ba46*($0, $1, $2, $3)], correlate=[table(split(id,name,tel,country))], select=[id,name,tel,country,a], rowType=[RecordType(BIGINT id, VARCHAR(2147483647) name, VARCHAR(2147483647) tel, VARCHAR(2147483647) country, VARCHAR(2147483647) a)], joinType=[INNER]) |
| # +- Calc(select=[id, JSON_VALUE(data, '$.name', NULL, ON EMPTY, NULL, ON ERROR) AS name, JSON_VALUE(data, '$.tel', NULL, ON EMPTY, NULL, ON ERROR) AS tel, JSON_VALUE(data, '$.addr.country', NULL, ON EMPTY, NULL, ON ERROR) AS country]) |
| # +- LegacyTableSourceScan(table=[[default_catalog, default_database, Unregistered_TableSource_249535355, source: [PythonInputFormatTableSource(id, data)]]], fields=[id, data]) |
| |
| # show execute plan with advice |
| print(table.join_lateral(split.alias('a')).explain(ExplainDetail.PLAN_ADVICE)) |
| # == Abstract Syntax Tree == |
| # LogicalCorrelate(correlation=[$cor2], joinType=[inner], requiredColumns=[{}]) |
| # :- LogicalProject(id=[$0], name=[JSON_VALUE($1, _UTF-16LE'$.name', FLAG(NULL), FLAG(ON EMPTY), FLAG(NULL), FLAG(ON ERROR))], tel=[JSON_VALUE($1, _UTF-16LE'$.tel', FLAG(NULL), FLAG(ON EMPTY), FLAG(NULL), FLAG(ON ERROR))], country=[JSON_VALUE($1, _UTF-16LE'$.addr.country', FLAG(NULL), FLAG(ON EMPTY), FLAG(NULL), FLAG(ON ERROR))]) |
| # : +- LogicalTableScan(table=[[*anonymous_python-input-format$1*]]) |
| # +- LogicalTableFunctionScan(invocation=[*org.apache.flink.table.functions.python.PythonTableFunction$720258394f6a31d31376164d23142f53*($0, $1, $2, $3)], rowType=[RecordType(VARCHAR(2147483647) a)]) |
| # |
| # == Optimized Physical Plan With Advice == |
| # PythonCorrelate(invocation=[*org.apache.flink.table.functions.python.PythonTableFunction$720258394f6a31d31376164d23142f53*($0, $1, $2, $3)], correlate=[table(*org.apache.flink.table.functions.python.PythonTableFunction$720258394f6a31d31376164d23142f53*(id,name,tel,country))], select=[id,name,tel,country,a], rowType=[RecordType(BIGINT id, VARCHAR(2147483647) name, VARCHAR(2147483647) tel, VARCHAR(2147483647) country, VARCHAR(2147483647) a)], joinType=[INNER]) |
| # +- Calc(select=[id, JSON_VALUE(data, _UTF-16LE'$.name', FLAG(NULL), FLAG(ON EMPTY), FLAG(NULL), FLAG(ON ERROR)) AS name, JSON_VALUE(data, _UTF-16LE'$.tel', FLAG(NULL), FLAG(ON EMPTY), FLAG(NULL), FLAG(ON ERROR)) AS tel, JSON_VALUE(data, _UTF-16LE'$.addr.country', FLAG(NULL), FLAG(ON EMPTY), FLAG(NULL), FLAG(ON ERROR)) AS country]) |
| # +- TableSourceScan(table=[[*anonymous_python-input-format$1*]], fields=[id, data]) |
| # |
| # No available advice... |
| # |
| # == Optimized Execution Plan == |
| # PythonCorrelate(invocation=[*org.apache.flink.table.functions.python.PythonTableFunction$720258394f6a31d31376164d23142f53*($0, $1, $2, $3)], correlate=[table(*org.apache.flink.table.functions.python.PythonTableFunction$720258394f6a31d31376164d23142f53*(id,name,tel,country))], select=[id,name,tel,country,a], rowType=[RecordType(BIGINT id, VARCHAR(2147483647) name, VARCHAR(2147483647) tel, VARCHAR(2147483647) country, VARCHAR(2147483647) a)], joinType=[INNER]) |
| # +- Calc(select=[id, JSON_VALUE(data, '$.name', NULL, ON EMPTY, NULL, ON ERROR) AS name, JSON_VALUE(data, '$.tel', NULL, ON EMPTY, NULL, ON ERROR) AS tel, JSON_VALUE(data, '$.addr.country', NULL, ON EMPTY, NULL, ON ERROR) AS country]) |
| # +- TableSourceScan(table=[[*anonymous_python-input-format$1*]], fields=[id, data]) |
| |
| def sql_operations(): |
| t_env = TableEnvironment.create(EnvironmentSettings.in_streaming_mode()) |
| |
| # define the source |
| table = t_env.from_elements( |
| elements=[ |
| (1, '{"name": "Flink", "tel": 123, "addr": {"country": "Germany", "city": "Berlin"}}'), |
| (2, '{"name": "hello", "tel": 135, "addr": {"country": "China", "city": "Shanghai"}}'), |
| (3, '{"name": "world", "tel": 124, "addr": {"country": "USA", "city": "NewYork"}}'), |
| (4, '{"name": "PyFlink", "tel": 32, "addr": {"country": "China", "city": "Hangzhou"}}') |
| ], |
| schema=['id', 'data']) |
| |
| t_env.sql_query("SELECT * FROM %s" % table) \ |
| .execute().print() |
| # +----+----------------------+--------------------------------+ |
| # | op | id | data | |
| # +----+----------------------+--------------------------------+ |
| # | +I | 1 | {"name": "Flink", "tel": 12... | |
| # | +I | 2 | {"name": "hello", "tel": 13... | |
| # | +I | 3 | {"name": "world", "tel": 12... | |
| # | +I | 4 | {"name": "PyFlink", "tel": ... | |
| # +----+----------------------+--------------------------------+ |
| |
| # execute sql statement |
| @udtf(result_types=[DataTypes.STRING(), DataTypes.INT(), DataTypes.STRING()]) |
| def parse_data(data: str): |
| json_data = json.loads(data) |
| yield json_data['name'], json_data['tel'], json_data['addr']['country'] |
| |
| t_env.create_temporary_function('parse_data', parse_data) |
| t_env.execute_sql( |
| """ |
| SELECT * |
| FROM %s, LATERAL TABLE(parse_data(`data`)) t(name, tel, country) |
| """ % table |
| ).print() |
| # +----+----------------------+--------------------------------+--------------------------------+-------------+--------------------------------+ |
| # | op | id | data | name | tel | country | |
| # +----+----------------------+--------------------------------+--------------------------------+-------------+--------------------------------+ |
| # | +I | 1 | {"name": "Flink", "tel": 12... | Flink | 123 | Germany | |
| # | +I | 2 | {"name": "hello", "tel": 13... | hello | 135 | China | |
| # | +I | 3 | {"name": "world", "tel": 12... | world | 124 | USA | |
| # | +I | 4 | {"name": "PyFlink", "tel": ... | PyFlink | 32 | China | |
| # +----+----------------------+--------------------------------+--------------------------------+-------------+--------------------------------+ |
| |
| # explain sql plan |
| print(t_env.explain_sql( |
| """ |
| SELECT * |
| FROM %s, LATERAL TABLE(parse_data(`data`)) t(name, tel, country) |
| """ % table |
| )) |
| # == Abstract Syntax Tree == |
| # LogicalProject(id=[$0], data=[$1], name=[$2], tel=[$3], country=[$4]) |
| # +- LogicalCorrelate(correlation=[$cor1], joinType=[inner], requiredColumns=[{1}]) |
| # :- LogicalTableScan(table=[[default_catalog, default_database, Unregistered_TableSource_734856049, source: [PythonInputFormatTableSource(id, data)]]]) |
| # +- LogicalTableFunctionScan(invocation=[parse_data($cor1.data)], rowType=[RecordType:peek_no_expand(VARCHAR(2147483647) f0, INTEGER f1, VARCHAR(2147483647) f2)]) |
| # |
| # == Optimized Physical Plan == |
| # PythonCorrelate(invocation=[parse_data($1)], correlate=[table(parse_data(data))], select=[id,data,f0,f1,f2], rowType=[RecordType(BIGINT id, VARCHAR(2147483647) data, VARCHAR(2147483647) f0, INTEGER f1, VARCHAR(2147483647) f2)], joinType=[INNER]) |
| # +- LegacyTableSourceScan(table=[[default_catalog, default_database, Unregistered_TableSource_734856049, source: [PythonInputFormatTableSource(id, data)]]], fields=[id, data]) |
| # |
| # == Optimized Execution Plan == |
| # PythonCorrelate(invocation=[parse_data($1)], correlate=[table(parse_data(data))], select=[id,data,f0,f1,f2], rowType=[RecordType(BIGINT id, VARCHAR(2147483647) data, VARCHAR(2147483647) f0, INTEGER f1, VARCHAR(2147483647) f2)], joinType=[INNER]) |
| # +- LegacyTableSourceScan(table=[[default_catalog, default_database, Unregistered_TableSource_734856049, source: [PythonInputFormatTableSource(id, data)]]], fields=[id, data]) |
| |
| # explain sql plan with advice |
| print(t_env.explain_sql( |
| """ |
| SELECT * |
| FROM %s, LATERAL TABLE(parse_data(`data`)) t(name, tel, country) |
| """ % table, ExplainDetail.PLAN_ADVICE |
| )) |
| # == Abstract Syntax Tree == |
| # LogicalProject(id=[$0], data=[$1], name=[$2], tel=[$3], country=[$4]) |
| # +- LogicalCorrelate(correlation=[$cor1], joinType=[inner], requiredColumns=[{1}]) |
| # :- LogicalTableScan(table=[[*anonymous_python-input-format$10*]]) |
| # +- LogicalTableFunctionScan(invocation=[parse_data($cor2.data)], rowType=[RecordType:peek_no_expand(VARCHAR(2147483647) f0, INTEGER f1, VARCHAR(2147483647) f2)]) |
| # |
| # == Optimized Physical Plan With Advice == |
| # PythonCorrelate(invocation=[parse_data($1)], correlate=[table(parse_data(data))], select=[id,data,f0,f1,f2], rowType=[RecordType(BIGINT id, VARCHAR(2147483647) data, VARCHAR(2147483647) f0, INTEGER f1, VARCHAR(2147483647) f2)], joinType=[INNER]) |
| # +- TableSourceScan(table=[[*anonymous_python-input-format$10*]], fields=[id, data]) |
| # |
| # No available advice... |
| # |
| # == Optimized Execution Plan == |
| # PythonCorrelate(invocation=[parse_data($1)], correlate=[table(parse_data(data))], select=[id,data,f0,f1,f2], rowType=[RecordType(BIGINT id, VARCHAR(2147483647) data, VARCHAR(2147483647) f0, INTEGER f1, VARCHAR(2147483647) f2)], joinType=[INNER]) |
| # +- TableSourceScan(table=[[*anonymous_python-input-format$10*]], fields=[id, data]) |
| |
| def column_operations(): |
| t_env = TableEnvironment.create(EnvironmentSettings.in_streaming_mode()) |
| |
| # define the source |
| table = t_env.from_elements( |
| elements=[ |
| (1, '{"name": "Flink", "tel": 123, "addr": {"country": "Germany", "city": "Berlin"}}'), |
| (2, '{"name": "hello", "tel": 135, "addr": {"country": "China", "city": "Shanghai"}}'), |
| (3, '{"name": "world", "tel": 124, "addr": {"country": "USA", "city": "NewYork"}}'), |
| (4, '{"name": "PyFlink", "tel": 32, "addr": {"country": "China", "city": "Hangzhou"}}') |
| ], |
| schema=['id', 'data']) |
| |
| # add columns |
| table = table.add_columns( |
| col('data').json_value('$.name', DataTypes.STRING()).alias('name'), |
| col('data').json_value('$.tel', DataTypes.STRING()).alias('tel'), |
| col('data').json_value('$.addr.country', DataTypes.STRING()).alias('country')) |
| |
| table.execute().print() |
| # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+--------------------------------+ |
| # | op | id | data | name | tel | country | |
| # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+--------------------------------+ |
| # | +I | 1 | {"name": "Flink", "tel": 12... | Flink | 123 | Germany | |
| # | +I | 2 | {"name": "hello", "tel": 13... | hello | 135 | China | |
| # | +I | 3 | {"name": "world", "tel": 12... | world | 124 | USA | |
| # | +I | 4 | {"name": "PyFlink", "tel": ... | PyFlink | 32 | China | |
| # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+--------------------------------+ |
| |
| # drop columns |
| table = table.drop_columns(col('data')) |
| table.execute().print() |
| # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+ |
| # | op | id | name | tel | country | |
| # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+ |
| # | +I | 1 | Flink | 123 | Germany | |
| # | +I | 2 | hello | 135 | China | |
| # | +I | 3 | world | 124 | USA | |
| # | +I | 4 | PyFlink | 32 | China | |
| # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+ |
| |
| # rename columns |
| table = table.rename_columns(col('tel').alias('telephone')) |
| table.execute().print() |
| # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+ |
| # | op | id | name | telephone | country | |
| # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+ |
| # | +I | 1 | Flink | 123 | Germany | |
| # | +I | 2 | hello | 135 | China | |
| # | +I | 3 | world | 124 | USA | |
| # | +I | 4 | PyFlink | 32 | China | |
| # +----+----------------------+--------------------------------+--------------------------------+--------------------------------+ |
| |
| # replace columns |
| table = table.add_or_replace_columns( |
| concat(col('id').cast(DataTypes.STRING()), '_', col('name')).alias('id')) |
| table.execute().print() |
| # +----+--------------------------------+--------------------------------+--------------------------------+--------------------------------+ |
| # | op | id | name | telephone | country | |
| # +----+--------------------------------+--------------------------------+--------------------------------+--------------------------------+ |
| # | +I | 1_Flink | Flink | 123 | Germany | |
| # | +I | 2_hello | hello | 135 | China | |
| # | +I | 3_world | world | 124 | USA | |
| # | +I | 4_PyFlink | PyFlink | 32 | China | |
| # +----+--------------------------------+--------------------------------+--------------------------------+--------------------------------+ |
| |
| |
| def row_operations(): |
| t_env = TableEnvironment.create(EnvironmentSettings.in_streaming_mode()) |
| |
| # define the source |
| table = t_env.from_elements( |
| elements=[ |
| (1, '{"name": "Flink", "tel": 123, "addr": {"country": "Germany", "city": "Berlin"}}'), |
| (2, '{"name": "hello", "tel": 135, "addr": {"country": "China", "city": "Shanghai"}}'), |
| (3, '{"name": "world", "tel": 124, "addr": {"country": "China", "city": "NewYork"}}'), |
| (4, '{"name": "PyFlink", "tel": 32, "addr": {"country": "China", "city": "Hangzhou"}}') |
| ], |
| schema=['id', 'data']) |
| |
| # map operation |
| @udf(result_type=DataTypes.ROW([DataTypes.FIELD("id", DataTypes.BIGINT()), |
| DataTypes.FIELD("country", DataTypes.STRING())])) |
| def extract_country(input_row: Row): |
| data = json.loads(input_row.data) |
| return Row(input_row.id, data['addr']['country']) |
| |
| table.map(extract_country) \ |
| .execute().print() |
| # +----+----------------------+--------------------------------+ |
| # | op | id | country | |
| # +----+----------------------+--------------------------------+ |
| # | +I | 1 | Germany | |
| # | +I | 2 | China | |
| # | +I | 3 | China | |
| # | +I | 4 | China | |
| # +----+----------------------+--------------------------------+ |
| |
| # flat_map operation |
| @udtf(result_types=[DataTypes.BIGINT(), DataTypes.STRING()]) |
| def extract_city(input_row: Row): |
| data = json.loads(input_row.data) |
| yield input_row.id, data['addr']['city'] |
| |
| table.flat_map(extract_city) \ |
| .execute().print() |
| # +----+----------------------+--------------------------------+ |
| # | op | f0 | f1 | |
| # +----+----------------------+--------------------------------+ |
| # | +I | 1 | Berlin | |
| # | +I | 2 | Shanghai | |
| # | +I | 3 | NewYork | |
| # | +I | 4 | Hangzhou | |
| # +----+----------------------+--------------------------------+ |
| |
| # aggregate operation |
| class CountAndSumAggregateFunction(AggregateFunction): |
| |
| def get_value(self, accumulator): |
| return Row(accumulator[0], accumulator[1]) |
| |
| def create_accumulator(self): |
| return Row(0, 0) |
| |
| def accumulate(self, accumulator, input_row): |
| accumulator[0] += 1 |
| accumulator[1] += int(input_row.tel) |
| |
| def retract(self, accumulator, input_row): |
| accumulator[0] -= 1 |
| accumulator[1] -= int(input_row.tel) |
| |
| def merge(self, accumulator, accumulators): |
| for other_acc in accumulators: |
| accumulator[0] += other_acc[0] |
| accumulator[1] += other_acc[1] |
| |
| def get_accumulator_type(self): |
| return DataTypes.ROW( |
| [DataTypes.FIELD("cnt", DataTypes.BIGINT()), |
| DataTypes.FIELD("sum", DataTypes.BIGINT())]) |
| |
| def get_result_type(self): |
| return DataTypes.ROW( |
| [DataTypes.FIELD("cnt", DataTypes.BIGINT()), |
| DataTypes.FIELD("sum", DataTypes.BIGINT())]) |
| |
| count_sum = udaf(CountAndSumAggregateFunction()) |
| table.add_columns( |
| col('data').json_value('$.name', DataTypes.STRING()).alias('name'), |
| col('data').json_value('$.tel', DataTypes.STRING()).alias('tel'), |
| col('data').json_value('$.addr.country', DataTypes.STRING()).alias('country')) \ |
| .group_by(col('country')) \ |
| .aggregate(count_sum.alias("cnt", "sum")) \ |
| .select(col('country'), col('cnt'), col('sum')) \ |
| .execute().print() |
| # +----+--------------------------------+----------------------+----------------------+ |
| # | op | country | cnt | sum | |
| # +----+--------------------------------+----------------------+----------------------+ |
| # | +I | China | 3 | 291 | |
| # | +I | Germany | 1 | 123 | |
| # +----+--------------------------------+----------------------+----------------------+ |
| |
| # flat_aggregate operation |
| class Top2(TableAggregateFunction): |
| |
| def emit_value(self, accumulator): |
| for v in accumulator: |
| if v: |
| yield Row(v) |
| |
| def create_accumulator(self): |
| return [None, None] |
| |
| def accumulate(self, accumulator, input_row): |
| tel = int(input_row.tel) |
| if accumulator[0] is None or tel > accumulator[0]: |
| accumulator[1] = accumulator[0] |
| accumulator[0] = tel |
| elif accumulator[1] is None or tel > accumulator[1]: |
| accumulator[1] = tel |
| |
| def get_accumulator_type(self): |
| return DataTypes.ARRAY(DataTypes.BIGINT()) |
| |
| def get_result_type(self): |
| return DataTypes.ROW( |
| [DataTypes.FIELD("tel", DataTypes.BIGINT())]) |
| |
| top2 = udtaf(Top2()) |
| table.add_columns( |
| col('data').json_value('$.name', DataTypes.STRING()).alias('name'), |
| col('data').json_value('$.tel', DataTypes.STRING()).alias('tel'), |
| col('data').json_value('$.addr.country', DataTypes.STRING()).alias('country')) \ |
| .group_by(col('country')) \ |
| .flat_aggregate(top2) \ |
| .select(col('country'), col('tel')) \ |
| .execute().print() |
| # +----+--------------------------------+----------------------+ |
| # | op | country | tel | |
| # +----+--------------------------------+----------------------+ |
| # | +I | China | 135 | |
| # | +I | China | 124 | |
| # | +I | Germany | 123 | |
| # +----+--------------------------------+----------------------+ |
| |
| |
| if __name__ == '__main__': |
| logging.basicConfig(stream=sys.stdout, level=logging.INFO, format="%(message)s") |
| |
| basic_operations() |
| sql_operations() |
| column_operations() |
| row_operations() |