blob: 9987ead0f9c15c7457665b0d2f6c30e49d4b5f6b [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pandas as pd
import pyarrow as pa
import pytest # noqa: F401
from pandas.api.types import is_datetime64_any_dtype
from superset.utils import csv, json
from superset.utils.core import GenericDataType
from superset.utils.csv import (
df_to_escaped_csv,
get_chart_dataframe,
)
def test_escape_value():
result = csv.escape_value("value")
assert result == "value"
result = csv.escape_value("-10")
assert result == "-10"
result = csv.escape_value("@value")
assert result == "'@value"
result = csv.escape_value("+value")
assert result == "'+value"
result = csv.escape_value("-value")
assert result == "'-value"
result = csv.escape_value("=value")
assert result == "'=value"
result = csv.escape_value("|value")
assert result == r"'\|value"
result = csv.escape_value("%value")
assert result == "'%value"
result = csv.escape_value("=cmd|' /C calc'!A0")
assert result == r"'=cmd\|' /C calc'!A0"
result = csv.escape_value('""=10+2')
assert result == '\'""=10+2'
result = csv.escape_value(" =10+2")
assert result == "' =10+2"
def fake_get_chart_csv_data_none(chart_url, auth_cookies=None):
return None
def fake_get_chart_csv_data_empty(chart_url, auth_cookies=None):
# Return JSON with empty data so that the resulting DataFrame is empty
fake_result = {
"result": [{"data": {}, "coltypes": [], "colnames": [], "indexnames": []}]
}
return json.dumps(fake_result).encode("utf-8")
def fake_get_chart_csv_data_valid(chart_url, auth_cookies=None):
# Return JSON with non-temporal data and valid indexnames so that they are used.
fake_result = {
"result": [
{
"data": {"col1": [1, 2], "col2": ["a", "b"]},
"coltypes": [GenericDataType.NUMERIC, GenericDataType.STRING],
"colnames": ["col1", "col2"],
# Provide two index names so that a MultiIndex is built.
"indexnames": ["idx1", "idx2"],
}
]
}
return json.dumps(fake_result).encode("utf-8")
def fake_get_chart_csv_data_temporal(chart_url, auth_cookies=None):
"""
Return JSON with a temporal column and valid indexnames
so that a MultiIndex is built.
"""
fake_result = {
"result": [
{
"data": {"date": [1609459200000, 1612137600000], "val": [10, 20]},
"coltypes": [GenericDataType.TEMPORAL, GenericDataType.NUMERIC],
"colnames": ["date", "val"],
# Provide two index names so a MultiIndex is built.
"indexnames": [0, 1],
}
]
}
return json.dumps(fake_result).encode("utf-8")
def fake_get_chart_csv_data_hierarchical(chart_url, auth_cookies=None):
# Return JSON with hierarchical column (list-based) and matching index names.
fake_result = {
"result": [
{
"data": {"a": [1, 2]},
"coltypes": [GenericDataType.NUMERIC],
"colnames": [["level1", "a"]],
# Provide two index tuples for two rows
"indexnames": [["idx"], ["idx"]],
}
]
}
return json.dumps(fake_result).encode("utf-8")
def test_df_to_escaped_csv():
df = pd.DataFrame(
data={
"value": [
"a",
"col_a",
"=func()",
"-10",
"=cmd|' /C calc'!A0",
'""=b',
" =a",
"\x00",
]
}
)
escaped_csv_str = df_to_escaped_csv(
df,
encoding="utf8",
index=False,
header=False,
)
escaped_csv_rows = [row.split(",") for row in escaped_csv_str.strip().split("\n")]
assert escaped_csv_rows == [
["a"],
["col_a"],
["'=func()"],
["-10"],
[r"'=cmd\\|' /C calc'!A0"],
['"\'""""=b"'],
["' =a"],
["\x00"],
]
df = pa.array([1, None]).to_pandas(integer_object_nulls=True).to_frame()
assert df_to_escaped_csv(df, encoding="utf8", index=False) == '0\n1\n""\n'
def test_get_chart_dataframe_returns_none_when_no_content(
monkeypatch: pytest.MonkeyPatch,
):
monkeypatch.setattr(csv, "get_chart_csv_data", fake_get_chart_csv_data_none)
result = get_chart_dataframe("http://dummy-url")
assert result is None
def test_get_chart_dataframe_returns_none_for_empty_data(
monkeypatch: pytest.MonkeyPatch,
):
monkeypatch.setattr(csv, "get_chart_csv_data", fake_get_chart_csv_data_empty)
result = get_chart_dataframe("http://dummy-url")
# When data is empty, the function should return None
assert result is None
def test_get_chart_dataframe_valid_non_temporal(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(csv, "get_chart_csv_data", fake_get_chart_csv_data_valid)
df = get_chart_dataframe("http://dummy-url")
assert df is not None
expected_columns = pd.MultiIndex.from_tuples([("col1",), ("col2",)])
pd.testing.assert_index_equal(df.columns, expected_columns)
expected_index = pd.MultiIndex.from_tuples([("idx1",), ("idx2",)])
pd.testing.assert_index_equal(df.index, expected_index)
pd.testing.assert_series_equal(
df[("col1",)], pd.Series([1, 2], name=("col1",), index=df.index)
)
pd.testing.assert_series_equal(
df[("col2",)], pd.Series(["a", "b"], name=("col2",), index=df.index)
)
markdown_str = df.to_markdown()
expected_markdown_str = """
| | ('col1',) | ('col2',) |
|:----------|------------:|:------------|
| ('idx1',) | 1 | a |
| ('idx2',) | 2 | b |
"""
assert markdown_str.strip() == expected_markdown_str.strip()
def test_get_chart_dataframe_valid_temporal(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(csv, "get_chart_csv_data", fake_get_chart_csv_data_temporal)
df = get_chart_dataframe("http://dummy-url")
expected_columns = pd.MultiIndex.from_tuples([("date",), ("val",)])
assert df is not None
pd.testing.assert_index_equal(df.columns, expected_columns)
expected_index = pd.MultiIndex.from_tuples([(0,), (1,)])
pd.testing.assert_index_equal(df.index, expected_index)
assert is_datetime64_any_dtype(df[("date",)])
expected_dates = pd.to_datetime([1609459200000, 1612137600000], unit="ms").astype(
"datetime64[ms]"
)
actual_dates = df[("date",)].reset_index(drop=True)
pd.testing.assert_series_equal(
actual_dates, pd.Series(expected_dates, name=("date",)), check_names=False
)
pd.testing.assert_series_equal(
df[("val",)], pd.Series([10, 20], name=("val",), index=df.index)
)
markdown_str = df.to_markdown()
expected_markdown_str = """
| | ('date',) | ('val',) |
|:-----|:--------------------|-----------:|
| (0,) | 2021-01-01 00:00:00 | 10 |
| (1,) | 2021-02-01 00:00:00 | 20 |
"""
assert markdown_str.strip() == expected_markdown_str.strip()
def test_get_chart_dataframe_with_hierarchical_columns(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(csv, "get_chart_csv_data", fake_get_chart_csv_data_hierarchical)
df = get_chart_dataframe("http://dummy-url")
assert df is not None
expected_columns = pd.MultiIndex.from_tuples([("level1", "a")])
pd.testing.assert_index_equal(df.columns, expected_columns)
expected_index = pd.MultiIndex.from_tuples([("idx",)] * len(df))
pd.testing.assert_index_equal(df.index, expected_index)
pd.testing.assert_series_equal(
df[("level1", "a")], pd.Series([1, 2], name=("level1", "a"), index=df.index)
)
markdown_str = df.to_markdown()
expected_markdown_str = """
| | ('level1', 'a') |
|:---------|------------------:|
| ('idx',) | 1 |
| ('idx',) | 2 |
"""
assert markdown_str.strip() == expected_markdown_str.strip()