diff --git a/src/mcp/server/mcpserver/prompts/base.py b/src/mcp/server/mcpserver/prompts/base.py index b4810c100..e5b2af7d8 100644 --- a/src/mcp/server/mcpserver/prompts/base.py +++ b/src/mcp/server/mcpserver/prompts/base.py @@ -3,7 +3,6 @@ from __future__ import annotations import functools -import inspect from collections.abc import Awaitable, Callable, Sequence from typing import TYPE_CHECKING, Any, Literal @@ -13,6 +12,7 @@ from mcp.server.mcpserver.utilities.context_injection import find_context_parameter, inject_context from mcp.server.mcpserver.utilities.func_metadata import func_metadata +from mcp.shared._callable_inspection import is_async_callable from mcp.types import ContentBlock, Icon, TextContent if TYPE_CHECKING: @@ -157,8 +157,9 @@ async def render( # Add context to arguments if needed call_args = inject_context(self.fn, arguments or {}, context, self.context_kwarg) - if inspect.iscoroutinefunction(self.fn): - result = await self.fn(**call_args) + fn = self.fn + if is_async_callable(fn): + result = await fn(**call_args) else: result = await anyio.to_thread.run_sync(functools.partial(self.fn, **call_args)) diff --git a/src/mcp/server/mcpserver/resources/templates.py b/src/mcp/server/mcpserver/resources/templates.py index 542b5e6f8..f1ee29a37 100644 --- a/src/mcp/server/mcpserver/resources/templates.py +++ b/src/mcp/server/mcpserver/resources/templates.py @@ -3,7 +3,6 @@ from __future__ import annotations import functools -import inspect import re from collections.abc import Callable from typing import TYPE_CHECKING, Any @@ -15,6 +14,7 @@ from mcp.server.mcpserver.resources.types import FunctionResource, Resource from mcp.server.mcpserver.utilities.context_injection import find_context_parameter, inject_context from mcp.server.mcpserver.utilities.func_metadata import func_metadata +from mcp.shared._callable_inspection import is_async_callable from mcp.types import Annotations, Icon if TYPE_CHECKING: @@ -112,8 +112,9 @@ async def create_resource( # Add context to params if needed params = inject_context(self.fn, params, context, self.context_kwarg) - if inspect.iscoroutinefunction(self.fn): - result = await self.fn(**params) + fn = self.fn + if is_async_callable(fn): + result = await fn(**params) else: result = await anyio.to_thread.run_sync(functools.partial(self.fn, **params)) diff --git a/src/mcp/server/mcpserver/resources/types.py b/src/mcp/server/mcpserver/resources/types.py index 04763be8b..d9e472e36 100644 --- a/src/mcp/server/mcpserver/resources/types.py +++ b/src/mcp/server/mcpserver/resources/types.py @@ -1,6 +1,7 @@ """Concrete resource implementations.""" -import inspect +from __future__ import annotations + import json from collections.abc import Callable from pathlib import Path @@ -14,6 +15,7 @@ from pydantic import Field, ValidationInfo, validate_call from mcp.server.mcpserver.resources.base import Resource +from mcp.shared._callable_inspection import is_async_callable from mcp.types import Annotations, Icon @@ -55,8 +57,9 @@ class FunctionResource(Resource): async def read(self) -> str | bytes: """Read the resource by calling the wrapped function.""" try: - if inspect.iscoroutinefunction(self.fn): - result = await self.fn() + fn = self.fn + if is_async_callable(fn): + result = await fn() else: result = await anyio.to_thread.run_sync(self.fn) @@ -83,7 +86,7 @@ def from_function( icons: list[Icon] | None = None, annotations: Annotations | None = None, meta: dict[str, Any] | None = None, - ) -> "FunctionResource": + ) -> FunctionResource: """Create a FunctionResource from a function.""" func_name = name or fn.__name__ if func_name == "": # pragma: no cover diff --git a/src/mcp/server/mcpserver/tools/base.py b/src/mcp/server/mcpserver/tools/base.py index dc65be988..754313eb8 100644 --- a/src/mcp/server/mcpserver/tools/base.py +++ b/src/mcp/server/mcpserver/tools/base.py @@ -1,7 +1,5 @@ from __future__ import annotations -import functools -import inspect from collections.abc import Callable from functools import cached_property from typing import TYPE_CHECKING, Any @@ -11,6 +9,7 @@ from mcp.server.mcpserver.exceptions import ToolError from mcp.server.mcpserver.utilities.context_injection import find_context_parameter from mcp.server.mcpserver.utilities.func_metadata import FuncMetadata, func_metadata +from mcp.shared._callable_inspection import is_async_callable from mcp.shared.exceptions import UrlElicitationRequiredError from mcp.shared.tool_name_validation import validate_and_warn_tool_name from mcp.types import Icon, ToolAnnotations @@ -63,7 +62,7 @@ def from_function( raise ValueError("You must provide a name for lambda functions") func_doc = description or fn.__doc__ or "" - is_async = _is_async_callable(fn) + is_async = is_async_callable(fn) if context_kwarg is None: # pragma: no branch context_kwarg = find_context_parameter(fn) @@ -118,12 +117,3 @@ async def run( raise except Exception as e: raise ToolError(f"Error executing tool {self.name}: {e}") from e - - -def _is_async_callable(obj: Any) -> bool: - while isinstance(obj, functools.partial): # pragma: lax no cover - obj = obj.func - - return inspect.iscoroutinefunction(obj) or ( - callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None)) - ) diff --git a/src/mcp/shared/_callable_inspection.py b/src/mcp/shared/_callable_inspection.py new file mode 100644 index 000000000..0e89e446f --- /dev/null +++ b/src/mcp/shared/_callable_inspection.py @@ -0,0 +1,33 @@ +"""Callable inspection utilities. + +Adapted from Starlette's `is_async_callable` implementation. +https://github.com/encode/starlette/blob/main/starlette/_utils.py +""" + +from __future__ import annotations + +import functools +import inspect +from collections.abc import Awaitable, Callable +from typing import Any, TypeGuard, TypeVar, overload + +T = TypeVar("T") + +AwaitableCallable = Callable[..., Awaitable[T]] + + +@overload +def is_async_callable(obj: AwaitableCallable[T]) -> TypeGuard[AwaitableCallable[T]]: ... + + +@overload +def is_async_callable(obj: Any) -> TypeGuard[AwaitableCallable[Any]]: ... + + +def is_async_callable(obj: Any) -> Any: + while isinstance(obj, functools.partial): # pragma: lax no cover + obj = obj.func + + return inspect.iscoroutinefunction(obj) or ( + callable(obj) and inspect.iscoroutinefunction(getattr(obj, "__call__", None)) + )