| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import shutil |
| import tempfile |
| import unittest |
| import uuid |
| |
| from pypaimon.api.api_response import ConfigResponse |
| from pypaimon.api.auth import BearTokenAuthProvider |
| from pypaimon.api.rest_api import RESTApi, IllegalArgumentError |
| from pypaimon.catalog.catalog_exception import ( |
| FunctionNotExistException, |
| FunctionAlreadyExistException, |
| DefinitionAlreadyExistException, |
| DefinitionNotExistException, |
| ) |
| from pypaimon.catalog.catalog_context import CatalogContext |
| from pypaimon.catalog.rest.rest_catalog import RESTCatalog |
| from pypaimon.common.identifier import Identifier |
| from pypaimon.common.options import Options |
| from pypaimon.function.function import FunctionImpl |
| from pypaimon.function.function_change import FunctionChange |
| from pypaimon.function.function_definition import ( |
| FunctionDefinition, FunctionFileResource, |
| ) |
| from pypaimon.schema.data_types import AtomicType, DataField |
| from pypaimon.tests.rest.rest_server import RESTCatalogServer |
| |
| |
| def _mock_function(identifier: Identifier) -> FunctionImpl: |
| input_params = [ |
| DataField(0, "length", AtomicType("DOUBLE")), |
| DataField(1, "width", AtomicType("DOUBLE")), |
| ] |
| return_params = [ |
| DataField(0, "area", AtomicType("DOUBLE")), |
| ] |
| flink_function = FunctionDefinition.file( |
| file_resources=[FunctionFileResource("jar", "/a/b/c.jar")], |
| language="java", |
| class_name="className", |
| function_name="eval", |
| ) |
| spark_function = FunctionDefinition.lambda_def( |
| "(Double length, Double width) -> length * width", "java" |
| ) |
| trino_function = FunctionDefinition.sql("length * width") |
| definitions = { |
| "flink": flink_function, |
| "spark": spark_function, |
| "trino": trino_function, |
| } |
| return FunctionImpl( |
| identifier=identifier, |
| input_params=input_params, |
| return_params=return_params, |
| deterministic=False, |
| definitions=definitions, |
| comment="comment", |
| options={}, |
| ) |
| |
| |
| class RESTFunctionTest(unittest.TestCase): |
| |
| def setUp(self): |
| self.temp_dir = tempfile.mkdtemp(prefix="function_test_") |
| self.config = ConfigResponse(defaults={"prefix": "mock-test"}) |
| self.token = str(uuid.uuid4()) |
| self.server = RESTCatalogServer( |
| data_path=self.temp_dir, |
| auth_provider=BearTokenAuthProvider(self.token), |
| config=self.config, |
| warehouse="warehouse", |
| ) |
| self.server.start() |
| |
| options = Options({ |
| "metastore": "rest", |
| "uri": f"http://localhost:{self.server.port}", |
| "warehouse": "warehouse", |
| "token.provider": "bear", |
| "token": self.token, |
| }) |
| self.catalog = RESTCatalog(CatalogContext.create_from_options(options)) |
| |
| def tearDown(self): |
| self.server.shutdown() |
| import gc |
| gc.collect() |
| shutil.rmtree(self.temp_dir, ignore_errors=True) |
| |
| def test_function(self): |
| self.catalog.create_database("rest_catalog_db", True) |
| |
| identifier_with_slash = Identifier.create("rest_catalog_db", "function/") |
| with self.assertRaises(IllegalArgumentError): |
| self.catalog.create_function( |
| identifier_with_slash, |
| _mock_function(identifier_with_slash), |
| False, |
| ) |
| with self.assertRaises(FunctionNotExistException): |
| self.catalog.get_function(identifier_with_slash) |
| with self.assertRaises(IllegalArgumentError): |
| self.catalog.drop_function(identifier_with_slash, True) |
| |
| identifier_without_alphabet = Identifier.create("rest_catalog_db", "-") |
| with self.assertRaises(IllegalArgumentError): |
| self.catalog.create_function( |
| identifier_without_alphabet, |
| _mock_function(identifier_without_alphabet), |
| False, |
| ) |
| with self.assertRaises(FunctionNotExistException): |
| self.catalog.get_function(identifier_without_alphabet) |
| with self.assertRaises(IllegalArgumentError): |
| self.catalog.drop_function(identifier_without_alphabet, True) |
| |
| identifier = Identifier.from_string("rest_catalog_db.function.na_me-01") |
| function = _mock_function(identifier) |
| |
| # Drop first to ensure fresh data (not stale from a previous failed run) |
| self.catalog.drop_function(identifier, True) |
| self.catalog.create_function(identifier, function, True) |
| with self.assertRaises(FunctionAlreadyExistException): |
| self.catalog.create_function(identifier, function, False) |
| |
| self.assertIn(function.name(), self.catalog.list_functions(identifier.get_database_name())) |
| |
| get_function = self.catalog.get_function(identifier) |
| self.assertEqual(get_function.name(), function.name()) |
| for dialect in function.definitions().keys(): |
| self.assertEqual(get_function.definition(dialect), function.definition(dialect)) |
| |
| self.catalog.drop_function(identifier, True) |
| self.assertNotIn(function.name(), self.catalog.list_functions(identifier.get_database_name())) |
| |
| with self.assertRaises(FunctionNotExistException): |
| self.catalog.drop_function(identifier, False) |
| with self.assertRaises(FunctionNotExistException): |
| self.catalog.get_function(identifier) |
| |
| def test_list_functions(self): |
| db1 = "db_rest_catalog_db" |
| db2 = "db2_rest_catalog" |
| identifier = Identifier.create(db1, "list_function") |
| identifier1 = Identifier.create(db1, "function") |
| identifier2 = Identifier.create(db2, "list_function") |
| identifier3 = Identifier.create(db2, "function") |
| |
| self.catalog.create_database(db1, True) |
| self.catalog.create_database(db2, True) |
| self.catalog.create_function(identifier, _mock_function(identifier), True) |
| self.catalog.create_function(identifier1, _mock_function(identifier1), True) |
| self.catalog.create_function(identifier2, _mock_function(identifier2), True) |
| self.catalog.create_function(identifier3, _mock_function(identifier3), True) |
| |
| result = self.catalog.list_functions_paged(db1, None, None, None) |
| self.assertEqual( |
| set(result.elements), |
| {identifier.get_object_name(), identifier1.get_object_name()}, |
| ) |
| |
| result = self.catalog.list_functions_paged(db1, 1, None, None) |
| self.assertEqual(len(result.elements), 1) |
| self.assertIn( |
| result.elements[0], |
| [identifier.get_object_name(), identifier1.get_object_name()], |
| ) |
| |
| result = self.catalog.list_functions_paged( |
| db1, 1, identifier1.get_object_name(), None) |
| self.assertEqual( |
| result.elements, |
| [identifier.get_object_name()], |
| ) |
| |
| result = self.catalog.list_functions_paged(db1, None, None, "func%") |
| self.assertEqual(result.elements, [identifier1.get_object_name()]) |
| |
| result = self.catalog.list_functions_paged_globally("db2_rest%", "func%", None, None) |
| self.assertEqual(len(result.elements), 1) |
| self.assertEqual(result.elements[0].get_full_name(), identifier3.get_full_name()) |
| |
| result = self.catalog.list_functions_paged_globally( |
| "db2_rest%", None, 1, None) |
| self.assertEqual(len(result.elements), 1) |
| self.assertIn( |
| result.elements[0].get_full_name(), |
| [identifier2.get_full_name(), identifier3.get_full_name()], |
| ) |
| |
| result = self.catalog.list_functions_paged_globally( |
| "db2_rest%", None, 1, identifier3.get_full_name()) |
| self.assertEqual(len(result.elements), 1) |
| self.assertEqual( |
| result.elements[0].get_full_name(), |
| identifier2.get_full_name(), |
| ) |
| |
| result = self.catalog.list_function_details_paged(db1, 1, None, None) |
| self.assertEqual(len(result.elements), 1) |
| self.assertIn( |
| result.elements[0].full_name(), |
| [identifier.get_full_name(), identifier1.get_full_name()], |
| ) |
| |
| result = self.catalog.list_function_details_paged(db2, 4, None, "func%") |
| self.assertEqual(len(result.elements), 1) |
| self.assertEqual( |
| result.elements[0].full_name(), identifier3.get_full_name()) |
| |
| result = self.catalog.list_function_details_paged( |
| db2, 1, identifier3.get_object_name(), None) |
| full_names = [f.full_name() for f in result.elements] |
| self.assertIn(identifier2.get_full_name(), full_names) |
| |
| def test_alter_function(self): |
| identifier = Identifier.create("rest_catalog_db", "alter_function_name") |
| self.catalog.create_database(identifier.get_database_name(), True) |
| self.catalog.drop_function(identifier, True) |
| function = _mock_function(identifier) |
| definition = FunctionDefinition.sql("x * y + 1") |
| add_definition = FunctionChange.add_definition("flink_1", definition) |
| |
| self.catalog.alter_function(identifier, [add_definition], True) |
| |
| with self.assertRaises(FunctionNotExistException): |
| self.catalog.alter_function(identifier, [add_definition], False) |
| |
| self.catalog.create_function(identifier, function, True) |
| |
| key = str(uuid.uuid4()) |
| value = str(uuid.uuid4()) |
| set_option = FunctionChange.set_option(key, value) |
| self.catalog.alter_function(identifier, [set_option], False) |
| catalog_function = self.catalog.get_function(identifier) |
| self.assertEqual(catalog_function.options().get(key), value) |
| |
| self.catalog.alter_function(identifier, [FunctionChange.remove_option(key)], False) |
| catalog_function = self.catalog.get_function(identifier) |
| self.assertNotIn(key, catalog_function.options()) |
| |
| new_comment = "new comment" |
| self.catalog.alter_function( |
| identifier, [FunctionChange.update_comment(new_comment)], False |
| ) |
| catalog_function = self.catalog.get_function(identifier) |
| self.assertEqual(catalog_function.comment(), new_comment) |
| |
| self.catalog.alter_function(identifier, [add_definition], False) |
| catalog_function = self.catalog.get_function(identifier) |
| self.assertEqual( |
| catalog_function.definition(add_definition.name), |
| add_definition.definition, |
| ) |
| |
| with self.assertRaises(DefinitionAlreadyExistException): |
| self.catalog.alter_function(identifier, [add_definition], False) |
| |
| update_definition = FunctionChange.update_definition("flink_1", definition) |
| self.catalog.alter_function(identifier, [update_definition], False) |
| catalog_function = self.catalog.get_function(identifier) |
| self.assertEqual( |
| catalog_function.definition(update_definition.name), |
| update_definition.definition, |
| ) |
| |
| with self.assertRaises(DefinitionNotExistException): |
| self.catalog.alter_function( |
| identifier, |
| [FunctionChange.update_definition("no_exist", definition)], |
| False, |
| ) |
| |
| drop_definition = FunctionChange.drop_definition(update_definition.name) |
| self.catalog.alter_function(identifier, [drop_definition], False) |
| catalog_function = self.catalog.get_function(identifier) |
| self.assertIsNone(catalog_function.definition(update_definition.name)) |
| |
| with self.assertRaises(DefinitionNotExistException): |
| self.catalog.alter_function(identifier, [drop_definition], False) |
| |
| def test_validate_function_name(self): |
| self.assertTrue(RESTApi.is_valid_function_name("a")) |
| self.assertTrue(RESTApi.is_valid_function_name("a1_")) |
| self.assertTrue(RESTApi.is_valid_function_name("a-b_c")) |
| self.assertTrue(RESTApi.is_valid_function_name("a-b_c.1")) |
| |
| self.assertFalse(RESTApi.is_valid_function_name("a\\/b")) |
| self.assertFalse(RESTApi.is_valid_function_name("a$?b")) |
| self.assertFalse(RESTApi.is_valid_function_name("a@b")) |
| self.assertFalse(RESTApi.is_valid_function_name("a*b")) |
| self.assertFalse(RESTApi.is_valid_function_name("123")) |
| self.assertFalse(RESTApi.is_valid_function_name("_-")) |
| self.assertFalse(RESTApi.is_valid_function_name("")) |
| self.assertFalse(RESTApi.is_valid_function_name(None)) |
| |
| with self.assertRaises(IllegalArgumentError): |
| RESTApi.check_function_name("a\\/b") |
| with self.assertRaises(IllegalArgumentError): |
| RESTApi.check_function_name("123") |
| with self.assertRaises(IllegalArgumentError): |
| RESTApi.check_function_name("") |
| with self.assertRaises(IllegalArgumentError): |
| RESTApi.check_function_name(None) |
| |
| |
| if __name__ == "__main__": |
| unittest.main() |