blob: 81784d48badd1cb162e2e28a99e82ad5b82ea8e1 [file] [log] [blame]
################################################################################
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#################################################################################
import typing
from inspect import signature
from typing import Any, Callable, Dict, Optional, Type, Union
from docstring_parser import parse
from mcp import types
from pydantic import BaseModel, create_model
from pydantic.fields import Field, FieldInfo
def create_schema_from_function(name: str, func: Callable) -> Type[BaseModel]:
"""Create a pydantic schema from a function's signature."""
docstr = func.__doc__
docstr = parse(docstr)
doc_params = {}
for param in docstr.params:
doc_params[param.arg_name] = param
fields = {}
params = signature(func).parameters
for param_name in params:
param_type = params[param_name].annotation
param_default = params[param_name].default
description = doc_params.get(param_name)
if description is not None:
description = description.description
else:
description = f"Parameter: {param_name}"
if typing.get_origin(param_type) is typing.Annotated:
args = typing.get_args(param_type)
param_type = args[0]
if isinstance(args[1], str):
description = args[1]
elif isinstance(args[1], FieldInfo):
description = args[1].description
if param_type is params[param_name].empty:
param_type = typing.Any
if param_default is params[param_name].empty:
# Required field
fields[param_name] = (param_type, FieldInfo(description=description))
elif isinstance(param_default, FieldInfo):
# Field with pydantic.Field as default value
fields[param_name] = (param_type, param_default)
else:
fields[param_name] = (
param_type,
FieldInfo(default=param_default, description=description),
)
return create_model(name, **fields)
TYPE_MAPPING: dict[str, type] = {
"string": str,
"integer": int,
"number": float,
"boolean": bool,
"object": dict,
"array": list,
"null": type(None),
}
CONSTRAINT_MAPPING: dict[str, str] = {
"minimum": "ge",
"maximum": "le",
"exclusiveMinimum": "gt",
"exclusiveMaximum": "lt",
"inclusiveMinimum": "ge",
"inclusiveMaximum": "le",
"minItems": "min_length",
"maxItems": "max_length",
}
def __get_field_params_from_field_schema(field_schema: dict) -> dict:
"""Gets Pydantic field parameters from a JSON schema field."""
field_params = {}
for constraint, constraint_value in CONSTRAINT_MAPPING.items():
if constraint in field_schema:
field_params[constraint_value] = field_schema[constraint]
if "description" in field_schema:
field_params["description"] = field_schema["description"]
if "default" in field_schema:
field_params["default"] = field_schema["default"]
return field_params
def create_model_from_schema(name: str, schema: dict) -> type[BaseModel]:
"""Create Pydantic model from a JSON schema generated by
BaseModel.model_json_schema().
"""
models: dict[str, type[BaseModel]] = {}
def resolve_field_type(field_schema: dict) -> type[typing.Any]:
"""Resolves field type, including optional types and nullability."""
if "$ref" in field_schema:
model_reference = field_schema["$ref"].split("/")[-1]
return models.get(model_reference, Any) #
if "anyOf" in field_schema:
types = [
TYPE_MAPPING.get(t["type"], typing.Any)
for t in field_schema["anyOf"]
if t.get("type")
]
if type(None) in types:
types.remove(type(None))
if len(types) == 1:
return typing.Optional[types[0]] # noqa: UP007
return Optional[tuple(types)] # # noqa: UP007
else:
return Union[tuple(types)] # noqa: UP007
field_type = TYPE_MAPPING.get(field_schema.get("type"), typing.Any) # type: ignore[arg-type]
# Handle arrays (lists)
if field_schema.get("type") == "array":
items = field_schema.get("items", {})
item_type = resolve_field_type(items)
return list[item_type] # type: ignore[valid-type]
# Handle objects (dicts with specified value types)
if field_schema.get("type") == "object":
additional_props = field_schema.get("additionalProperties")
value_type = (
resolve_field_type(additional_props) if additional_props else typing.Any
)
return dict[str, value_type] # type: ignore[valid-type]
return field_type # type: ignore[return-value]
# First, create models for definitions
definitions = schema.get("$defs", {})
for model_name, model_schema in definitions.items():
fields = {}
for field_name, field_schema in model_schema.get("properties", {}).items():
field_type = resolve_field_type(field_schema=field_schema)
field_params = __get_field_params_from_field_schema(
field_schema=field_schema
)
fields[field_name] = (field_type, Field(**field_params))
models[model_name] = create_model(
model_name, **fields, __doc__=model_schema.get("description", "")
) # type: ignore[call-overload]
# Now, create the main model, resolving references
main_fields = {}
for field_name, field_schema in schema.get("properties", {}).items():
if "$ref" in field_schema:
model_reference = field_schema["$ref"].split("/")[-1]
field_type = models.get(model_reference, Any) # type: ignore[arg-type]
else:
field_type = resolve_field_type(field_schema=field_schema)
field_params = __get_field_params_from_field_schema(field_schema=field_schema)
main_fields[field_name] = (field_type, Field(**field_params))
return create_model(name, **main_fields, __doc__=schema.get("description", ""))
def extract_mcp_content_item(content_item: Any) -> Dict[str, Any] | str:
"""Extract and normalize a single MCP content item.
Args:
content_item: A single MCP content item (TextContent, ImageContent, etc.)
Returns:
Dict representation of the content item
Raises:
ImportError: If MCP types are not available
"""
if types is None:
err_msg = "MCP types not available. Please install the mcp package."
raise ImportError(err_msg)
if isinstance(content_item, types.TextContent):
return content_item.text
elif isinstance(content_item, types.ImageContent):
return {
"type": "image",
"data": content_item.data,
"mimeType": content_item.mimeType
}
elif isinstance(content_item, types.EmbeddedResource):
if isinstance(content_item.resource, types.TextResourceContents):
return {
"type": "resource",
"uri": content_item.resource.uri,
"text": content_item.resource.text
}
elif isinstance(content_item.resource, types.BlobResourceContents):
return {
"type": "resource",
"uri": content_item.resource.uri,
"blob": content_item.resource.blob
}
else:
# Handle unknown content types as generic dict
return content_item.model_dump() if hasattr(content_item, 'model_dump') else str(content_item)