blob: fd3f434827c14df2e14279cfa917376ca4eb9dd3 [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#################################################################################
import json
from pathlib import Path
import pytest
from flink_agents.api.resource import Resource, ResourceType
from flink_agents.plan.resource_provider import PythonResourceProvider, ResourceProvider
current_dir = Path(__file__).parent
class MockChatModelImpl(Resource): # noqa: D101
host: str
desc: str
@classmethod
def resource_type(cls) -> ResourceType: # noqa: D102
return ResourceType.CHAT_MODEL
@pytest.fixture(scope="module")
def resource_provider() -> ResourceProvider: # noqa: D103
return PythonResourceProvider(
name="mock",
type=MockChatModelImpl.resource_type(),
module=MockChatModelImpl.__module__,
clazz=MockChatModelImpl.__name__,
kwargs={"host": "8.8.8.8", "desc": "mock chat model"},
)
def test_python_resource_provider_serialize( # noqa: D103
resource_provider: ResourceProvider,
) -> None:
json_value = resource_provider.model_dump_json(serialize_as_any=True)
with Path.open(Path(f"{current_dir}/resources/resource_provider.json")) as f:
expected_json = f.read()
actual = json.loads(json_value)
expected = json.loads(expected_json)
assert actual == expected
def test_python_resource_provider_deserialize( # noqa: D103
resource_provider: ResourceProvider,
) -> None:
with Path.open(Path(f"{current_dir}/resources/resource_provider.json")) as f:
expected_json = f.read()
expected_resource_provider = PythonResourceProvider.model_validate_json(
expected_json
)
assert resource_provider == expected_resource_provider