| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| from pandas import DataFrame, Series, Timestamp |
| from pandas.testing import assert_frame_equal |
| from pytest import fixture, mark |
| |
| from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType |
| from superset.common.query_context import QueryContext |
| from superset.common.query_context_processor import ( |
| AGGREGATED_JOIN_COLUMN, |
| QueryContextProcessor, |
| ) |
| from superset.connectors.base.models import BaseDatasource |
| from superset.constants import TimeGrain |
| |
| query_context_processor = QueryContextProcessor( |
| QueryContext( |
| datasource=BaseDatasource(), |
| queries=[], |
| result_type=ChartDataResultType.COLUMNS, |
| form_data={}, |
| slice_=None, |
| result_format=ChartDataResultFormat.CSV, |
| cache_values={}, |
| ) |
| ) |
| |
| |
| @fixture |
| def make_join_column_producer(): |
| def join_column_producer(row: Series, column_index: int) -> str: |
| return "CUSTOM_FORMAT" |
| |
| return join_column_producer |
| |
| |
| @mark.parametrize( |
| ("time_grain", "expected"), |
| [ |
| (TimeGrain.WEEK, "2020-W01"), |
| (TimeGrain.MONTH, "2020-01"), |
| (TimeGrain.QUARTER, "2020-Q1"), |
| (TimeGrain.YEAR, "2020"), |
| ], |
| ) |
| def test_aggregated_join_column(time_grain: str, expected: str): |
| df = DataFrame({"ds": [Timestamp("2020-01-07")]}) |
| query_context_processor.add_aggregated_join_column(df, time_grain) |
| result = DataFrame( |
| {"ds": [Timestamp("2020-01-07")], AGGREGATED_JOIN_COLUMN: [expected]} |
| ) |
| assert_frame_equal(df, result) |
| |
| |
| def test_aggregated_join_column_producer(make_join_column_producer): |
| df = DataFrame({"ds": [Timestamp("2020-01-07")]}) |
| query_context_processor.add_aggregated_join_column( |
| df, TimeGrain.YEAR, make_join_column_producer |
| ) |
| result = DataFrame( |
| {"ds": [Timestamp("2020-01-07")], AGGREGATED_JOIN_COLUMN: ["CUSTOM_FORMAT"]} |
| ) |
| assert_frame_equal(df, result) |