| ################################################################################ |
| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| ################################################################################# |
| import typing |
| from inspect import signature |
| from typing import Any, Callable, Dict, Optional, Type, Union |
| |
| from docstring_parser import parse |
| from mcp import types |
| from pydantic import BaseModel, create_model |
| from pydantic.fields import Field, FieldInfo |
| |
| |
| def create_schema_from_function(name: str, func: Callable) -> Type[BaseModel]: |
| """Create a pydantic schema from a function's signature.""" |
| docstr = func.__doc__ |
| |
| docstr = parse(docstr) |
| doc_params = {} |
| for param in docstr.params: |
| doc_params[param.arg_name] = param |
| |
| fields = {} |
| params = signature(func).parameters |
| for param_name in params: |
| param_type = params[param_name].annotation |
| param_default = params[param_name].default |
| description = doc_params.get(param_name) |
| if description is not None: |
| description = description.description |
| else: |
| description = f"Parameter: {param_name}" |
| |
| if typing.get_origin(param_type) is typing.Annotated: |
| args = typing.get_args(param_type) |
| param_type = args[0] |
| if isinstance(args[1], str): |
| description = args[1] |
| elif isinstance(args[1], FieldInfo): |
| description = args[1].description |
| |
| if param_type is params[param_name].empty: |
| param_type = typing.Any |
| |
| if param_default is params[param_name].empty: |
| # Required field |
| fields[param_name] = (param_type, FieldInfo(description=description)) |
| elif isinstance(param_default, FieldInfo): |
| # Field with pydantic.Field as default value |
| fields[param_name] = (param_type, param_default) |
| else: |
| fields[param_name] = ( |
| param_type, |
| FieldInfo(default=param_default, description=description), |
| ) |
| |
| return create_model(name, **fields) |
| |
| |
| TYPE_MAPPING: dict[str, type] = { |
| "string": str, |
| "integer": int, |
| "number": float, |
| "boolean": bool, |
| "object": dict, |
| "array": list, |
| "null": type(None), |
| } |
| |
| CONSTRAINT_MAPPING: dict[str, str] = { |
| "minimum": "ge", |
| "maximum": "le", |
| "exclusiveMinimum": "gt", |
| "exclusiveMaximum": "lt", |
| "inclusiveMinimum": "ge", |
| "inclusiveMaximum": "le", |
| "minItems": "min_length", |
| "maxItems": "max_length", |
| } |
| |
| |
| def __get_field_params_from_field_schema(field_schema: dict) -> dict: |
| """Gets Pydantic field parameters from a JSON schema field.""" |
| field_params = {} |
| for constraint, constraint_value in CONSTRAINT_MAPPING.items(): |
| if constraint in field_schema: |
| field_params[constraint_value] = field_schema[constraint] |
| if "description" in field_schema: |
| field_params["description"] = field_schema["description"] |
| if "default" in field_schema: |
| field_params["default"] = field_schema["default"] |
| return field_params |
| |
| |
| def create_model_from_schema(name: str, schema: dict) -> type[BaseModel]: |
| """Create Pydantic model from a JSON schema generated by |
| BaseModel.model_json_schema(). |
| """ |
| models: dict[str, type[BaseModel]] = {} |
| |
| def resolve_field_type(field_schema: dict) -> type[typing.Any]: |
| """Resolves field type, including optional types and nullability.""" |
| if "$ref" in field_schema: |
| model_reference = field_schema["$ref"].split("/")[-1] |
| return models.get(model_reference, Any) # |
| |
| if "anyOf" in field_schema: |
| types = [ |
| TYPE_MAPPING.get(t["type"], typing.Any) |
| for t in field_schema["anyOf"] |
| if t.get("type") |
| ] |
| if type(None) in types: |
| types.remove(type(None)) |
| if len(types) == 1: |
| return typing.Optional[types[0]] # noqa: UP007 |
| return Optional[tuple(types)] # # noqa: UP007 |
| else: |
| return Union[tuple(types)] # noqa: UP007 |
| field_type = TYPE_MAPPING.get(field_schema.get("type"), typing.Any) # type: ignore[arg-type] |
| |
| # Handle arrays (lists) |
| if field_schema.get("type") == "array": |
| items = field_schema.get("items", {}) |
| item_type = resolve_field_type(items) |
| return list[item_type] # type: ignore[valid-type] |
| |
| # Handle objects (dicts with specified value types) |
| if field_schema.get("type") == "object": |
| additional_props = field_schema.get("additionalProperties") |
| value_type = ( |
| resolve_field_type(additional_props) if additional_props else typing.Any |
| ) |
| return dict[str, value_type] # type: ignore[valid-type] |
| |
| return field_type # type: ignore[return-value] |
| |
| # First, create models for definitions |
| definitions = schema.get("$defs", {}) |
| for model_name, model_schema in definitions.items(): |
| fields = {} |
| for field_name, field_schema in model_schema.get("properties", {}).items(): |
| field_type = resolve_field_type(field_schema=field_schema) |
| field_params = __get_field_params_from_field_schema( |
| field_schema=field_schema |
| ) |
| fields[field_name] = (field_type, Field(**field_params)) |
| |
| models[model_name] = create_model( |
| model_name, **fields, __doc__=model_schema.get("description", "") |
| ) # type: ignore[call-overload] |
| |
| # Now, create the main model, resolving references |
| main_fields = {} |
| for field_name, field_schema in schema.get("properties", {}).items(): |
| if "$ref" in field_schema: |
| model_reference = field_schema["$ref"].split("/")[-1] |
| field_type = models.get(model_reference, Any) # type: ignore[arg-type] |
| else: |
| field_type = resolve_field_type(field_schema=field_schema) |
| |
| field_params = __get_field_params_from_field_schema(field_schema=field_schema) |
| main_fields[field_name] = (field_type, Field(**field_params)) |
| |
| return create_model(name, **main_fields, __doc__=schema.get("description", "")) |
| |
| |
| def extract_mcp_content_item(content_item: Any) -> Dict[str, Any] | str: |
| """Extract and normalize a single MCP content item. |
| |
| Args: |
| content_item: A single MCP content item (TextContent, ImageContent, etc.) |
| |
| Returns: |
| Dict representation of the content item |
| |
| Raises: |
| ImportError: If MCP types are not available |
| """ |
| if types is None: |
| err_msg = "MCP types not available. Please install the mcp package." |
| raise ImportError(err_msg) |
| |
| if isinstance(content_item, types.TextContent): |
| return content_item.text |
| elif isinstance(content_item, types.ImageContent): |
| return { |
| "type": "image", |
| "data": content_item.data, |
| "mimeType": content_item.mimeType |
| } |
| elif isinstance(content_item, types.EmbeddedResource): |
| if isinstance(content_item.resource, types.TextResourceContents): |
| return { |
| "type": "resource", |
| "uri": content_item.resource.uri, |
| "text": content_item.resource.text |
| } |
| elif isinstance(content_item.resource, types.BlobResourceContents): |
| return { |
| "type": "resource", |
| "uri": content_item.resource.uri, |
| "blob": content_item.resource.blob |
| } |
| else: |
| # Handle unknown content types as generic dict |
| return content_item.model_dump() if hasattr(content_item, 'model_dump') else str(content_item) |