| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| from datetime import datetime |
| from importlib.util import find_spec |
| |
| import pandas as pd |
| import pytest |
| |
| from superset.exceptions import InvalidPostProcessingError |
| from superset.utils.core import DTTM_ALIAS |
| from superset.utils.pandas_postprocessing import prophet |
| from tests.unit_tests.fixtures.dataframes import prophet_df |
| |
| |
| def test_prophet_valid(): |
| df = prophet(df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.9) |
| columns = {column for column in df.columns} # noqa: C416 |
| assert columns == { |
| DTTM_ALIAS, |
| "a__yhat", |
| "a__yhat_upper", |
| "a__yhat_lower", |
| "a", |
| "b__yhat", |
| "b__yhat_upper", |
| "b__yhat_lower", |
| "b", |
| } |
| assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31) |
| assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 3, 31) |
| assert len(df) == 7 |
| |
| df = prophet(df=prophet_df, time_grain="P1M", periods=5, confidence_interval=0.9) |
| assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31) |
| assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 5, 31) |
| assert len(df) == 9 |
| |
| df = prophet( |
| df=pd.DataFrame( |
| { |
| "__timestamp": [datetime(2022, 1, 2), datetime(2022, 1, 9)], |
| "x": [1, 1], |
| } |
| ), |
| time_grain="P1W", |
| periods=1, |
| confidence_interval=0.9, |
| ) |
| |
| assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 1, 16) |
| assert len(df) == 3 |
| |
| df = prophet( |
| df=pd.DataFrame( |
| { |
| "__timestamp": [datetime(2022, 1, 2), datetime(2022, 1, 9)], |
| "x": [1, 1], |
| } |
| ), |
| time_grain="1969-12-28T00:00:00Z/P1W", |
| periods=1, |
| confidence_interval=0.9, |
| ) |
| |
| assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 1, 16) |
| assert len(df) == 3 |
| |
| df = prophet( |
| df=pd.DataFrame( |
| { |
| "__timestamp": [datetime(2022, 1, 3), datetime(2022, 1, 10)], |
| "x": [1, 1], |
| } |
| ), |
| time_grain="1969-12-29T00:00:00Z/P1W", |
| periods=1, |
| confidence_interval=0.9, |
| ) |
| |
| assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 1, 17) |
| assert len(df) == 3 |
| |
| df = prophet( |
| df=pd.DataFrame( |
| { |
| "__timestamp": [datetime(2022, 1, 8), datetime(2022, 1, 15)], |
| "x": [1, 1], |
| } |
| ), |
| time_grain="P1W/1970-01-03T00:00:00Z", |
| periods=1, |
| confidence_interval=0.9, |
| ) |
| |
| assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2022, 1, 22) |
| assert len(df) == 3 |
| |
| |
| def test_prophet_valid_zero_periods(): |
| df = prophet(df=prophet_df, time_grain="P1M", periods=0, confidence_interval=0.9) |
| columns = {column for column in df.columns} # noqa: C416 |
| assert columns == { |
| DTTM_ALIAS, |
| "a__yhat", |
| "a__yhat_upper", |
| "a__yhat_lower", |
| "a", |
| "b__yhat", |
| "b__yhat_upper", |
| "b__yhat_lower", |
| "b", |
| } |
| assert df[DTTM_ALIAS].iloc[0].to_pydatetime() == datetime(2018, 12, 31) |
| assert df[DTTM_ALIAS].iloc[-1].to_pydatetime() == datetime(2021, 12, 31) |
| assert len(df) == 4 |
| |
| |
| def test_prophet_import(): |
| dynamic_module = find_spec("prophet") |
| if dynamic_module is None: |
| with pytest.raises(InvalidPostProcessingError): |
| prophet(df=prophet_df, time_grain="P1M", periods=3, confidence_interval=0.9) |
| |
| |
| def test_prophet_missing_temporal_column(): |
| df = prophet_df.drop(DTTM_ALIAS, axis=1) |
| |
| with pytest.raises(InvalidPostProcessingError): |
| prophet( |
| df=df, |
| time_grain="P1M", |
| periods=3, |
| confidence_interval=0.9, |
| ) |
| |
| |
| def test_prophet_incorrect_confidence_interval(): |
| with pytest.raises(InvalidPostProcessingError): |
| prophet( |
| df=prophet_df, |
| time_grain="P1M", |
| periods=3, |
| confidence_interval=0.0, |
| ) |
| |
| with pytest.raises(InvalidPostProcessingError): |
| prophet( |
| df=prophet_df, |
| time_grain="P1M", |
| periods=3, |
| confidence_interval=1.0, |
| ) |
| |
| |
| def test_prophet_incorrect_periods(): |
| with pytest.raises(InvalidPostProcessingError): |
| prophet( |
| df=prophet_df, |
| time_grain="P1M", |
| periods=-1, |
| confidence_interval=0.8, |
| ) |
| |
| |
| def test_prophet_incorrect_time_grain(): |
| with pytest.raises(InvalidPostProcessingError): |
| prophet( |
| df=prophet_df, |
| time_grain="yearly", |
| periods=10, |
| confidence_interval=0.8, |
| ) |