blob: 8effacf2494cbce650e08166d3a49fad6938aa9a [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from pandas import DataFrame, Series, Timestamp
from pandas.testing import assert_frame_equal
from pytest import fixture, mark
from superset.common.chart_data import ChartDataResultFormat, ChartDataResultType
from superset.common.query_context import QueryContext
from superset.common.query_context_processor import (
AGGREGATED_JOIN_COLUMN,
QueryContextProcessor,
)
from superset.connectors.base.models import BaseDatasource
from superset.constants import TimeGrain
query_context_processor = QueryContextProcessor(
QueryContext(
datasource=BaseDatasource(),
queries=[],
result_type=ChartDataResultType.COLUMNS,
form_data={},
slice_=None,
result_format=ChartDataResultFormat.CSV,
cache_values={},
)
)
@fixture
def make_join_column_producer():
def join_column_producer(row: Series, column_index: int) -> str:
return "CUSTOM_FORMAT"
return join_column_producer
@mark.parametrize(
("time_grain", "expected"),
[
(TimeGrain.WEEK, "2020-W01"),
(TimeGrain.MONTH, "2020-01"),
(TimeGrain.QUARTER, "2020-Q1"),
(TimeGrain.YEAR, "2020"),
],
)
def test_aggregated_join_column(time_grain: str, expected: str):
df = DataFrame({"ds": [Timestamp("2020-01-07")]})
query_context_processor.add_aggregated_join_column(df, time_grain)
result = DataFrame(
{"ds": [Timestamp("2020-01-07")], AGGREGATED_JOIN_COLUMN: [expected]}
)
assert_frame_equal(df, result)
def test_aggregated_join_column_producer(make_join_column_producer):
df = DataFrame({"ds": [Timestamp("2020-01-07")]})
query_context_processor.add_aggregated_join_column(
df, TimeGrain.YEAR, make_join_column_producer
)
result = DataFrame(
{"ds": [Timestamp("2020-01-07")], AGGREGATED_JOIN_COLUMN: ["CUSTOM_FORMAT"]}
)
assert_frame_equal(df, result)