From f49cc0faf17243b4d6d437d6ed1035f1116fa95a Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Sun, 21 Jun 2026 11:12:56 +0200 Subject: [PATCH] Slim ServerMiddleware to `(ctx, call_next)` and add `OpenTelemetryMiddleware` Move `method` and `params` onto `ServerRequestContext` so context-tier middleware reads `ctx.method`/`ctx.params` instead of separate positional args. `CallNext` now takes the context, so middleware can rewrite the inbound message with `call_next(replace(ctx, params=...))`. Add a context-tier `OpenTelemetryMiddleware` alongside the existing dispatch-tier `otel_middleware`, which is left intact. --- docs/migration.md | 15 +-- src/mcp/server/_otel.py | 51 +++++++++ src/mcp/server/context.py | 34 +++--- src/mcp/server/lowlevel/server.py | 2 +- src/mcp/server/runner.py | 60 ++++++---- src/mcp/shared/_otel.py | 18 ++- tests/issues/test_176_progress_token.py | 1 + tests/server/mcpserver/test_server.py | 1 + tests/server/test_otel.py | 146 ++++++++++++++++++++++++ tests/server/test_runner.py | 49 ++++---- tests/shared/test_jsonrpc_dispatcher.py | 8 +- 11 files changed, 297 insertions(+), 88 deletions(-) create mode 100644 src/mcp/server/_otel.py create mode 100644 tests/server/test_otel.py diff --git a/docs/migration.md b/docs/migration.md index 02990d779a..7b19fc5807 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -720,7 +720,7 @@ ctx: ClientRequestContext server_ctx: ServerRequestContext[LifespanContextT, RequestT] ``` -`ServerRequestContext` is now a standalone dataclass — it no longer subclasses `RequestContext[ServerSession]`. It carries the same fields (`session`, `request_id`, `meta`, `lifespan_context`, `request`, `close_sse_stream`, `close_standalone_sse_stream`) plus a new `protocol_version: str` field, so handler code is unaffected, but `isinstance(ctx, RequestContext)` checks and `RequestContext[ServerSession]` annotations need updating to `ServerRequestContext`. +`ServerRequestContext` is now a standalone dataclass — it no longer subclasses `RequestContext[ServerSession]`. It carries the same fields (`session`, `request_id`, `meta`, `lifespan_context`, `request`, `close_sse_stream`, `close_standalone_sse_stream`) plus new `protocol_version: str`, `method: str`, and raw `params: Mapping[str, Any] | None` fields (the last two let middleware read and rewrite the inbound message), so handler code is unaffected, but `isinstance(ctx, RequestContext)` checks and `RequestContext[ServerSession]` annotations need updating to `ServerRequestContext`. The high-level `Context` class (injected into `@mcp.tool()` etc.) similarly dropped its `ServerSessionT` parameter: `Context[ServerSessionT, LifespanContextT, RequestT]` → `Context[LifespanContextT, RequestT]`. Both remaining parameters have defaults, so bare `Context` is usually sufficient: @@ -935,19 +935,16 @@ server.add_notification_handler("notifications/custom", MyNotifyParams, my_notif These were private, but some users subclassed `Server` and overrode them to intercept requests. Use middleware instead: ```python -from collections.abc import Mapping from typing import Any from mcp.server import Server, ServerRequestContext from mcp.server.context import CallNext, HandlerResult -async def logging_middleware( - ctx: ServerRequestContext[Any, Any], method: str, params: Mapping[str, Any] | None, call_next: CallNext -) -> HandlerResult: - print(f"handling {method}") - result = await call_next() - print(f"done {method}") +async def logging_middleware(ctx: ServerRequestContext[Any, Any], call_next: CallNext) -> HandlerResult: + print(f"handling {ctx.method}") + result = await call_next(ctx) + print(f"done {ctx.method}") return result @@ -955,7 +952,7 @@ server = Server("my-server", on_call_tool=...) server.middleware.append(logging_middleware) ``` -Middleware runs before params validation, so `params` is the raw inbound mapping (or `None`), and it also wraps unknown methods. +The method and the raw inbound params are `ctx.method` and `ctx.params` (`params` is `None` when the message carries none). Middleware runs before params validation and also wraps unknown methods. To rewrite the method or params before the handler runs, pass an adjusted context through: `await call_next(replace(ctx, params=...))`. ### Lowlevel `Server.run(raise_exceptions=True)`: transport errors no longer re-raised diff --git a/src/mcp/server/_otel.py b/src/mcp/server/_otel.py new file mode 100644 index 0000000000..54a0e3ea8b --- /dev/null +++ b/src/mcp/server/_otel.py @@ -0,0 +1,51 @@ +from __future__ import annotations + +from typing import Any + +from opentelemetry.trace import SpanKind, StatusCode +from pydantic import ValidationError + +from mcp.server.context import CallNext, HandlerResult, ServerMiddleware, ServerRequestContext +from mcp.shared._otel import extract_trace_context, otel_span +from mcp.shared.exceptions import MCPError + + +class OpenTelemetryMiddleware(ServerMiddleware[Any]): + """Context-tier middleware that wraps each inbound message in an OpenTelemetry span. + + Span name `"MCP handle []"`, `mcp.method.name` attribute, W3C + trace context extracted from `params._meta` (SEP-414), and an ERROR status if + the handler raises. Requests and notifications both get a span; + `jsonrpc.request.id` is set only when `ctx.request_id` is present (notifications + have none). + """ + + async def __call__(self, ctx: ServerRequestContext[Any, Any], call_next: CallNext) -> HandlerResult: + name = ctx.params.get("name") if ctx.params else None + target = name if isinstance(name, str) else None + + attributes: dict[str, Any] = {"mcp.method.name": ctx.method} + if ctx.request_id is not None: + attributes["jsonrpc.request.id"] = str(ctx.request_id) + + with otel_span( + name=f"MCP handle {ctx.method}{f' {target}' if target else ''}", + kind=SpanKind.SERVER, + attributes=attributes, + context=extract_trace_context(ctx.meta or {}), + record_exception=False, + set_status_on_exception=False, + ) as span: + try: + return await call_next(ctx) + except MCPError as e: + span.set_status(StatusCode.ERROR, e.error.message) + raise + except ValidationError: + # Mirror the sanitized wire response; pydantic messages carry client input. + span.set_status(StatusCode.ERROR, "Invalid request parameters") + raise + except Exception as e: + span.record_exception(e) + span.set_status(StatusCode.ERROR, str(e)) + raise diff --git a/src/mcp/server/context.py b/src/mcp/server/context.py index eafb70d07c..151b0ac138 100644 --- a/src/mcp/server/context.py +++ b/src/mcp/server/context.py @@ -15,7 +15,7 @@ from mcp.shared.transport_context import TransportContext from mcp.types import LoggingLevel, RequestId, RequestParamsMeta -# Invariant: parameterizes a mutable dataclass field; dict default matches the default lifespan. +# Invariant: parametrizes a mutable dataclass field; dict default matches the default lifespan. LifespanContextT = TypeVar("LifespanContextT", default=dict[str, Any]) RequestT = TypeVar("RequestT", default=Any) @@ -33,6 +33,8 @@ class ServerRequestContext(Generic[LifespanContextT, RequestT]): session: ServerSession lifespan_context: LifespanContextT protocol_version: str + method: str + params: Mapping[str, Any] | None = None request_id: RequestId | None = None meta: RequestParamsMeta | None = None request: RequestT | None = None @@ -113,39 +115,41 @@ async def log(self, level: LoggingLevel, data: Any, logger: str | None = None, * """What a request handler (or middleware) may return. `ServerRunner` serializes all three to a result dict.""" -CallNext = Callable[[], Awaitable[HandlerResult]] +CallNext = Callable[["ServerRequestContext[Any, Any]"], Awaitable[HandlerResult]] +"""Invokes the rest of the chain. Pass the `ctx` through; rewrite `method` or +`params` with `dataclasses.replace(ctx, ...)` to alter what the handler sees.""" _MwLifespanT = TypeVar("_MwLifespanT") class ServerMiddleware(Protocol[_MwLifespanT]): - """Context-tier middleware: `(ctx, method, params, call_next) -> result`. + """Context-tier middleware: `(ctx, call_next) -> result`. Runs at the top of `ServerRunner._on_request` / `_on_notify` after `ctx` is built but before any validation, lookup, or handshake. Wraps every inbound request and notification: `initialize`, the pre-init gate, `METHOD_NOT_FOUND`, params validation, the handler call, and - `notifications/initialized` all run inside `call_next()`. + `notifications/initialized` all run inside `call_next(ctx)`. `notifications/cancelled` is observed too; the dispatcher applies the cancellation itself, then forwards the notification. A request-side failure reaches the middleware as a raised `MCPError` (or `ValidationError` for malformed params) so observation/logging middleware can record it. Listed outermost-first on `Server.middleware`. + The method and the raw inbound params are `ctx.method` and `ctx.params` (no + model validation has happened yet). To rewrite either before the handler + runs, pass an adjusted context: `await call_next(replace(ctx, params=...))`. `ctx.request_id is None` distinguishes a notification from a request. For - notifications `call_next()` returns `None` (a dropped or unhandled + notifications `call_next(ctx)` returns `None` (a dropped or unhandled notification also returns `None`) and the middleware's own return value is discarded. - `params` is the raw inbound mapping (no model validation has happened - yet). For typed inspection, validate against the model the middleware - expects. - - Warning: `initialize` is handled inline - the dispatcher does not read - further inbound messages until the middleware chain returns. Awaiting a - server-to-client request (`ctx.session.send_request`, `send_ping`, ...) - while handling `initialize` therefore deadlocks the connection: the - response can never be dequeued. Send-and-forget notifications are safe. + !!! warning + `initialize` is handled inline - the dispatcher does not read + further inbound messages until the middleware chain returns. Awaiting a + server-to-client request (`ctx.session.send_request`, `send_ping`, ...) + while handling `initialize` therefore deadlocks the connection: the + response can never be dequeued. Send-and-forget notifications are safe. `Server[L].middleware` holds `ServerMiddleware[L]`, so an app-specific middleware sees `ctx.lifespan_context: L`. While the context is the @@ -162,7 +166,5 @@ class ServerMiddleware(Protocol[_MwLifespanT]): async def __call__( self, ctx: ServerRequestContext[_MwLifespanT, Any], - method: str, - params: Mapping[str, Any] | None, call_next: CallNext, ) -> HandlerResult: ... diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index d2536189d0..8481fadb75 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -217,7 +217,7 @@ def __init__( self._session_manager: StreamableHTTPSessionManager | None = None # Context-tier middleware: wraps every inbound request (including # `initialize`, lookup, validation, handler) with - # `(ctx, method, params, call_next)`. Applied in `ServerRunner._on_request`. + # `(ctx, call_next)`. Applied in `ServerRunner._on_request`. # TODO(maxisbey): provisional - signature and semantics change with the # Context/middleware rework (covariant `Context[L]`, outbound seam) before # v2 final. diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index fcdcc68ced..a18a4bd125 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -16,7 +16,7 @@ from __future__ import annotations import logging -from collections.abc import Mapping +from collections.abc import Awaitable, Mapping from dataclasses import dataclass, field from functools import partial, reduce from typing import TYPE_CHECKING, Any, Generic, cast @@ -103,7 +103,7 @@ def _resolve_protocol_version( return "2025-11-25" -def otel_middleware(next_on_request: OnRequest) -> OnRequest: +def otel_middleware(call_next: OnRequest) -> OnRequest: """Dispatch-tier middleware that wraps each request in an OpenTelemetry span. Mirrors the span shape of the existing `Server._handle_request`: span name @@ -139,7 +139,7 @@ async def wrapped( set_status_on_exception=False, ) as span: try: - return await next_on_request(dctx, method, params) + return await call_next(dctx, method, params) except MCPError as e: span.set_status(StatusCode.ERROR, e.error.message) raise @@ -169,6 +169,14 @@ def _dump_result(result: Any) -> dict[str, Any]: raise TypeError(f"handler returned {type(result).__name__}; expected BaseModel, dict, or None") +def _apply_middleware( + mw: ServerMiddleware[Any], call_next: CallNext, ctx: ServerRequestContext[Any, Any] +) -> Awaitable[HandlerResult]: + """Adapt one middleware to the `CallNext` shape: bind `call_next`, take + `ctx` at call time so a rewritten context flows down the chain.""" + return mw(ctx, call_next) + + @dataclass class ServerRunner(Generic[LifespanT]): """Per-connection orchestrator. One instance per client connection.""" @@ -244,15 +252,18 @@ async def _on_request( ) -> dict[str, Any]: meta = _extract_meta(params) version = _resolve_protocol_version(self.connection.protocol_version, meta, dctx.message_metadata) - ctx = self._make_context(dctx, meta, version) + ctx = self._make_context(dctx, method, params, meta, version) is_spec_method = method in _methods.SPEC_CLIENT_METHODS - async def _inner() -> HandlerResult: + async def _inner(ctx: ServerRequestContext[LifespanT, Any]) -> HandlerResult: + # Read method/params off `ctx` so a middleware that rewrote them via + # `call_next(replace(ctx, ...))` reaches lookup and the handler. + method, params = ctx.method, ctx.params # Pinned compat: spec methods are surface-validated before lookup, # so malformed params are INVALID_PARAMS even with no handler # registered. Custom methods miss the monolith map and fall through # to `entry.params_type` exactly as before. - if is_spec_method: + if method in _methods.SPEC_CLIENT_METHODS: try: _methods.validate_client_request(method, version, params) except KeyError: @@ -282,8 +293,8 @@ async def _inner() -> HandlerResult: raise MCPError.from_error_data(result) return result - call = self._compose_server_middleware(ctx, method, params, _inner) - result = _dump_result(await call()) + call = self._compose_server_middleware(_inner) + result = _dump_result(await call(ctx)) # TODO: reject resultType values outside {"complete", "input_required"} unless the # corresponding extension is in this request's _meta clientCapabilities.extensions; the # explicit MUST-reject is client-side (basic/index.mdx ResultType), this enforces it proactively. @@ -313,9 +324,10 @@ async def _on_notify( ) -> None: meta = _extract_meta(params) version = _resolve_protocol_version(self.connection.protocol_version, meta, dctx.message_metadata) - ctx = self._make_context(dctx, meta, version) + ctx = self._make_context(dctx, method, params, meta, version) - async def _inner() -> None: + async def _inner(ctx: ServerRequestContext[LifespanT, Any]) -> None: + method, params = ctx.method, ctx.params if method in _methods.SPEC_CLIENT_NOTIFICATION_METHODS: try: _methods.validate_client_notification(method, version, params) @@ -345,33 +357,33 @@ async def _inner() -> None: return await entry.handler(ctx, typed_params) - call = self._compose_server_middleware(ctx, method, params, _inner) + call = self._compose_server_middleware(_inner) try: - await call() + await call(ctx) except Exception: # A crashing handler must not cancel the dispatcher's task group; # middleware saw the raise out of call_next() first. logger.exception("notification handler for %r raised", method) - def _compose_server_middleware( - self, - ctx: ServerRequestContext[LifespanT, Any], - method: str, - params: Mapping[str, Any] | None, - inner: CallNext, - ) -> CallNext: + def _compose_server_middleware(self, inner: CallNext) -> CallNext: """Wrap `inner` in `Server.middleware`, outermost-first. Shared by `_on_request` and `_on_notify` so the same middleware chain - observes every inbound message. + observes every inbound message. The composed callable takes the `ctx` + at call time, so a middleware can rewrite it for the rest of the chain. """ - call = inner + call: CallNext = inner for mw in reversed(self.server.middleware): - call = partial(mw, ctx, method, params, call) + call = partial(_apply_middleware, mw, call) return call def _make_context( - self, dctx: DispatchContext[TransportContext], meta: RequestParamsMeta | None, protocol_version: str + self, + dctx: DispatchContext[TransportContext], + method: str, + params: Mapping[str, Any] | None, + meta: RequestParamsMeta | None, + protocol_version: str, ) -> ServerRequestContext[LifespanT, Any]: # TODO(maxisbey): remove for Context rework. Reads the SHTTP per-request # data off the raw `dctx.message_metadata` carrier; replace with the @@ -386,6 +398,8 @@ def _make_context( return ServerRequestContext( session=self.session, lifespan_context=self.lifespan_state, + method=method, + params=params, request_id=dctx.request_id, meta=meta, protocol_version=protocol_version, diff --git a/src/mcp/shared/_otel.py b/src/mcp/shared/_otel.py index 9b20024194..375182b0db 100644 --- a/src/mcp/shared/_otel.py +++ b/src/mcp/shared/_otel.py @@ -2,13 +2,14 @@ from __future__ import annotations -from collections.abc import Iterator +from collections.abc import Generator, Mapping from contextlib import contextmanager from typing import Any from opentelemetry.context import Context from opentelemetry.propagate import extract, inject from opentelemetry.trace import SpanKind, get_tracer +from opentelemetry.trace.span import Span _tracer = get_tracer("mcp-python-sdk") @@ -22,7 +23,7 @@ def otel_span( context: Context | None = None, record_exception: bool = True, set_status_on_exception: bool = True, -) -> Iterator[Any]: +) -> Generator[Span]: """Create an OTel span.""" with _tracer.start_as_current_span( name, @@ -40,13 +41,10 @@ def inject_trace_context(meta: dict[str, Any]) -> None: inject(meta) -def extract_trace_context(meta: dict[str, Any]) -> Context | None: - """Extract W3C trace context from a `_meta` dict. - - Returns `None` when the carrier is malformed; telemetry parsing must - never fail the request it annotates. - """ +def extract_trace_context(meta: Mapping[str, Any]) -> Context: + """Extract W3C trace context from a `_meta` dict.""" try: return extract(meta) - except (TypeError, ValueError): - return None + except (ValueError, TypeError): + # If the traceparent is malformed, degrade to no parent rather than failing the request. + return Context() diff --git a/tests/issues/test_176_progress_token.py b/tests/issues/test_176_progress_token.py index ddd9c67c1d..1ba2c8e118 100644 --- a/tests/issues/test_176_progress_token.py +++ b/tests/issues/test_176_progress_token.py @@ -19,6 +19,7 @@ async def test_progress_token_zero_first_call(): request_context = ServerRequestContext( request_id="test-request", session=mock_session, + method="tools/call", meta={"progress_token": 0}, lifespan_context=None, protocol_version="2025-11-25", diff --git a/tests/server/mcpserver/test_server.py b/tests/server/mcpserver/test_server.py index d1816e6400..554fe50215 100644 --- a/tests/server/mcpserver/test_server.py +++ b/tests/server/mcpserver/test_server.py @@ -1528,6 +1528,7 @@ async def test_report_progress_passes_related_request_id(): request_context = ServerRequestContext( request_id="req-abc-123", session=mock_session, + method="tools/call", meta={"progress_token": "tok-1"}, lifespan_context=None, protocol_version="2025-11-25", diff --git a/tests/server/test_otel.py b/tests/server/test_otel.py new file mode 100644 index 0000000000..3463890892 --- /dev/null +++ b/tests/server/test_otel.py @@ -0,0 +1,146 @@ +"""Tests for `OpenTelemetryMiddleware` (the context-tier OTel span middleware).""" + +from dataclasses import replace +from typing import Any + +import anyio +import pytest +from opentelemetry.trace import SpanKind, StatusCode + +from mcp.server._otel import OpenTelemetryMiddleware +from mcp.server.context import CallNext +from mcp.server.lowlevel.server import Server +from mcp.shared._otel import inject_trace_context +from mcp.shared.exceptions import MCPError +from mcp.types import CallToolRequestParams, ListToolsResult, NotificationParams, PaginatedRequestParams, Tool + +from .conftest import SpanCapture +from .test_runner import Ctx, SrvT, connected_runner + + +@pytest.fixture +def server() -> SrvT: + async def list_tools(ctx: Ctx, params: PaginatedRequestParams | None) -> ListToolsResult: + return ListToolsResult(tools=[Tool(name="t", input_schema={"type": "object"})]) + + return Server(name="test-server", version="0.0.1", on_list_tools=list_tools) + + +async def _ok_tool(ctx: Ctx, params: CallToolRequestParams) -> dict[str, Any]: + return {"content": [], "isError": False} + + +@pytest.mark.anyio +async def test_emits_server_span_with_method_and_target(server: SrvT, spans: SpanCapture): + server.add_request_handler("tools/call", CallToolRequestParams, _ok_tool) + server.middleware.append(OpenTelemetryMiddleware()) + async with connected_runner(server) as (client, _): + spans.clear() + result = await client.send_raw_request("tools/call", {"name": "mytool", "arguments": {}}) + assert result == {"content": [], "isError": False} + [span] = [s for s in spans.finished() if s.kind == SpanKind.SERVER] + assert span.name == "MCP handle tools/call mytool" + assert span.attributes is not None + assert span.attributes["mcp.method.name"] == "tools/call" + assert isinstance(span.attributes["jsonrpc.request.id"], str) + assert span.status.status_code == StatusCode.UNSET + + +@pytest.mark.anyio +async def test_notification_span_omits_request_id(server: SrvT, spans: SpanCapture): + async def on_roots(ctx: Ctx, params: NotificationParams | None) -> None: + return None + + server.add_notification_handler("notifications/roots/list_changed", NotificationParams, on_roots) + server.middleware.append(OpenTelemetryMiddleware()) + async with connected_runner(server) as (client, _): + spans.clear() + await client.notify("notifications/roots/list_changed", None) + await anyio.wait_all_tasks_blocked() + [span] = [s for s in spans.finished() if s.kind == SpanKind.SERVER] + assert span.name == "MCP handle notifications/roots/list_changed" + assert span.attributes is not None + assert span.attributes["mcp.method.name"] == "notifications/roots/list_changed" + assert "jsonrpc.request.id" not in span.attributes + + +@pytest.mark.anyio +async def test_extracts_trace_context_from_meta(server: SrvT, spans: SpanCapture): + meta: dict[str, Any] = {} + inject_trace_context(meta) + server.middleware.append(OpenTelemetryMiddleware()) + async with connected_runner(server) as (client, _): + spans.clear() + await client.send_raw_request("tools/list", {"_meta": meta}) + [span] = [s for s in spans.finished() if s.kind == SpanKind.SERVER] + assert span.parent is not None + + +@pytest.mark.anyio +async def test_records_error_status_on_mcp_error(server: SrvT, spans: SpanCapture): + server.middleware.append(OpenTelemetryMiddleware()) + async with connected_runner(server) as (client, _): + spans.clear() + with pytest.raises(MCPError) as exc: + await client.send_raw_request("resources/list", None) + assert exc.value.error.code != 0 + [span] = [s for s in spans.finished() if s.kind == SpanKind.SERVER] + assert span.status.status_code == StatusCode.ERROR + assert span.status.description == "Method not found" + assert not [e for e in span.events if e.name == "exception"] + + +@pytest.mark.anyio +async def test_validation_failure_sets_sanitized_status(server: SrvT, spans: SpanCapture): + server.add_request_handler("tools/call", CallToolRequestParams, _ok_tool) + server.middleware.append(OpenTelemetryMiddleware()) + async with connected_runner(server) as (client, _): + spans.clear() + with pytest.raises(MCPError): + await client.send_raw_request("tools/call", {"name": 123}) + [span] = [s for s in spans.finished() if s.kind == SpanKind.SERVER] + assert span.status.status_code == StatusCode.ERROR + assert span.status.description == "Invalid request parameters" + assert not span.events + + +@pytest.mark.anyio +async def test_records_error_status_on_handler_exception(server: SrvT, spans: SpanCapture): + async def failing(ctx: Ctx, params: PaginatedRequestParams | None) -> Any: + raise ValueError("handler blew up") + + server.add_request_handler("tools/list", PaginatedRequestParams, failing) + server.middleware.append(OpenTelemetryMiddleware()) + async with connected_runner(server) as (client, _): + spans.clear() + with pytest.raises(MCPError): + await client.send_raw_request("tools/list", None) + [span] = [s for s in spans.finished() if s.kind == SpanKind.SERVER] + assert span.status.status_code == StatusCode.ERROR + assert span.status.description == "handler blew up" + [event] = [e for e in span.events if e.name == "exception"] + assert event.attributes is not None + assert event.attributes["exception.type"] == "ValueError" + + +@pytest.mark.anyio +async def test_passes_rewritten_context_through(server: SrvT, spans: SpanCapture): + seen_arguments: dict[str, Any] = {} + + async def call_tool(ctx: Ctx, params: CallToolRequestParams) -> dict[str, Any]: + seen_arguments.update(params.arguments or {}) + return {"content": [], "isError": False} + + async def inject_arg(ctx: Ctx, call_next: CallNext) -> Any: + assert ctx.params is not None + arguments = {**ctx.params.get("arguments", {}), "injected": True} + return await call_next(replace(ctx, params={**ctx.params, "arguments": arguments})) + + server.add_request_handler("tools/call", CallToolRequestParams, call_tool) + server.middleware.extend([OpenTelemetryMiddleware(), inject_arg]) + async with connected_runner(server) as (client, _): + spans.clear() + await client.send_raw_request("tools/call", {"name": "mytool", "arguments": {"x": 1}}) + assert seen_arguments == {"x": 1, "injected": True} + [span] = [s for s in spans.finished() if s.kind == SpanKind.SERVER] + assert span.name == "MCP handle tools/call mytool" diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index c4d10aea08..5b1852bcb5 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -568,9 +568,9 @@ async def wrapped(dctx: Any, method: str, params: Any) -> Any: async def test_runner_server_middleware_wraps_every_request_including_initialize(server: SrvT): seen: list[tuple[str, Any]] = [] - async def ctx_mw(ctx: Ctx, method: str, params: Any, call_next: Any) -> Any: - seen.append((method, params)) - return await call_next() + async def ctx_mw(ctx: Ctx, call_next: Any) -> Any: + seen.append((ctx.method, ctx.params)) + return await call_next(ctx) server.middleware.append(ctx_mw) async with connected_runner(server) as (client, _): @@ -587,9 +587,9 @@ async def test_runner_middleware_raise_after_call_next_on_initialize_leaves_conn client as an error and skips the state commit: the pre-init gate stays closed and `connection.initialized` never fires.""" - async def reject_initialize(ctx: Ctx, method: str, params: Any, call_next: Any) -> Any: - result = await call_next() - if method == "initialize": + async def reject_initialize(ctx: Ctx, call_next: Any) -> Any: + result = await call_next(ctx) + if ctx.method == "initialize": raise MCPError(code=INTERNAL_ERROR, message="rejected by middleware") return result @@ -613,11 +613,11 @@ async def reject_initialize(ctx: Ctx, method: str, params: Any, call_next: Any) async def test_runner_server_middleware_observes_method_not_found_via_call_next_raise(server: SrvT): seen: list[tuple[str, type[BaseException] | None]] = [] - async def observe(ctx: Ctx, method: str, params: Any, call_next: Any) -> Any: + async def observe(ctx: Ctx, call_next: Any) -> Any: try: - return await call_next() + return await call_next(ctx) except MCPError as e: - seen.append((method, type(e))) + seen.append((ctx.method, type(e))) raise server.middleware.append(observe) @@ -635,9 +635,9 @@ async def test_runner_server_middleware_wraps_notifications(server: SrvT): `ctx.request_id is None`.""" seen: list[tuple[str, bool]] = [] - async def observe(ctx: Ctx, method: str, params: Any, call_next: Any) -> Any: - seen.append((method, ctx.request_id is None)) - return await call_next() + async def observe(ctx: Ctx, call_next: Any) -> Any: + seen.append((ctx.method, ctx.request_id is None)) + return await call_next(ctx) async def on_roots(ctx: Ctx, params: NotificationParams | None) -> None: return None @@ -739,9 +739,9 @@ async def test_runner_server_middleware_runs_outermost_first(server: SrvT): order: list[str] = [] def make_mw(tag: str) -> Any: - async def mw(ctx: Ctx, method: str, params: Any, call_next: Any) -> Any: + async def mw(ctx: Ctx, call_next: Any) -> Any: order.append(f"{tag}-in") - result = await call_next() + result = await call_next(ctx) order.append(f"{tag}-out") return result @@ -787,9 +787,9 @@ async def test_runner_server_middleware_observes_handler_error_data_as_mcp_error successful-looking `ErrorData` return.""" seen: list[MCPError] = [] - async def observe(ctx: Ctx, method: str, params: Any, call_next: Any) -> Any: + async def observe(ctx: Ctx, call_next: Any) -> Any: try: - return await call_next() + return await call_next(ctx) except MCPError as e: seen.append(e) raise @@ -811,7 +811,7 @@ async def test_runner_middleware_returning_error_data_produces_jsonrpc_error(ser """A middleware that short-circuits with an `ErrorData` return gets the same treatment as a handler return: the wire sees a JSON-RPC error.""" - async def short_circuit(ctx: Ctx, method: str, params: Any, call_next: Any) -> Any: + async def short_circuit(ctx: Ctx, call_next: Any) -> Any: return ErrorData(code=INVALID_PARAMS, message="denied") server.middleware.append(short_circuit) @@ -960,10 +960,10 @@ async def test_runner_middleware_short_circuit_on_a_wrong_version_spec_method_sk spec method absent at the negotiated version owns the result shape; the outbound sieve has no `(method, version)` row and must not raise.""" - async def short_circuit(ctx: Ctx, method: str, params: Any, call_next: Any) -> Any: - if method == "server/discover": + async def short_circuit(ctx: Ctx, call_next: Any) -> Any: + if ctx.method == "server/discover": return {"ok": True} - return await call_next() + return await call_next(ctx) server.middleware.append(short_circuit) async with connected_runner(server) as (client, runner): @@ -1051,13 +1051,12 @@ async def test_otel_trace_context_propagates_client_to_server(server: SrvT, span @pytest.mark.anyio async def test_otel_middleware_malformed_traceparent_degrades_to_no_parent(server: SrvT, spans: SpanCapture): - """A non-string traceparent in `_meta` must not fail the request; the - server span simply gets no parent.""" + """A non-string traceparent in `_meta` must not fail the request; the server span simply gets no parent.""" - def break_traceparent(next_on_request: OnRequest) -> OnRequest: - async def wrapped(dctx: DispatchContext[Any], method: str, params: Any) -> dict[str, Any]: + def break_traceparent(call_next: OnRequest) -> OnRequest: + async def wrapped(ctx: DispatchContext[Any], method: str, params: Any) -> dict[str, Any]: mangled = {"_meta": {"traceparent": 123}} if method == "tools/list" else params - return await next_on_request(dctx, method, mangled) + return await call_next(ctx, method, mangled) return wrapped diff --git a/tests/shared/test_jsonrpc_dispatcher.py b/tests/shared/test_jsonrpc_dispatcher.py index 2b828e5317..588c1dcc21 100644 --- a/tests/shared/test_jsonrpc_dispatcher.py +++ b/tests/shared/test_jsonrpc_dispatcher.py @@ -2315,11 +2315,11 @@ async def handle_call_tool(ctx: ServerRequestContext, params: CallToolRequestPar await anyio.sleep_forever() raise NotImplementedError - async def observe(ctx: Any, method: str, params: Mapping[str, Any] | None, call_next: Any) -> Any: - if method == "notifications/cancelled": - observed.append((method, dict(params or {}))) + async def observe(ctx: Any, call_next: Any) -> Any: + if ctx.method == "notifications/cancelled": + observed.append((ctx.method, dict(ctx.params or {}))) cancel_observed.set() - return await call_next() + return await call_next(ctx) server = Server("test-server", on_call_tool=handle_call_tool) server.middleware.append(observe)