blob: b0e4c69a35eac772702e8782883bea722d5129c9 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import unittest
import pandas as pd
import numpy as np
import re
from pyspark import pandas as ps
from pyspark.testing.pandasutils import PandasOnSparkTestCase
from pyspark.testing.sqlutils import SQLTestUtils
class SeriesStringOpsAdvMixin:
@property
def pser(self):
return pd.Series(
[
"apples",
"Bananas",
"carrots",
"1",
"100",
"",
"\nleading-whitespace",
"trailing-Whitespace \t",
None,
np.nan,
]
)
def check_func(self, func, almost=False):
self.check_func_on_series(func, self.pser, almost=almost)
def check_func_on_series(self, func, pser, almost=False):
self.assert_eq(func(ps.from_pandas(pser)), func(pser), almost=almost)
def test_string_decode(self):
psser = ps.from_pandas(self.pser)
with self.assertRaises(NotImplementedError):
psser.str.decode("utf-8")
def test_string_encode(self):
psser = ps.from_pandas(self.pser)
with self.assertRaises(NotImplementedError):
psser.str.encode("utf-8")
def test_string_extract(self):
psser = ps.from_pandas(self.pser)
with self.assertRaises(NotImplementedError):
psser.str.extract("pat")
def test_string_extractall(self):
psser = ps.from_pandas(self.pser)
with self.assertRaises(NotImplementedError):
psser.str.extractall("pat")
def test_string_find(self):
self.check_func(lambda x: x.str.find("a"))
self.check_func(lambda x: x.str.find("a", start=3))
self.check_func(lambda x: x.str.find("a", start=0, end=1))
def test_string_findall(self):
self.check_func_on_series(lambda x: x.str.findall("es|as").apply(str), self.pser[:-1])
self.check_func_on_series(
lambda x: x.str.findall("wh.*", flags=re.IGNORECASE).apply(str), self.pser[:-1]
)
def test_string_index(self):
pser = pd.Series(["tea", "eat"])
self.check_func_on_series(lambda x: x.str.index("ea"), pser)
with self.assertRaises(Exception):
self.check_func_on_series(lambda x: x.str.index("ea", start=0, end=2), pser)
with self.assertRaises(Exception):
self.check_func(lambda x: x.str.index("not-found"))
def test_string_join(self):
pser = pd.Series([["a", "b", "c"], ["xx", "yy", "zz"]])
self.check_func_on_series(lambda x: x.str.join("-"), pser)
self.check_func(lambda x: x.str.join("-"))
def test_string_len(self):
self.check_func(lambda x: x.str.len())
pser = pd.Series([["a", "b", "c"], ["xx"], []])
self.check_func_on_series(lambda x: x.str.len(), pser)
def test_string_ljust(self):
self.check_func(lambda x: x.str.ljust(0))
self.check_func(lambda x: x.str.ljust(10))
self.check_func(lambda x: x.str.ljust(30, "x"))
def test_string_match(self):
self.check_func(lambda x: x.str.match("in"))
self.check_func(lambda x: x.str.match("apples|carrots", na=False))
self.check_func(lambda x: x.str.match("White", case=True))
self.check_func(lambda x: x.str.match("BANANAS", flags=re.IGNORECASE))
def test_string_normalize(self):
self.check_func(lambda x: x.str.normalize("NFC"))
self.check_func(lambda x: x.str.normalize("NFKD"))
def test_string_pad(self):
self.check_func(lambda x: x.str.pad(10))
self.check_func(lambda x: x.str.pad(10, side="both"))
self.check_func(lambda x: x.str.pad(10, side="right", fillchar="-"))
def test_string_partition(self):
with self.assertRaises(NotImplementedError):
self.check_func(lambda x: x.str.partition())
def test_string_repeat(self):
self.check_func(lambda x: x.str.repeat(repeats=3))
with self.assertRaises(TypeError):
self.check_func(lambda x: x.str.repeat(repeats=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
def test_string_replace(self):
self.check_func(lambda x: x.str.replace("a.", "xx", regex=True))
self.check_func(lambda x: x.str.replace("a.", "xx", regex=False))
self.check_func(lambda x: x.str.replace("ing", "0", flags=re.IGNORECASE))
# reverse every lowercase word
def repl(m):
return m.group(0)[::-1]
self.check_func(lambda x: x.str.replace("[a-z]+", repl, regex=True))
# compiled regex with flags
regex_pat = re.compile("WHITESPACE", flags=re.IGNORECASE)
self.check_func(lambda x: x.str.replace(regex_pat, "---", regex=True))
def test_string_rfind(self):
self.check_func(lambda x: x.str.rfind("a"))
self.check_func(lambda x: x.str.rfind("a", start=3))
self.check_func(lambda x: x.str.rfind("a", start=0, end=1))
def test_string_rindex(self):
pser = pd.Series(["teatea", "eateat"])
self.check_func_on_series(lambda x: x.str.rindex("ea"), pser)
with self.assertRaises(Exception):
self.check_func_on_series(lambda x: x.str.rindex("ea", start=0, end=2), pser)
with self.assertRaises(Exception):
self.check_func(lambda x: x.str.rindex("not-found"))
def test_string_rjust(self):
self.check_func(lambda x: x.str.rjust(0))
self.check_func(lambda x: x.str.rjust(10))
self.check_func(lambda x: x.str.rjust(30, "x"))
def test_string_rpartition(self):
with self.assertRaises(NotImplementedError):
self.check_func(lambda x: x.str.rpartition())
def test_string_slice(self):
self.check_func(lambda x: x.str.slice(start=1))
self.check_func(lambda x: x.str.slice(stop=3))
self.check_func(lambda x: x.str.slice(step=2))
self.check_func(lambda x: x.str.slice(start=0, stop=5, step=3))
def test_string_slice_replace(self):
self.check_func(lambda x: x.str.slice_replace(1, repl="X"))
self.check_func(lambda x: x.str.slice_replace(stop=2, repl="X"))
self.check_func(lambda x: x.str.slice_replace(start=1, stop=3, repl="X"))
def test_string_split(self):
self.check_func_on_series(lambda x: repr(x.str.split()), self.pser[:-1])
self.check_func_on_series(lambda x: repr(x.str.split(r"p*")), self.pser[:-1])
pser = pd.Series(["This is a sentence.", "This-is-a-long-word."])
self.check_func_on_series(lambda x: repr(x.str.split(n=2)), pser)
self.check_func_on_series(lambda x: repr(x.str.split(pat="-", n=2)), pser)
self.check_func_on_series(lambda x: x.str.split(n=2, expand=True), pser, almost=True)
with self.assertRaises(NotImplementedError):
self.check_func(lambda x: x.str.split(expand=True))
self.check_func_on_series(lambda x: repr(x.str.split("-", n=1, expand=True)), pser)
def test_string_rsplit(self):
self.check_func_on_series(lambda x: repr(x.str.rsplit()), self.pser[:-1])
self.check_func_on_series(lambda x: repr(x.str.rsplit(r"p*")), self.pser[:-1])
pser = pd.Series(["This is a sentence.", "This-is-a-long-word."])
self.check_func_on_series(lambda x: repr(x.str.rsplit(n=2)), pser)
self.check_func_on_series(lambda x: repr(x.str.rsplit(pat="-", n=2)), pser)
self.check_func_on_series(lambda x: x.str.rsplit(n=2, expand=True), pser, almost=True)
with self.assertRaises(NotImplementedError):
self.check_func(lambda x: x.str.rsplit(expand=True))
self.check_func_on_series(lambda x: repr(x.str.rsplit("-", n=1, expand=True)), pser)
def test_string_translate(self):
m = str.maketrans({"a": "X", "e": "Y", "i": None})
self.check_func(lambda x: x.str.translate(m))
def test_string_wrap(self):
self.check_func(lambda x: x.str.wrap(5))
self.check_func(lambda x: x.str.wrap(5, expand_tabs=False))
self.check_func(lambda x: x.str.wrap(5, replace_whitespace=False))
self.check_func(lambda x: x.str.wrap(5, drop_whitespace=False))
self.check_func(lambda x: x.str.wrap(5, break_long_words=False))
self.check_func(lambda x: x.str.wrap(5, break_on_hyphens=False))
def test_string_zfill(self):
self.check_func(lambda x: x.str.zfill(10))
def test_string_get_dummies(self):
with self.assertRaises(NotImplementedError):
self.check_func(lambda x: x.str.get_dummies())
class SeriesStringOpsAdvTests(
SeriesStringOpsAdvMixin,
PandasOnSparkTestCase,
SQLTestUtils,
):
pass
if __name__ == "__main__":
import unittest
from pyspark.pandas.tests.series.test_string_ops_adv import * # noqa: F401
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)