diff --git a/.github/actions/conformance/expected-failures.2026-07-28.yml b/.github/actions/conformance/expected-failures.2026-07-28.yml index 14e85f7a95..b49626d0d6 100644 --- a/.github/actions/conformance/expected-failures.2026-07-28.yml +++ b/.github/actions/conformance/expected-failures.2026-07-28.yml @@ -69,9 +69,6 @@ server: - json-schema-2020-12 # --- Draft scenarios (same failures and reasons as the `--suite draft` leg) --- - # SEP-2575 (stateless HTTP / _meta envelope): server has no stateless mode, - # _meta-derived capabilities, error-code mappings, or server/discover yet. - - server-stateless # SEP-2322 (multi-round-trip requests / IncompleteResult): not implemented. - input-required-result-basic-elicitation - input-required-result-basic-sampling @@ -83,14 +80,12 @@ server: - input-required-result-result-type - input-required-result-tampered-state - input-required-result-capability-check - - input-required-result-validate-input - # SEP-2243 (HTTP header standardization): -32020 HeaderMismatch handling and - # case-insensitive/whitespace-trimmed header validation not implemented. + # SEP-2243 (HTTP header standardization): Mcp-Method / Mcp-Name cross-check + # against the request body is not implemented. - http-header-validation - - # --- WARNING-only entries --- - # These scenarios emit no FAILURE checks, only SHOULD-level WARNINGs, but - # the expected-failures evaluator counts WARNINGs as failures. Same entries - # as the draft suite in expected-failures.yml. - # SEP-2322 SHOULD-level behaviour (re-request missing inputResponses). + # WARNING-only entries: these scenarios emit no FAILURE checks but the + # expected-failures evaluator counts WARNINGs as failures (the summary line + # only shows passed/failed, not warnings, so a local re-probe can mis-read + # these as stale). - input-required-result-missing-input-response + - input-required-result-validate-input diff --git a/.github/actions/conformance/expected-failures.yml b/.github/actions/conformance/expected-failures.yml index 816723b2ff..4234a6d4aa 100644 --- a/.github/actions/conformance/expected-failures.yml +++ b/.github/actions/conformance/expected-failures.yml @@ -34,12 +34,7 @@ client: server: # --- Draft-spec scenarios (in `--suite draft`; the `active` suite is green) --- - # SEP-2575 (stateless HTTP / _meta envelope): server has no stateless mode, - # _meta-derived capabilities, error-code mappings, or server/discover yet. - - server-stateless - # SEP-2322 (multi-round-trip requests / IncompleteResult): not implemented; - # most scenarios currently fail early with "Missing session ID" because - # mcp-everything-server only runs in stateful mode. + # SEP-2322 (multi-round-trip requests / IncompleteResult): not implemented. - input-required-result-basic-elicitation - input-required-result-basic-sampling - input-required-result-basic-list-roots @@ -50,17 +45,12 @@ server: - input-required-result-result-type - input-required-result-tampered-state - input-required-result-capability-check - # SEP-2243 (HTTP header standardization): -32020 HeaderMismatch handling and - # case-insensitive/whitespace-trimmed header validation not implemented. + # SEP-2243 (HTTP header standardization): Mcp-Method / Mcp-Name cross-check + # against the request body is not implemented. - http-header-validation - # WARNING-only entries: these scenarios emit no FAILURE checks, only SHOULD-level - # WARNINGs, but the expected-failures evaluator counts WARNINGs as failures. - # SEP-2322 SHOULD-level behaviour (re-request missing inputResponses). + # WARNING-only entries: these scenarios emit no FAILURE checks but the + # expected-failures evaluator counts WARNINGs as failures (the summary line + # only shows passed/failed, not warnings, so a local re-probe can mis-read + # these as stale). - input-required-result-missing-input-response - # SEP-2322 negative-case scenarios: input-required-result-validate-input is - # now baselined (added when the stateless path landed — the stateless server - # reaches the handler, so the previous accidental pass via -32600 "Missing - # session ID" no longer applies). input-required-result-unsupported-methods - # is intentionally NOT baselined: it still passes for now; add it once it - # starts failing for real. - input-required-result-validate-input diff --git a/docs/migration.md b/docs/migration.md index 02990d779a..5eb12d3e30 100644 --- a/docs/migration.md +++ b/docs/migration.md @@ -19,6 +19,20 @@ If you call `MCPServer.call_tool()` directly, read `.content` and `.structured_content` off the returned `CallToolResult` instead of branching on the result type. +### `MCPError` raised from an `@mcp.tool()` handler now surfaces as a JSON-RPC error + +Raising `MCPError` (or a subclass such as `UrlElicitationRequiredError`) inside +an `@mcp.tool()` handler now produces a top-level JSON-RPC error response with +the raised `code`, `message`, and `data` intact. Previously the tool wrapper +caught it like any other exception and returned `CallToolResult(isError=True)`, +which discarded the error code and structured `data`. + +`MCPError` carries `ErrorData` and is the SDK's protocol-error type — raise it +when the request itself should be rejected (missing client capability, +elicitation required, invalid parameters). For tool *execution* failures the +calling LLM should see and react to, raise any other exception or return +`CallToolResult(is_error=True, ...)` directly; that path is unchanged. + ### `streamablehttp_client` removed The deprecated `streamablehttp_client` function has been removed. Use `streamable_http_client` instead. @@ -487,6 +501,18 @@ app = Starlette(routes=[Mount("/", app=mcp.streamable_http_app(json_response=Tru If you were mutating these via `mcp.settings` after construction (e.g., `mcp.settings.port = 9000`), pass them to `run()` / `sse_app()` / `streamable_http_app()` instead — these fields no longer exist on `Settings`. The `debug` and `log_level` parameters remain on the constructor. +### Streamable HTTP: lifespan now entered once at manager startup + +When serving streamable HTTP (stateful or `stateless_http=True`), the server's `lifespan` context manager is now entered once when `StreamableHTTPSessionManager.run()` starts, and the resulting state is shared across all sessions and requests. Previously each session (stateful) or each request (stateless) entered and exited `lifespan` independently. + +Lifespans that set up process-wide state (connection pools, caches, background tasks) are unaffected — they now run once instead of per session/request. If your lifespan was acquiring per-connection resources, move that acquisition into the handler body; per-connection cleanup belongs on the connection's `exit_stack` (the public surface for reaching it from high-level `@mcp.tool()` handlers is being finalised as part of the public-surface review). + +### `Server.run()` no longer takes a `stateless` flag; `StatelessModeNotSupported` removed + +The `stateless: bool` parameter on the lowlevel `Server.run()` has been removed. Stateless serving is now a property of how the connection is constructed (the streamable-HTTP manager builds a born-ready `Connection` per request), not a flag the loop driver inspects. + +`StatelessModeNotSupported` has been removed. Server-initiated requests that have no channel to travel on now raise `NoBackChannelError` (an `MCPError` subclass) — the same exception regardless of why the channel is absent. If you were catching `StatelessModeNotSupported`, catch `NoBackChannelError` instead. + ### `MCPServer.get_context()` removed `MCPServer.get_context()` has been removed. Context is now injected by the framework and passed explicitly — there is no ambient ContextVar to read from. @@ -1202,8 +1228,8 @@ from mcp.server import ServerRequestContext session = ServerSession(read_stream, write_stream, init_options, stateless=False) # After (v2) -session = ServerSession(dispatcher, connection, stateless=False) -# where `dispatcher` is a JSONRPCDispatcher and `connection` is a Connection +session = ServerSession(request_outbound, connection) +# where `request_outbound` is an Outbound and `connection` is a Connection ``` In practice, replace direct `ServerSession` use with `Server.run(read_stream, write_stream, init_options)` and let the framework wire it up. diff --git a/examples/servers/everything-server/mcp_everything_server/server.py b/examples/servers/everything-server/mcp_everything_server/server.py index 01baa56340..b12290a6d8 100644 --- a/examples/servers/everything-server/mcp_everything_server/server.py +++ b/examples/servers/everything-server/mcp_everything_server/server.py @@ -14,6 +14,7 @@ from mcp.server.mcpserver import Context, MCPServer from mcp.server.mcpserver.prompts.base import UserMessage from mcp.server.streamable_http import EventCallback, EventMessage, EventStore +from mcp.shared.exceptions import MCPError from mcp.types import ( AudioContent, Completion, @@ -32,6 +33,7 @@ TextResourceContents, UnsubscribeRequestParams, ) +from mcp.types.jsonrpc import MISSING_REQUIRED_CLIENT_CAPABILITY from pydantic import BaseModel, Field logger = logging.getLogger(__name__) @@ -311,6 +313,26 @@ def test_error_handling() -> str: raise RuntimeError("This tool intentionally returns an error for testing") +@mcp.tool() +async def test_missing_capability(ctx: Context) -> str: + """Tests that a handler-raised MISSING_REQUIRED_CLIENT_CAPABILITY surfaces as a top-level JSON-RPC error. + + Requires the client to declare the ``sampling`` capability. When absent, raises + `MCPError` (which the tool dispatch re-raises rather than wrapping in + ``CallToolResult.isError``) so the conformance harness observes a protocol-level + error response with ``data.requiredCapabilities``. + """ + client_params = ctx.session.client_params + sampling_declared = client_params is not None and client_params.capabilities.sampling is not None + if not sampling_declared: + raise MCPError( + code=MISSING_REQUIRED_CLIENT_CAPABILITY, + message="This tool requires the client 'sampling' capability", + data={"requiredCapabilities": ["sampling"]}, + ) + return "Client declared sampling capability; proceeding." + + @mcp.tool() async def test_reconnection(ctx: Context) -> str: """Tests SSE polling by closing stream mid-call (SEP-1699)""" diff --git a/src/mcp/client/streamable_http.py b/src/mcp/client/streamable_http.py index fdb127ca0a..a703a48afb 100644 --- a/src/mcp/client/streamable_http.py +++ b/src/mcp/client/streamable_http.py @@ -314,16 +314,29 @@ async def _handle_post_request(self, ctx: RequestContext) -> None: logger.debug("Received 202 Accepted") return - if response.status_code == 404: - if isinstance(message, JSONRPCRequest): - error_data = ErrorData(code=INVALID_REQUEST, message="Session terminated") - session_message = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data)) - await ctx.read_stream_writer.send(session_message) - return - if response.status_code >= 400: if isinstance(message, JSONRPCRequest): - error_data = ErrorData(code=INTERNAL_ERROR, message="Server returned an error response") + # A spec-correct server may return the JSON-RPC error in the + # body at a non-2xx status (e.g. 400 for INVALID_PARAMS, 404 + # for METHOD_NOT_FOUND). Surface that error rather than the + # status-derived stand-in below. + if response.headers.get("content-type", "").lower().startswith("application/json"): + try: + body = await response.aread() + parsed = jsonrpc_message_adapter.validate_json(body, by_name=False) + if isinstance(parsed, JSONRPCError): + # The server may have set `id: null` (request rejected before its + # id was parsed); use this request's id so correlation works. + reply = JSONRPCError(jsonrpc="2.0", id=message.id, error=parsed.error) + await ctx.read_stream_writer.send(SessionMessage(reply)) + return + except (httpx.StreamError, ValidationError): + pass + logger.debug("Non-2xx body was not a JSON-RPC error; using fallback") + if response.status_code == 404: + error_data = ErrorData(code=INVALID_REQUEST, message="Session terminated") + else: + error_data = ErrorData(code=INTERNAL_ERROR, message="Server returned an error response") session_message = SessionMessage(JSONRPCError(jsonrpc="2.0", id=message.id, error=error_data)) await ctx.read_stream_writer.send(session_message) return diff --git a/src/mcp/server/_streamable_http_modern.py b/src/mcp/server/_streamable_http_modern.py index 051265a84d..6a81786b3b 100644 --- a/src/mcp/server/_streamable_http_modern.py +++ b/src/mcp/server/_streamable_http_modern.py @@ -12,33 +12,32 @@ from __future__ import annotations +import json import logging from collections.abc import Mapping from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, TypeVar import anyio -import anyio.abc -from pydantic import ValidationError +from pydantic import BaseModel, ValidationError from starlette.requests import Request from starlette.responses import Response from starlette.types import Receive, Scope, Send -from mcp.server.runner import ( - _EXIT_STACK_CLOSE_TIMEOUT, # type: ignore[reportPrivateUsage] - ServerRunner, - otel_middleware, -) +from mcp.server.connection import Connection +from mcp.server.runner import serve_one from mcp.server.transport_security import TransportSecurityMiddleware, TransportSecuritySettings -from mcp.shared.dispatcher import CallOptions, OnNotify, OnRequest -from mcp.shared.exceptions import MCPError, NoBackChannelError +from mcp.shared.dispatcher import CallOptions +from mcp.shared.exceptions import NoBackChannelError +from mcp.shared.inbound import ERROR_CODE_HTTP_STATUS, InboundLadderRejection, classify_inbound_request from mcp.shared.message import MessageMetadata, ServerMessageMetadata from mcp.shared.transport_context import TransportContext from mcp.types import ( - INTERNAL_ERROR, - INVALID_PARAMS, + INVALID_REQUEST, PARSE_ERROR, + ClientCapabilities, ErrorData, + Implementation, JSONRPCError, JSONRPCRequest, JSONRPCResponse, @@ -50,6 +49,10 @@ logger = logging.getLogger(__name__) +_ModelT = TypeVar("_ModelT", bound=BaseModel) + +_OK_STATUS = 200 + @dataclass class _SingleExchangeDispatchContext: @@ -64,7 +67,7 @@ class _SingleExchangeDispatchContext: request_id: RequestId message_metadata: MessageMetadata cancel_requested: anyio.Event = field(default_factory=anyio.Event) - can_send_request: bool = False + can_send_request: bool = field(default=False, init=False) async def send_raw_request( self, @@ -75,91 +78,51 @@ async def send_raw_request( raise NoBackChannelError(method) async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + # TODO(D-005a): buffer and stream as SSE once the JSON-vs-SSE response mode lands. return None async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: - # TODO: no progressToken plumbing yet. + # TODO(D-005a): no progressToken plumbing yet; ships with the SSE response mode. return None -class SingleExchangeDispatcher: - """Dispatcher for exactly one inbound JSON-RPC request over a single HTTP POST. +def _typed(model: type[_ModelT], raw: Any) -> _ModelT | None: + """Validate the classifier's raw envelope value into a typed model. - The exception->wire boundary lives here (mirrors `JSONRPCDispatcher`'s - role). Implements the `Dispatcher` Protocol so `ServerRunner` / - `Connection` / `ServerSession` accept it; `run()` is never driven. + Rung 1 guarantees the envelope key was present; a ``null`` or mis-shaped + value falls through to ``ValidationError`` and is treated as not supplied + so the request still routes. """ - - def __init__(self, request: Request) -> None: - self._request = request - self._tctx = TransportContext( - kind="streamable-http", - can_send_request=False, - headers=request.headers, - ) - - async def send_raw_request( - self, - method: str, - params: Mapping[str, Any] | None, - opts: CallOptions | None = None, - *, - _related_request_id: RequestId | None = None, - ) -> dict[str, Any]: - raise NoBackChannelError(method) - - async def notify( - self, - method: str, - params: Mapping[str, Any] | None, - *, - _related_request_id: RequestId | None = None, - ) -> None: - # TODO: buffer and stream as SSE once the response-mode design lands. + try: + return model.model_validate(raw, by_name=False) + except ValidationError: return None - async def run( - self, - on_request: OnRequest, - on_notify: OnNotify, - *, - task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, - ) -> None: - raise RuntimeError("SingleExchangeDispatcher.run() is never driven; use handle()") - - async def handle(self, req: JSONRPCRequest, on_request: OnRequest) -> JSONRPCResponse | JSONRPCError: - """Dispatch one request and map any exception to a `JSONRPCError`.""" - dctx = _SingleExchangeDispatchContext( - transport=self._tctx, - request_id=req.id, - message_metadata=ServerMessageMetadata(request_context=self._request), - ) - try: - result = await on_request(dctx, req.method, req.params) - return JSONRPCResponse(jsonrpc="2.0", id=req.id, result=result) - except MCPError as e: - return JSONRPCError(jsonrpc="2.0", id=req.id, error=e.error) - except ValidationError: - return JSONRPCError( - jsonrpc="2.0", - id=req.id, - error=ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data=""), - ) - # TODO: consolidate the three exception->ErrorData copies once the - # code=0 compat pin in JSONRPCDispatcher is lifted. - except Exception: - logger.exception("handler for %r raised", req.method) - return JSONRPCError( - jsonrpc="2.0", - id=req.id, - error=ErrorData(code=INTERNAL_ERROR, message="Internal server error"), - ) + +async def _write( + msg: JSONRPCResponse | JSONRPCError, + scope: Scope, + receive: Receive, + send: Send, +) -> None: + """Serialise a JSON-RPC reply with the table-mapped HTTP status.""" + status = ERROR_CODE_HTTP_STATUS.get(msg.error.code, _OK_STATUS) if isinstance(msg, JSONRPCError) else _OK_STATUS + body = msg.model_dump(mode="json", by_alias=True, exclude_none=True) + if isinstance(msg, JSONRPCError) and msg.id is None: + # JSON-RPC requires `id: null` to appear on the wire when the request + # id couldn't be parsed; `exclude_none` would otherwise drop it. + body["id"] = None + await Response( + json.dumps(body, separators=(",", ":")), + status_code=status, + media_type="application/json", + )(scope, receive, send) async def handle_modern_request( app: Server[Any], security_settings: TransportSecuritySettings | None, - protocol_version: str, + lifespan_state: Any, scope: Scope, receive: Receive, send: Send, @@ -167,8 +130,9 @@ async def handle_modern_request( """ASGI handler for a single stateless-era POST. Called from `StreamableHTTPSessionManager.handle_request` when the - `MCP-Protocol-Version` header is in `MODERN_PROTOCOL_VERSIONS`; the header - value is passed as `protocol_version`. Never sets `Mcp-Session-Id`. + `MCP-Protocol-Version` header names a modern revision; the manager enters + `app.lifespan` once at startup and passes the state in. Never sets + `Mcp-Session-Id`. """ request = Request(scope, receive) @@ -178,54 +142,56 @@ async def handle_modern_request( await err(scope, receive, send) return - # TODO: validate Accept header once the JSON-vs-SSE response-mode design is settled. + # TODO(D-005a): validate Accept once the JSON-vs-SSE response mode is settled. if request.method != "POST": - # TODO: GET/DELETE rejection (405 + -32601) lands with the validation ladder. + # HTTP-layer rejection (Allow accompanies 405 per RFC 9110) — happens + # before JSON-RPC parsing, so it doesn't go through `_write`. await Response(status_code=405, headers={"Allow": "POST"})(scope, receive, send) return body = await request.body() try: - req = JSONRPCRequest.model_validate_json(body) + decoded = json.loads(body) + except json.JSONDecodeError: + rej = JSONRPCError(jsonrpc="2.0", id=None, error=ErrorData(code=PARSE_ERROR, message="Parse error")) + await _write(rej, scope, receive, send) + return + try: + req = JSONRPCRequest.model_validate(decoded) except ValidationError: - msg = JSONRPCError(jsonrpc="2.0", id=None, error=ErrorData(code=PARSE_ERROR, message="Parse error")) - await Response( - msg.model_dump_json(by_alias=True), - status_code=400, - media_type="application/json", - )(scope, receive, send) + # Well-formed JSON that isn't a single request object. The transport + # spec permits notification POSTs and gives the server two responses + # (202 accept / 4xx cannot-accept; streamable-http §Sending Messages + # item 5). The core protocol defines no client→server notifications + # over HTTP at 2026-07-28 (cancellation is SSE-stream close), so this + # entry takes the cannot-accept branch. TODO(L57): S4 owns the + # strict-vs-lenient choice. + rej = JSONRPCError( + jsonrpc="2.0", + id=None, + error=ErrorData(code=INVALID_REQUEST, message="Body must be a single JSON-RPC request object"), + ) + await _write(rej, scope, receive, send) return - dispatcher = SingleExchangeDispatcher(request) - # TODO: per-request lifespan re-entry matches stateless_http=True today; revisit in #2893. - async with app.lifespan(app) as lifespan_state: - runner = ServerRunner( - server=app, - dispatcher=dispatcher, - lifespan_state=lifespan_state, - has_standalone_channel=False, - stateless=True, - dispatch_middleware=[otel_middleware], + verdict = classify_inbound_request(decoded, headers=dict(request.headers)) + if isinstance(verdict, InboundLadderRejection): + rej = JSONRPCError( + jsonrpc="2.0", id=req.id, error=ErrorData(code=verdict.code, message=verdict.message, data=verdict.data) ) - runner.connection.protocol_version = protocol_version - try: - msg = await dispatcher.handle(req, runner._compose_on_request()) # type: ignore[reportPrivateUsage] - finally: - with anyio.move_on_after(_EXIT_STACK_CLOSE_TIMEOUT, shield=True) as cancel_scope: - try: - await runner.connection.exit_stack.aclose() - except Exception: - logger.exception("connection exit_stack cleanup raised") - if cancel_scope.cancelled_caught: - logger.warning( - "connection exit_stack cleanup exceeded %s seconds; abandoning remaining callbacks", - _EXIT_STACK_CLOSE_TIMEOUT, - ) - - # TODO: error.code -> HTTP status mapping is a follow-up; 200 for all JSONRPCError bodies for now. - await Response( - msg.model_dump_json(by_alias=True, exclude_none=True), - status_code=200, - media_type="application/json", - )(scope, receive, send) + await _write(rej, scope, receive, send) + return + + connection = Connection.from_envelope( + verdict.protocol_version, + _typed(Implementation, verdict.client_info), + _typed(ClientCapabilities, verdict.client_capabilities), + ) + dctx = _SingleExchangeDispatchContext( + transport=TransportContext(kind="streamable-http", can_send_request=False, headers=request.headers), + request_id=req.id, + message_metadata=ServerMessageMetadata(request_context=request), + ) + msg = await serve_one(app, req, connection=connection, dctx=dctx, lifespan_state=lifespan_state) + await _write(msg, scope, receive, send) diff --git a/src/mcp/server/connection.py b/src/mcp/server/connection.py index 8a8034e37e..f5bfc18dfb 100644 --- a/src/mcp/server/connection.py +++ b/src/mcp/server/connection.py @@ -1,18 +1,24 @@ """`Connection` - per-client connection state and the standalone outbound channel. Always present on `Context` (never `None`), even in stateless deployments. -Holds peer info populated at `initialize` time, per-connection scratch -`state` and an `exit_stack` for teardown, and an `Outbound` for the -standalone stream (the SSE GET stream in streamable HTTP, or the single duplex -stream in stdio). +Holds peer info, per-connection scratch `state` and an `exit_stack` for +teardown, and an `Outbound` for the standalone stream (the SSE GET stream in +streamable HTTP, or the single duplex stream in stdio). + +Construct via the factories: `Connection.from_envelope` for the 2026-era +single-exchange path (born ready, no back-channel) and `Connection.for_loop` +for the handshake-driven loop path. Both populate `protocol_version` so the +kernel reads it as a fact. `notify` is best-effort: it never raises. If there's no standalone channel -(stateless HTTP) or the stream has been dropped, the notification is -debug-logged and silently discarded - server-initiated notifications are -inherently advisory. `send_raw_request` *does* raise `NoBackChannelError` when -there's no channel; `ping` is the only spec-sanctioned standalone request. +or the stream has been dropped, the notification is debug-logged and silently +discarded - server-initiated notifications are inherently advisory. +`send_raw_request` raises `NoBackChannelError` when there's no channel; `ping` +is the only spec-sanctioned standalone request. """ +from __future__ import annotations + import logging from collections.abc import Mapping from contextlib import AsyncExitStack @@ -26,12 +32,14 @@ from mcp.shared.exceptions import MCPDeprecationWarning, NoBackChannelError from mcp.shared.peer import Meta, dump_params from mcp.types import ( + LATEST_PROTOCOL_VERSION, ClientCapabilities, CreateMessageRequest, CreateMessageResult, ElicitRequest, ElicitResult, EmptyResult, + Implementation, InitializeRequestParams, ListRootsRequest, ListRootsResult, @@ -68,32 +76,57 @@ def _notification_params(payload: dict[str, Any] | None, meta: Meta | None) -> d return out +class _NoChannelOutbound: + """Connection-scoped `Outbound` for the no-back-channel case. + + The structural answer to "this connection cannot push to its peer": + `send_raw_request` raises `NoBackChannelError`; `notify` drops with a + debug log. `Connection.from_envelope` installs this so the modern + single-exchange path never needs a mode flag - the channel itself says no. + """ + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + raise NoBackChannelError(method) + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + logger.debug("dropped %s: no standalone channel", method) + + +_NO_CHANNEL = _NoChannelOutbound() + + class Connection: """Per-client connection state and standalone-stream `Outbound`. - Constructed by `ServerRunner` once per connection. The peer-info fields - are `None` until `initialize` completes; `initialized` is set later, when - the client's `notifications/initialized` follow-up arrives. In stateless - deployments the runner sets `initialized` immediately and peer-info - remains `None` (no handshake reaches a stateless connection). + Construct via `from_envelope` (modern single-exchange: born ready, no + back-channel) or `for_loop` (handshake-driven: ready once the client's + `notifications/initialized` arrives). Either way `protocol_version` is + populated at construction. """ - has_standalone_channel: bool + outbound: Outbound + """The connection-scoped channel for server-initiated messages.""" + session_id: str | None client_params: InitializeRequestParams | None - """The full `initialize` request params; `None` before initialization.""" + """The full `initialize` request params, or the equivalent built from the + 2026-era envelope. `None` when no client info was supplied.""" - protocol_version: str | None - """The protocol version negotiated during `initialize`; `None` before - initialization. Stateless connections don't require the handshake, so this - normally stays `None` there (a client that sends `initialize` anyway still - commits it). For the per-request value, read `ctx.protocol_version`.""" + protocol_version: str + """The protocol version this connection speaks. Populated at construction + by the factory and overwritten by `_handle_initialize` once the handshake + commits on the loop path.""" initialized: anyio.Event """Set when `notifications/initialized` arrives (matches TS `oninitialized`); the point from which the spec permits server-initiated requests beyond - ping/logging. Pre-set on stateless connections.""" + ping/logging. Pre-set on connections built via `from_envelope`.""" state: dict[str, Any] """Per-connection scratch state; persists across requests on this connection.""" @@ -103,24 +136,83 @@ class Connection: closes. Push cleanup from handlers or middleware; exceptions are logged and swallowed.""" - def __init__(self, outbound: Outbound, *, has_standalone_channel: bool, session_id: str | None = None) -> None: - self._outbound = outbound - self.has_standalone_channel = has_standalone_channel + def __init__( + self, + outbound: Outbound, + *, + protocol_version: str, + session_id: str | None = None, + client_params: InitializeRequestParams | None = None, + ) -> None: + self.outbound = outbound + self.protocol_version = protocol_version self.session_id = session_id - - self.client_params = None - self.protocol_version = None + self.client_params = client_params self.initialized = anyio.Event() - self.state = {} - self.exit_stack = AsyncExitStack() + @classmethod + def from_envelope( + cls, + protocol_version: str, + client_info: Implementation | None, + client_capabilities: ClientCapabilities | None, + *, + outbound: Outbound = _NO_CHANNEL, + ) -> Connection: + """A born-ready connection populated from a request's `_meta` envelope. + + `initialized` is set and the envelope's client info/capabilities (when + both supplied) are recorded as `client_params` so capability checks + work. `outbound` defaults to the no-channel sentinel for the + single-exchange HTTP path; duplex modern transports (e.g. stdio) pass + the dispatcher so server-initiated messages have a back-channel. + """ + client_params = None + if client_info is not None and client_capabilities is not None: + client_params = InitializeRequestParams( + protocol_version=protocol_version, + capabilities=client_capabilities, + client_info=client_info, + ) + connection = cls(outbound, protocol_version=protocol_version, client_params=client_params) + connection.initialized.set() + return connection + + @classmethod + def for_loop( + cls, + outbound: Outbound, + *, + session_id: str | None = None, + protocol_version_hint: str | None = None, + ) -> Connection: + """A connection for the handshake-driven loop path. + + Not born-ready: `initialized` is set later by the kernel when + `notifications/initialized` arrives. `protocol_version` is seeded from + the transport hint (or `LATEST_PROTOCOL_VERSION`) so it's never `None`; + the handshake overwrites it once negotiated. + """ + return cls( + outbound, + protocol_version=protocol_version_hint if protocol_version_hint is not None else LATEST_PROTOCOL_VERSION, + session_id=session_id, + ) + + @property + def has_standalone_channel(self) -> bool: + """Whether this connection has a real back-channel for server-initiated + messages. Derived from `outbound` - the no-channel sentinel is the only + case that doesn't.""" + return self.outbound is not _NO_CHANNEL + @property def initialize_accepted(self) -> bool: """True once the inbound request gate is open: `initialize` recorded the - peer info, or the handshake completed outright (stateless birth, or a - bare `notifications/initialized`). Derived, never stored.""" + peer info, or the handshake completed outright (born-ready, or a bare + `notifications/initialized`). Derived, never stored.""" return self.client_params is not None or self.initialized.is_set() async def send_raw_request( @@ -140,9 +232,7 @@ async def send_raw_request( MCPError: The peer responded with an error. NoBackChannelError: `has_standalone_channel` is `False`. """ - if not self.has_standalone_channel: - raise NoBackChannelError(method) - return await self._outbound.send_raw_request(method, params, opts) + return await self.outbound.send_raw_request(method, params, opts) @overload async def send_request( @@ -177,11 +267,9 @@ async def send_request( KeyError: `result_type` omitted for a non-spec request type. """ raw = await self.send_raw_request(req.method, dump_params(req.params), opts) - # Literal fallback covers pre-handshake and stateless; matches runner.py. - version = self.protocol_version or "2025-11-25" if req.method in _methods.MONOLITH_REQUESTS: try: - _methods.validate_client_result(req.method, version, raw) + _methods.validate_client_result(req.method, self.protocol_version, raw) except KeyError: pass cls = result_type if result_type is not None else _RESULT_FOR[type(req)] @@ -193,11 +281,8 @@ async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: Never raises. If there's no standalone channel or the stream is broken, the notification is dropped and debug-logged. """ - if not self.has_standalone_channel: - logger.debug("dropped %s: no standalone channel", method) - return try: - await self._outbound.notify(method, params) + await self.outbound.notify(method, params) except (anyio.BrokenResourceError, anyio.ClosedResourceError): logger.debug("dropped %s: standalone stream closed", method) @@ -233,9 +318,9 @@ async def send_resource_updated(self, uri: str, *, meta: Meta | None = None) -> def check_capability(self, capability: ClientCapabilities) -> bool: """Return whether the connected client declared the given capability. - Returns `False` if `initialize` hasn't completed yet. + Returns `False` when no client info has been recorded. """ - # TODO: redesign - mirrors v1 ServerSession.check_client_capability + # TODO(L53): redesign - mirrors v1 ServerSession.check_client_capability # verbatim for parity. if self.client_params is None: return False diff --git a/src/mcp/server/lowlevel/server.py b/src/mcp/server/lowlevel/server.py index d2536189d0..22ce0dca50 100644 --- a/src/mcp/server/lowlevel/server.py +++ b/src/mcp/server/lowlevel/server.py @@ -58,14 +58,13 @@ async def main(): from mcp.server.auth.settings import AuthSettings from mcp.server.context import HandlerResult, ServerMiddleware, ServerRequestContext from mcp.server.models import InitializationOptions -from mcp.server.runner import ServerRunner, otel_middleware +from mcp.server.runner import serve_loop from mcp.server.streamable_http import EventStore from mcp.server.streamable_http_manager import StreamableHTTPASGIApp, StreamableHTTPSessionManager from mcp.server.transport_security import TransportSecuritySettings from mcp.shared._stream_protocols import ReadStream, WriteStream -from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher from mcp.shared.message import SessionMessage -from mcp.shared.transport_context import TransportContext +from mcp.shared.version import MODERN_PROTOCOL_VERSIONS logger = logging.getLogger(__name__) @@ -117,6 +116,15 @@ async def _ping_handler(ctx: ServerRequestContext[Any], params: types.RequestPar return types.EmptyResult() +def _package_version(package: str) -> str: + try: + return importlib_version(package) + except Exception: # pragma: no cover + pass + + return "unknown" # pragma: no cover + + class Server(Generic[LifespanResultT]): def __init__( self, @@ -218,7 +226,7 @@ def __init__( # Context-tier middleware: wraps every inbound request (including # `initialize`, lookup, validation, handler) with # `(ctx, method, params, call_next)`. Applied in `ServerRunner._on_request`. - # TODO(maxisbey): provisional - signature and semantics change with the + # TODO(L54): provisional - signature and semantics change with the # Context/middleware rework (covariant `Context[L]`, outbound seam) before # v2 final. self.middleware: list[ServerMiddleware[LifespanResultT]] = [] @@ -226,6 +234,7 @@ def __init__( _spec_requests: list[tuple[str, type[BaseModel], RequestHandler[LifespanResultT, Any] | None]] = [ ("ping", types.RequestParams, on_ping), + ("server/discover", types.RequestParams, self._handle_discover), ("prompts/list", types.PaginatedRequestParams, on_list_prompts), ("prompts/get", types.GetPromptRequestParams, on_get_prompt), ("resources/list", types.PaginatedRequestParams, on_list_resources), @@ -298,7 +307,7 @@ def get_notification_handler(self, method: str) -> HandlerEntry[LifespanResultT] """Return the registered entry for a notification method, or `None`.""" return self._notification_handlers.get(method) - # TODO: Rethink capabilities API. Currently capabilities are derived from registered + # TODO(L53): Rethink capabilities API. Currently capabilities are derived from registered # handlers but require NotificationOptions to be passed externally for list_changed # flags, and experimental_capabilities as a separate dict. Consider deriving capabilities # entirely from server state (e.g. constructor params for list_changed) instead of @@ -309,18 +318,9 @@ def create_initialization_options( experimental_capabilities: dict[str, dict[str, Any]] | None = None, ) -> InitializationOptions: """Create initialization options from this server instance.""" - - def pkg_version(package: str) -> str: - try: - return importlib_version(package) - except Exception: # pragma: no cover - pass - - return "unknown" # pragma: no cover - return InitializationOptions( server_name=self.name, - server_version=self.version if self.version else pkg_version("mcp"), + server_version=self.version if self.version else _package_version("mcp"), title=self.title, description=self.description, capabilities=self.get_capabilities( @@ -334,10 +334,11 @@ def pkg_version(package: str) -> str: def get_capabilities( self, - notification_options: NotificationOptions, - experimental_capabilities: dict[str, dict[str, Any]], + notification_options: NotificationOptions | None = None, + experimental_capabilities: dict[str, dict[str, Any]] | None = None, ) -> types.ServerCapabilities: """Convert existing handlers to a ServerCapabilities object.""" + notification_options = notification_options or NotificationOptions() prompts_capability = None resources_capability = None tools_capability = None @@ -377,6 +378,40 @@ def get_capabilities( ) return capabilities + @property + def server_info(self) -> types.Implementation: + """The `serverInfo` block describing this implementation. + + Derived from the constructor's identity fields. `version` falls back to + the installed `mcp` package version when not supplied explicitly. + """ + return types.Implementation( + name=self.name, + version=self.version if self.version else _package_version("mcp"), + title=self.title, + description=self.description, + website_url=self.website_url, + icons=self.icons, + ) + + async def _handle_discover( + self, ctx: ServerRequestContext[LifespanResultT], params: types.RequestParams | None + ) -> types.DiscoverResult: + """Default `server/discover` handler. + + Auto-derived from server state at call time, so capabilities reflect + whatever has been registered (constructor `on_*` kwargs and later + `add_request_handler` calls). Operators can replace it wholesale via + `add_request_handler("server/discover", ...)`. Reachability for legacy + peers is decided at the boundary (`types.methods`), not here. + """ + return types.DiscoverResult( + supported_versions=list(MODERN_PROTOCOL_VERSIONS), + capabilities=self.get_capabilities(), + server_info=self.server_info, + instructions=self.instructions, + ) + @property def session_manager(self) -> StreamableHTTPSessionManager: """Get the StreamableHTTP session manager. @@ -401,36 +436,22 @@ async def run( # but also make tracing exceptions much easier during testing and when using # in-process servers. raise_exceptions: bool = False, - # When True, the server is stateless and - # clients can perform initialization with any node. The client must still follow - # the initialization lifecycle, but can do so with any available node - # rather than requiring initialization for each connection. - stateless: bool = False, ) -> None: + """Serve a single connection over the given streams until the read side closes. + + Thin wrapper over `serve_loop`: enters the server lifespan, + then drives the loop. Transports with their own lifespan owner + (the streamable-HTTP manager) call `serve_loop` directly instead. + """ async with self.lifespan(self) as lifespan_context: - dispatcher: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher( + await serve_loop( + self, read_stream, write_stream, - raise_handler_exceptions=raise_exceptions, - # Handle `initialize` inline so a client that pipelines it with - # the next request (spec says SHOULD NOT, not MUST NOT) sees - # the initialized state instead of failing the init-gate. - inline_methods=frozenset({"initialize"}), - ) - runner = ServerRunner( - server=self, - dispatcher=dispatcher, lifespan_state=lifespan_context, init_options=initialization_options, - # Stateless HTTP has no standalone GET stream, so server-initiated - # requests on `runner.connection` must fail fast with - # `NoBackChannelError` rather than write to a channel that will - # never deliver a response. - has_standalone_channel=not stateless, - stateless=stateless, - dispatch_middleware=[otel_middleware], + raise_exceptions=raise_exceptions, ) - await runner.run() def streamable_http_app( self, diff --git a/src/mcp/server/mcpserver/tools/base.py b/src/mcp/server/mcpserver/tools/base.py index 754313eb8a..29894d7d1d 100644 --- a/src/mcp/server/mcpserver/tools/base.py +++ b/src/mcp/server/mcpserver/tools/base.py @@ -10,7 +10,7 @@ from mcp.server.mcpserver.utilities.context_injection import find_context_parameter from mcp.server.mcpserver.utilities.func_metadata import FuncMetadata, func_metadata from mcp.shared._callable_inspection import is_async_callable -from mcp.shared.exceptions import UrlElicitationRequiredError +from mcp.shared.exceptions import MCPError from mcp.shared.tool_name_validation import validate_and_warn_tool_name from mcp.types import Icon, ToolAnnotations @@ -111,9 +111,12 @@ async def run( result = self.fn_metadata.convert_result(result) return result - except UrlElicitationRequiredError: - # Re-raise UrlElicitationRequiredError so it can be properly handled - # as an MCP error response with code -32042 + except MCPError: + # `MCPError` (and subclasses such as `UrlElicitationRequiredError`) + # carries a JSON-RPC `ErrorData(code, message, data)` and means + # "respond with a protocol error" - re-raise so the kernel surfaces + # it as a top-level JSON-RPC error rather than wrapping it as a + # `CallToolResult(isError=True)` execution failure. raise except Exception as e: raise ToolError(f"Error executing tool {self.name}: {e}") from e diff --git a/src/mcp/server/runner.py b/src/mcp/server/runner.py index fcdcc68ced..6b64ce9c49 100644 --- a/src/mcp/server/runner.py +++ b/src/mcp/server/runner.py @@ -1,14 +1,12 @@ -"""`ServerRunner` - per-connection orchestrator over a `Dispatcher`. +"""`ServerRunner` - the per-connection handler kernel. -`ServerRunner` is the bridge between the dispatcher layer (`on_request` / -`on_notify`, untyped dicts) and the user's handler layer (typed `Context`, -typed params). One instance per client connection. It: - -* handles the `initialize` handshake and populates `Connection` -* gates requests until initialized (`ping` exempt) -* looks up the handler in the server's registry, validates params, builds - `Context`, runs the middleware chain, returns the result dict -* drives `dispatcher.run()` and the per-connection lifespan +`ServerRunner` bridges the dispatch layer (`on_request` / `on_notify`, untyped +dicts) and the user's handler layer (typed `Context`, typed params). It is a +pure kernel: it holds a pre-populated `Connection` and reads +`connection.protocol_version` / `connection.outbound` as facts. Driving a +dispatcher loop and tearing down the connection live in the free-function +drivers (`serve_connection`, `serve_loop`, `serve_one`); the entry constructs +the `Connection`, the driver tears it down. `ServerRunner` holds a `Server` directly - `Server` is the registry. """ @@ -16,11 +14,12 @@ from __future__ import annotations import logging -from collections.abc import Mapping -from dataclasses import dataclass, field -from functools import partial, reduce +from collections.abc import Awaitable, Mapping, Sequence +from dataclasses import KW_ONLY, dataclass +from functools import cached_property, partial, reduce from typing import TYPE_CHECKING, Any, Generic, cast +import anyio import anyio.abc from opentelemetry.trace import SpanKind, StatusCode from pydantic import BaseModel, ValidationError @@ -31,9 +30,11 @@ from mcp.server.models import InitializationOptions from mcp.server.session import ServerSession from mcp.shared._otel import extract_trace_context, otel_span -from mcp.shared.dispatcher import DispatchContext, Dispatcher, DispatchMiddleware, OnRequest +from mcp.shared._stream_protocols import ReadStream, WriteStream +from mcp.shared.dispatcher import DispatchContext, Dispatcher, DispatchMiddleware, OnNotify, OnRequest from mcp.shared.exceptions import MCPError -from mcp.shared.message import MessageMetadata, ServerMessageMetadata +from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher, handler_exception_to_error_data +from mcp.shared.message import ServerMessageMetadata, SessionMessage from mcp.shared.transport_context import TransportContext from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS from mcp.types import ( @@ -41,11 +42,14 @@ INVALID_PARAMS, LATEST_PROTOCOL_VERSION, METHOD_NOT_FOUND, - PROTOCOL_VERSION_META_KEY, ErrorData, Implementation, InitializeRequestParams, InitializeResult, + JSONRPCError, + JSONRPCRequest, + JSONRPCResponse, + RequestId, RequestParams, RequestParamsMeta, ) @@ -54,7 +58,17 @@ if TYPE_CHECKING: from mcp.server.lowlevel.server import Server -__all__ = ["CallNext", "ServerMiddleware", "ServerRunner", "otel_middleware"] +__all__ = [ + "CallNext", + "ServerMiddleware", + "ServerRunner", + "aclose_shielded", + "otel_middleware", + "serve_connection", + "serve_loop", + "serve_one", + "to_jsonrpc_response", +] logger = logging.getLogger(__name__) @@ -64,8 +78,8 @@ _INIT_EXEMPT: frozenset[str] = frozenset({"ping"}) _EXIT_STACK_CLOSE_TIMEOUT: float = 5 -"""Bound for the shielded exit-stack unwind in `run()`; a hung cleanup -callback must not wedge shutdown.""" +"""Bound for `aclose_shielded`'s exit-stack unwind; a hung cleanup callback +must not wedge shutdown.""" def _extract_meta(params: Mapping[str, Any] | None) -> RequestParamsMeta | None: @@ -79,30 +93,6 @@ def _extract_meta(params: Mapping[str, Any] | None) -> RequestParamsMeta | None: return None -def _resolve_protocol_version( - negotiated: str | None, - meta: RequestParamsMeta | None, - md: MessageMetadata, -) -> str: - """Resolve the protocol version for this inbound message. - - Handshake-committed value wins; else per-request `_meta`, else the - transport hint. Unsupported values fall through so surface validation - never sees them. - """ - if negotiated is not None: - return negotiated - if meta is not None: - v = meta.get(PROTOCOL_VERSION_META_KEY) - if isinstance(v, str) and v in SUPPORTED_PROTOCOL_VERSIONS: - return v - if isinstance(md, ServerMessageMetadata): - hint = md.protocol_version - if hint is not None and hint in SUPPORTED_PROTOCOL_VERSIONS: - return hint - return "2025-11-25" - - def otel_middleware(next_on_request: OnRequest) -> OnRequest: """Dispatch-tier middleware that wraps each request in an OpenTelemetry span. @@ -169,73 +159,73 @@ def _dump_result(result: Any) -> dict[str, Any]: raise TypeError(f"handler returned {type(result).__name__}; expected BaseModel, dict, or None") +async def aclose_shielded(connection: Connection) -> None: + """Unwind ``connection.exit_stack`` under a shielded, bounded scope. + + Called from a driver's ``finally``: the shield lets per-connection cleanup + callbacks run even when the driver itself is being cancelled, the + `_EXIT_STACK_CLOSE_TIMEOUT` bound stops a hung callback wedging shutdown, + and a raising callback is logged-and-swallowed so it never masks the + driver's own exception. + """ + with anyio.move_on_after(_EXIT_STACK_CLOSE_TIMEOUT, shield=True) as scope: + try: + await connection.exit_stack.aclose() + except Exception: + logger.exception("connection exit_stack cleanup raised") + if scope.cancelled_caught: + logger.warning( + "connection exit_stack cleanup exceeded %s seconds; abandoning remaining callbacks", + _EXIT_STACK_CLOSE_TIMEOUT, + ) + + +async def to_jsonrpc_response(request_id: RequestId, coro: Awaitable[dict[str, Any]]) -> JSONRPCResponse | JSONRPCError: + """Await ``coro`` and wrap its outcome as the JSON-RPC reply for ``request_id``. + + The exception-to-wire boundary for the request-per-call drivers + (`serve_one`, the modern HTTP entry). `MCPError` and `ValidationError` + map via the shared `handler_exception_to_error_data` ladder; any other + exception is logged and surfaced as `INTERNAL_ERROR` so handler internals + never reach the wire. + """ + try: + result = await coro + except Exception as exc: + error = handler_exception_to_error_data(exc) + if error is None: + logger.exception("request handler raised") + error = ErrorData(code=INTERNAL_ERROR, message="Internal server error") + return JSONRPCError(jsonrpc="2.0", id=request_id, error=error) + return JSONRPCResponse(jsonrpc="2.0", id=request_id, result=result) + + @dataclass class ServerRunner(Generic[LifespanT]): - """Per-connection orchestrator. One instance per client connection.""" + """Per-connection handler kernel. One instance per client connection.""" server: Server[LifespanT] - dispatcher: Dispatcher[Any] + connection: Connection lifespan_state: LifespanT - has_standalone_channel: bool + _: KW_ONLY init_options: InitializationOptions | None = None """`InitializeResult` payload. Defaults to `server.create_initialization_options()`.""" - session_id: str | None = None - stateless: bool = False - dispatch_middleware: list[DispatchMiddleware] = field(default_factory=list[DispatchMiddleware]) - - connection: Connection = field(init=False) - session: ServerSession = field(init=False) - """Connection-scoped: the same instance reaches every request as `ctx.session`.""" - - def __post_init__(self) -> None: - if self.init_options is None: - self.init_options = self.server.create_initialization_options() - self.connection = Connection( - self.dispatcher, has_standalone_channel=self.has_standalone_channel, session_id=self.session_id - ) - if self.stateless: - # No handshake ever arrives on a stateless connection; born ready. - self.connection.initialized.set() - self.session = ServerSession(self.dispatcher, self.connection, stateless=self.stateless) - - async def run(self, *, task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED) -> None: - """Drive the dispatcher until the underlying channel closes. - - Composes `dispatch_middleware` over `_on_request` and hands the result - to `dispatcher.run()`. `task_status.started()` is forwarded so callers - can `await tg.start(runner.run)` and resume once the dispatcher is - ready to accept requests. Once the dispatcher exits, - `connection.exit_stack` is unwound (shielded from outer cancellation, - bounded by `_EXIT_STACK_CLOSE_TIMEOUT`) so any per-connection cleanup - registered by handlers or middleware gets a chance to run without a - misbehaving callback hanging shutdown indefinitely. - """ - try: - await self.dispatcher.run(self._compose_on_request(), self._on_notify, task_status=task_status) - finally: - with anyio.move_on_after(_EXIT_STACK_CLOSE_TIMEOUT, shield=True) as scope: - try: - await self.connection.exit_stack.aclose() - except Exception: - # Raising here would mask dispatcher.run()'s exception and - # crash stdio servers on normal disconnect. - logger.exception("connection exit_stack cleanup raised") - if scope.cancelled_caught: - logger.warning( - "connection exit_stack cleanup exceeded %s seconds; abandoning remaining callbacks", - _EXIT_STACK_CLOSE_TIMEOUT, - ) - - def _compose_on_request(self) -> OnRequest: - """Wrap `_on_request` in `dispatch_middleware`, outermost-first. - - Dispatch-tier middleware sees raw `(dctx, method, params) -> dict` - and wraps everything - initialize, METHOD_NOT_FOUND, validation - failures included. `run()` calls this once and hands the result to - `dispatcher.run()`. + dispatch_middleware: Sequence[DispatchMiddleware] = (otel_middleware,) + + @cached_property + def on_request(self) -> OnRequest: + """`_on_request` wrapped in `dispatch_middleware`, outermost-first. + + Dispatch-tier middleware sees raw `(dctx, method, params) -> dict` and + wraps everything - initialize, METHOD_NOT_FOUND, validation failures + included. """ return reduce(lambda h, mw: mw(h), reversed(self.dispatch_middleware), self._on_request) + @cached_property + def on_notify(self) -> OnNotify: + return self._on_notify + async def _on_request( self, dctx: DispatchContext[TransportContext], @@ -243,7 +233,7 @@ async def _on_request( params: Mapping[str, Any] | None, ) -> dict[str, Any]: meta = _extract_meta(params) - version = _resolve_protocol_version(self.connection.protocol_version, meta, dctx.message_metadata) + version = self.connection.protocol_version ctx = self._make_context(dctx, meta, version) is_spec_method = method in _methods.SPEC_CLIENT_METHODS @@ -257,7 +247,7 @@ async def _inner() -> HandlerResult: _methods.validate_client_request(method, version, params) except KeyError: raise MCPError(code=METHOD_NOT_FOUND, message="Method not found", data=method) from None - # TODO(maxisbey): the 2026-07-28 spec drops the handshake; this branch and + # TODO(L29): the 2026-07-28 spec drops the handshake; this branch and # the gate become a per-version legacy path then. Initialize runs inline # (read loop parked), so awaiting the peer anywhere on this path deadlocks. if method == "initialize": @@ -284,7 +274,7 @@ async def _inner() -> HandlerResult: call = self._compose_server_middleware(ctx, method, params, _inner) result = _dump_result(await call()) - # TODO: reject resultType values outside {"complete", "input_required"} unless the + # TODO(L56): reject resultType values outside {"complete", "input_required"} unless the # corresponding extension is in this request's _meta clientCapabilities.extensions; the # explicit MUST-reject is client-side (basic/index.mdx ResultType), this enforces it proactively. if is_spec_method: @@ -312,7 +302,7 @@ async def _on_notify( params: Mapping[str, Any] | None, ) -> None: meta = _extract_meta(params) - version = _resolve_protocol_version(self.connection.protocol_version, meta, dctx.message_metadata) + version = self.connection.protocol_version ctx = self._make_context(dctx, meta, version) async def _inner() -> None: @@ -373,7 +363,7 @@ def _compose_server_middleware( def _make_context( self, dctx: DispatchContext[TransportContext], meta: RequestParamsMeta | None, protocol_version: str ) -> ServerRequestContext[LifespanT, Any]: - # TODO(maxisbey): remove for Context rework. Reads the SHTTP per-request + # TODO(L54): remove for Context rework. Reads the SHTTP per-request # data off the raw `dctx.message_metadata` carrier; replace with the # per-transport context once that lands. md = dctx.message_metadata @@ -383,8 +373,12 @@ def _make_context( close_standalone_sse_stream = md.close_standalone_sse_stream else: request = close_sse_stream = close_standalone_sse_stream = None + # Per-request session: `dctx` is the request-scoped channel (auto-threads + # its own request_id on streamable HTTP); the standalone channel is read + # off `connection.outbound`. `related_request_id` on the public API selects. + session = ServerSession(dctx, self.connection) return ServerRequestContext( - session=self.session, + session=session, lifespan_context=self.lifespan_state, request_id=dctx.request_id, meta=meta, @@ -405,8 +399,7 @@ def _negotiate_initialize(params: Mapping[str, Any] | None) -> tuple[InitializeR def _handle_initialize(self, params: Mapping[str, Any] | None) -> InitializeResult: """Build the `initialize` result; state commits later in `_on_request`.""" _, negotiated = self._negotiate_initialize(params) - assert self.init_options is not None - opts = self.init_options + opts = self.init_options if self.init_options is not None else self.server.create_initialization_options() return InitializeResult( protocol_version=negotiated, capabilities=opts.capabilities, @@ -420,3 +413,82 @@ def _handle_initialize(self, params: Mapping[str, Any] | None) -> InitializeResu ), instructions=opts.instructions, ) + + +async def serve_connection( + server: Server[LifespanT], + dispatcher: Dispatcher[Any], + *, + connection: Connection, + lifespan_state: LifespanT, + init_options: InitializationOptions | None = None, + task_status: anyio.abc.TaskStatus[None] = anyio.TASK_STATUS_IGNORED, +) -> None: + """Drive ``dispatcher`` until the underlying channel closes. + + The loop-mode driver: builds the kernel, hands `on_request`/`on_notify` + to `dispatcher.run()`, and tears down `connection.exit_stack` (shielded) + on the way out. The entry constructs the `Connection`; this only consumes + it. + """ + runner = ServerRunner(server, connection, lifespan_state, init_options=init_options) + try: + await dispatcher.run(runner.on_request, runner.on_notify, task_status=task_status) + finally: + await aclose_shielded(connection) + + +async def serve_loop( + server: Server[LifespanT], + read_stream: ReadStream[SessionMessage | Exception], + write_stream: WriteStream[SessionMessage], + *, + lifespan_state: LifespanT, + session_id: str | None = None, + init_options: InitializationOptions | None = None, + raise_exceptions: bool = False, +) -> None: + """Drive ``server`` in loop mode over a stream pair until the channel closes. + + Builds the loop-mode `JSONRPCDispatcher` + `Connection` and hands them to + `serve_connection`, so loop-mode callers share one dispatcher-construction + recipe (notably the `inline_methods={"initialize"}` rule). Callers that own + a lifespan (the streamable-HTTP manager) pass it in; callers that don't + (`Server.run` for stdio/memory) enter the lifespan and then call this. + """ + dispatcher: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher( + read_stream, + write_stream, + raise_handler_exceptions=raise_exceptions, + # Handle `initialize` inline so a client that pipelines it with the + # next request (spec: SHOULD NOT, not MUST NOT) sees the initialized + # state instead of failing the init-gate. + inline_methods=frozenset({"initialize"}), + ) + connection = Connection.for_loop(dispatcher, session_id=session_id) + await serve_connection( + server, dispatcher, connection=connection, lifespan_state=lifespan_state, init_options=init_options + ) + + +async def serve_one( + server: Server[LifespanT], + request: JSONRPCRequest, + *, + connection: Connection, + dctx: DispatchContext[TransportContext], + lifespan_state: LifespanT, +) -> JSONRPCResponse | JSONRPCError: + """Handle a single ``request`` and return its JSON-RPC reply. + + The single-exchange driver: builds the kernel, runs `on_request` once for + `request` under `dctx`, maps the outcome to a `JSONRPCResponse` / + `JSONRPCError` via `to_jsonrpc_response`, and tears down + `connection.exit_stack` (shielded) on the way out. The entry constructs + the (born-ready) `Connection` and the `dctx`; this only consumes them. + """ + runner = ServerRunner(server, connection, lifespan_state) + try: + return await to_jsonrpc_response(request.id, runner.on_request(dctx, request.method, request.params)) + finally: + await aclose_shielded(connection) diff --git a/src/mcp/server/session.py b/src/mcp/server/session.py index 5aad5602ac..be4c1805a9 100644 --- a/src/mcp/server/session.py +++ b/src/mcp/server/session.py @@ -1,15 +1,12 @@ """`ServerSession`: server-to-client requests and notifications. -A thin proxy over `JSONRPCDispatcher` and `Connection`. One instance per -client connection (built by `ServerRunner`). Handlers reach it as -`ctx.session` and use the typed helpers (`create_message`, `elicit_form`, +A per-request proxy built by the kernel for each inbound request. Exposes the +request-scoped outbound channel and the connection's standalone channel. +Handlers reach it as `ctx.session` and use the typed helpers (`elicit_form`, `send_log_message`, ...) to call back to the client. - -The receive-loop, initialize handling, and per-request task isolation that -used to live here are now owned by `JSONRPCDispatcher` and `ServerRunner`. """ -from typing import Any, TypeVar, cast, overload +from typing import Any, TypeVar, overload from pydantic import AnyUrl, BaseModel from typing_extensions import deprecated @@ -17,8 +14,8 @@ from mcp import types from mcp.server.connection import Connection from mcp.server.validation import validate_sampling_tools, validate_tool_use_result_messages -from mcp.shared.dispatcher import CallOptions, Dispatcher, ProgressFnT -from mcp.shared.exceptions import MCPDeprecationWarning, NoBackChannelError, StatelessModeNotSupported +from mcp.shared.dispatcher import CallOptions, Outbound, ProgressFnT +from mcp.shared.exceptions import MCPDeprecationWarning from mcp.shared.message import ServerMessageMetadata from mcp.types import methods as _methods @@ -28,35 +25,32 @@ class ServerSession: - """Connection-scoped proxy for server-to-client requests and notifications. - - `send_request` / `send_notification` model-dump their argument and forward - to the dispatcher; the typed helpers below are unchanged from the previous - implementation and only call those two methods. + """Per-request proxy for server-to-client requests and notifications. + + Built once per inbound request by the kernel's `_make_context`. Holds two + `Outbound` channels: the request-scoped one (the per-request + `DispatchContext`, which on streamable HTTP routes onto the originating + POST's response stream) and the connection's standalone channel + (`connection.outbound`). `related_request_id` on the public methods is the + selector — present means request-scoped, absent means standalone — and + never crosses the `Outbound` Protocol. """ - def __init__( - self, - dispatcher: Dispatcher[Any], - connection: Connection, - *, - stateless: bool = False, - ) -> None: - self._dispatcher = dispatcher + def __init__(self, request_outbound: Outbound, connection: Connection) -> None: + self._request_outbound = request_outbound self._connection = connection - self._stateless = stateless @property def client_params(self) -> types.InitializeRequestParams | None: - """The client's `initialize` request params; `None` before initialization.""" + """The client's `initialize` request params; `None` when no client info was supplied.""" return self._connection.client_params @property - def protocol_version(self) -> str | None: - """The protocol version negotiated during `initialize`. + def protocol_version(self) -> str: + """The protocol version this connection speaks. - `None` before initialization, and normally `None` on stateless - connections. For the per-request value, read `ctx.protocol_version`. + Populated at `Connection` construction and overwritten once the + handshake commits on the loop path; never `None`. """ return self._connection.protocol_version @@ -70,46 +64,23 @@ async def send_request( ) -> ResultT: """Send a typed server-to-client request and validate the result. - `metadata.related_request_id` (when supplied) routes the outgoing - message onto the originating request's response stream over - streamable HTTP; it is the only metadata field honored here. - Raises: MCPError: The peer responded with an error. - NoBackChannelError: If there is no related request to ride on and - the connection has no standalone channel (stateless HTTP), so - a response could never arrive. + NoBackChannelError: The connection has no back-channel for + server-initiated requests (raised by the held `Outbound`). pydantic.ValidationError: The peer's result does not match `result_type`. """ + related = metadata.related_request_id if metadata is not None else None + channel = self._request_outbound if related is not None else self._connection.outbound data = request.model_dump(by_alias=True, mode="json", exclude_none=True) opts: CallOptions = {} if request_read_timeout_seconds is not None: opts["timeout"] = request_read_timeout_seconds if progress_callback is not None: opts["on_progress"] = progress_callback - related = metadata.related_request_id if metadata is not None else None - if related is None and not self._connection.has_standalone_channel: - # Fail fast instead of parking forever on a response that cannot - # arrive; matches `Connection.send_raw_request`. - raise NoBackChannelError(data["method"]) - # TODO: _related_request_id is not on the Dispatcher Protocol (and must not - # be — it's transport-specific). The fix is to give `ctx.session` a per-request - # Outbound (the DispatchContext, which threads its own request_id) alongside - # the connection-level one, with `related_request_id` as the selector; that - # belongs with the ServerSession/Context rework, not here. - result = cast( - "dict[str, Any]", - await self._dispatcher.send_raw_request( - data["method"], - data.get("params"), - opts or None, - _related_request_id=related, # type: ignore[call-arg] - ), - ) - # Literal fallback covers pre-handshake and stateless; matches runner.py. - version = self.protocol_version or "2025-11-25" + result = await channel.send_raw_request(data["method"], data.get("params"), opts or None) try: - _methods.validate_client_result(request.method, version, result) + _methods.validate_client_result(request.method, self.protocol_version, result) except KeyError: pass return result_type.model_validate(result, by_name=False) @@ -120,8 +91,9 @@ async def send_notification( related_request_id: types.RequestId | None = None, ) -> None: """Send a typed server-to-client notification.""" + channel = self._request_outbound if related_request_id is not None else self._connection.outbound data = notification.model_dump(by_alias=True, mode="json", exclude_none=True) - await self._dispatcher.notify(data["method"], data.get("params"), _related_request_id=related_request_id) # type: ignore[call-arg] + await channel.notify(data["method"], data.get("params")) def check_client_capability(self, capability: types.ClientCapabilities) -> bool: """Check if the client supports a specific capability.""" @@ -236,10 +208,9 @@ async def create_message( Raises: MCPError: If tools are provided but client doesn't support them. ValueError: If tool_use or tool_result message structure is invalid. - StatelessModeNotSupported: If called in stateless HTTP mode. + NoBackChannelError: The connection has no back-channel for + server-initiated requests. """ - if self._stateless: - raise StatelessModeNotSupported(method="sampling") client_caps = self.client_params.capabilities if self.client_params else None validate_sampling_tools(client_caps, tools, tool_choice) validate_tool_use_result_messages(messages) @@ -274,9 +245,12 @@ async def create_message( @deprecated("The roots capability is deprecated as of 2026-07-28 (SEP-2577).", category=MCPDeprecationWarning) async def list_roots(self) -> types.ListRootsResult: - """Send a roots/list request.""" - if self._stateless: - raise StatelessModeNotSupported(method="list_roots") + """Send a roots/list request. + + Raises: + NoBackChannelError: The connection has no back-channel for + server-initiated requests. + """ return await self.send_request( types.ListRootsRequest(), types.ListRootsResult, @@ -321,10 +295,9 @@ async def elicit_form( The client's response with form data. Raises: - StatelessModeNotSupported: If called in stateless HTTP mode. + NoBackChannelError: The connection has no back-channel for + server-initiated requests. """ - if self._stateless: - raise StatelessModeNotSupported(method="elicitation") return await self.send_request( types.ElicitRequest( params=types.ElicitRequestFormParams( @@ -358,10 +331,9 @@ async def elicit_url( The client's response indicating acceptance, decline, or cancellation. Raises: - StatelessModeNotSupported: If called in stateless HTTP mode. + NoBackChannelError: The connection has no back-channel for + server-initiated requests. """ - if self._stateless: - raise StatelessModeNotSupported(method="elicitation") return await self.send_request( types.ElicitRequest( params=types.ElicitRequestURLParams( diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 9103996a52..c1c8a0f619 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -28,7 +28,7 @@ from mcp.shared._context_streams import ContextReceiveStream, ContextSendStream, create_context_streams from mcp.shared._stream_protocols import ReadStream, WriteStream from mcp.shared.message import ServerMessageMetadata, SessionMessage -from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS, is_version_at_least +from mcp.shared.version import is_version_at_least from mcp.types import ( DEFAULT_NEGOTIATED_VERSION, INTERNAL_ERROR, @@ -248,12 +248,11 @@ async def close_standalone_stream_callback() -> None: metadata = ServerMessageMetadata( request_context=request, - protocol_version=protocol_version, close_sse_stream=close_stream_callback, close_standalone_sse_stream=close_standalone_stream_callback, ) else: - metadata = ServerMessageMetadata(request_context=request, protocol_version=protocol_version) + metadata = ServerMessageMetadata(request_context=request) return SessionMessage(message, metadata=metadata) @@ -507,10 +506,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re await response(scope, receive, send) # Process the message after sending the response - metadata = ServerMessageMetadata( - request_context=request, - protocol_version=request.headers.get(MCP_PROTOCOL_VERSION_HEADER, DEFAULT_NEGOTIATED_VERSION), - ) + metadata = ServerMessageMetadata(request_context=request) session_message = SessionMessage(message, metadata=metadata) await writer.send(session_message) @@ -533,7 +529,7 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re if self.is_json_response_enabled: # Process the message - metadata = ServerMessageMetadata(request_context=request, protocol_version=protocol_version) + metadata = ServerMessageMetadata(request_context=request) session_message = SessionMessage(message, metadata=metadata) await writer.send(session_message) try: @@ -818,11 +814,10 @@ async def _handle_unsupported_request(self, request: Request, send: Send) -> Non await response(request.scope, request.receive, send) async def _validate_request_headers(self, request: Request, send: Send) -> bool: - if not await self._validate_session(request, send): - return False - if not await self._validate_protocol_version(request, send): - return False - return True + # Protocol-version validation lives in the manager's era-routing: only + # values in `SUPPORTED_PROTOCOL_VERSIONS` (or no header at all) reach + # this transport, so the legacy version-gate is gone. + return await self._validate_session(request, send) async def _validate_session(self, request: Request, send: Send) -> bool: """Validate the session ID in the request.""" @@ -853,28 +848,6 @@ async def _validate_session(self, request: Request, send: Send) -> bool: return True - async def _validate_protocol_version(self, request: Request, send: Send) -> bool: - """Validate the protocol version header in the request.""" - # Get the protocol version from the request headers - protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER) - - # If no protocol version provided, assume default version - if protocol_version is None: - protocol_version = DEFAULT_NEGOTIATED_VERSION - - # Check if the protocol version is supported - if protocol_version not in SUPPORTED_PROTOCOL_VERSIONS: - supported_versions = ", ".join(SUPPORTED_PROTOCOL_VERSIONS) - response = self._create_error_response( - f"Bad Request: Unsupported protocol version: {protocol_version}. " - + f"Supported versions: {supported_versions}", - HTTPStatus.BAD_REQUEST, - ) - await response(request.scope, request.receive, send) - return False - - return True - async def _replay_events(self, last_event_id: str, request: Request, send: Send) -> None: """Replays events that would have been sent after the specified event ID. @@ -894,7 +867,7 @@ async def _replay_events(self, last_event_id: str, request: Request, send: Send) if self.mcp_session_id: # pragma: no branch headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id - # Get protocol version from header (already validated in _validate_protocol_version) + # The manager only routes supported (or absent) header values to this transport replay_protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER, DEFAULT_NEGOTIATED_VERSION) # Create SSE stream for replay diff --git a/src/mcp/server/streamable_http_manager.py b/src/mcp/server/streamable_http_manager.py index 5c6ea531d0..648dcc827f 100644 --- a/src/mcp/server/streamable_http_manager.py +++ b/src/mcp/server/streamable_http_manager.py @@ -16,15 +16,20 @@ from mcp.server._streamable_http_modern import handle_modern_request from mcp.server.auth.middleware.bearer_auth import AuthenticatedUser, AuthorizationContext, authorization_context +from mcp.server.connection import Connection +from mcp.server.runner import serve_connection, serve_loop from mcp.server.streamable_http import ( + MCP_PROTOCOL_VERSION_HEADER, MCP_SESSION_ID_HEADER, EventStore, StreamableHTTPServerTransport, ) from mcp.server.transport_security import TransportSecuritySettings from mcp.shared._compat import resync_tracer -from mcp.shared.version import MODERN_PROTOCOL_VERSIONS -from mcp.types import INVALID_REQUEST, ErrorData, JSONRPCError +from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher +from mcp.shared.transport_context import TransportContext +from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS +from mcp.types import DEFAULT_NEGOTIATED_VERSION, INVALID_REQUEST, ErrorData, JSONRPCError if TYPE_CHECKING: from mcp.server.lowlevel.server import Server @@ -96,8 +101,9 @@ def __init__( # session must present the same credential. self._session_owners: dict[str, AuthorizationContext] = {} - # The task group will be set during lifespan + # The task group and lifespan state are set during run() self._task_group = None + self._lifespan_state: Any = None # Thread-safe tracking of run() calls self._run_lock = anyio.Lock() self._has_started = False @@ -128,8 +134,11 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]: ) self._has_started = True - async with anyio.create_task_group() as tg: - # Store the task group for later use + async with self.app.lifespan(self.app) as lifespan_state, anyio.create_task_group() as tg: + # Store for handle_request: lifespan is entered once for the + # manager's lifetime, not per request (per-connection cleanup + # belongs on `connection.exit_stack`). + self._lifespan_state = lifespan_state self._task_group = tg logger.info("StreamableHTTP session manager started") try: @@ -139,6 +148,7 @@ async def lifespan(app: Starlette) -> AsyncIterator[None]: # Cancel task group to stop all spawned tasks tg.cancel_scope.cancel() self._task_group = None + self._lifespan_state = None # Clear any remaining server instances self._server_instances.clear() self._session_owners.clear() @@ -152,20 +162,26 @@ async def handle_request(self, scope: Scope, receive: Receive, send: Send) -> No if self._task_group is None: raise RuntimeError("Task group is not initialized. Make sure to use run().") - # TODO: header-only routing for now; body-primary classification - # (per SEP-2575) is a follow-up. 2025 paths below remain unchanged. - pv = next((v.decode("latin-1") for k, v in scope["headers"] if k == b"mcp-protocol-version"), None) - if pv in MODERN_PROTOCOL_VERSIONS: - await handle_modern_request(self.app, self.security_settings, pv, scope, receive, send) + # TODO(L49): header-only era-routing for now; body-primary classification + # is a follow-up. The legacy paths below own only the known + # initialize-handshake versions; anything else (including unknown + # values) goes to the modern entry so the classifier can validate it + # and return a structured rejection. 2025 paths below remain unchanged. + header = MCP_PROTOCOL_VERSION_HEADER.encode("ascii") + pv = next((v.decode("latin-1") for k, v in scope["headers"] if k == header), None) + if pv is not None and pv not in SUPPORTED_PROTOCOL_VERSIONS: + await handle_modern_request(self.app, self.security_settings, self._lifespan_state, scope, receive, send) return # Dispatch to the appropriate handler if self.stateless: - await self._handle_stateless_request(scope, receive, send) + await self._handle_stateless_request(pv, scope, receive, send) else: await self._handle_stateful_request(scope, receive, send) - async def _handle_stateless_request(self, scope: Scope, receive: Receive, send: Send) -> None: + async def _handle_stateless_request( + self, protocol_version_hint: str | None, scope: Scope, receive: Receive, send: Send + ) -> None: """Process request in stateless mode - creating a new transport for each request.""" logger.debug("Stateless mode: Creating new transport for this request") # No session ID needed in stateless mode @@ -181,12 +197,29 @@ async def run_stateless_server(*, task_status: TaskStatus[None] = anyio.TASK_STA async with http_transport.connect() as streams: read_stream, write_stream = streams task_status.started() + dispatcher: JSONRPCDispatcher[TransportContext] = JSONRPCDispatcher( + read_stream, + write_stream, + inline_methods=frozenset({"initialize"}), + # No session ID means a server-to-client request can be + # written to this POST's response stream, but the client's + # reply has nowhere to land — `can_send_request=False` + # makes the per-request channel raise `NoBackChannelError` + # for requests while still allowing notifications. + transport_builder=lambda _md: TransportContext(kind="streamable-http", can_send_request=False), + ) + # Born-ready, no standalone channel: the legacy stateless path + # never opens a GET stream and need not see `initialize`. The + # header (or the spec's default-absent value) seeds + # `ctx.protocol_version`. + connection = Connection.from_envelope( + protocol_version_hint if protocol_version_hint is not None else DEFAULT_NEGOTIATED_VERSION, + None, + None, + ) try: - await self.app.run( - read_stream, - write_stream, - self.app.create_initialization_options(), - stateless=True, + await serve_connection( + self.app, dispatcher, connection=connection, lifespan_state=self._lifespan_state ) except Exception: # pragma: lax no cover logger.exception("Stateless session crashed") @@ -263,7 +296,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE task_status.started() try: # Use a cancel scope for idle timeout — when the - # deadline passes the scope cancels app.run() and + # deadline passes the scope cancels the loop and # execution continues after the ``with`` block. # Incoming requests push the deadline forward. idle_scope = anyio.CancelScope() @@ -272,11 +305,15 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE http_transport.idle_scope = idle_scope with idle_scope: - await self.app.run( + # Drive via `serve_loop` (not `Server.run()`) so the + # manager's already-entered lifespan is reused + # rather than re-entered per session. + await serve_loop( + self.app, read_stream, write_stream, - self.app.create_initialization_options(), - stateless=False, + lifespan_state=self._lifespan_state, + session_id=http_transport.mcp_session_id, ) if idle_scope.cancelled_caught: @@ -309,7 +346,7 @@ async def run_server(*, task_status: TaskStatus[None] = anyio.TASK_STATUS_IGNORE await http_transport.handle_request(scope, receive, send) else: # Unknown or expired session ID - return 404 per MCP spec - # TODO: Align error code once spec clarifies + # TODO(L62): Align error code once spec clarifies # See: https://github.com/modelcontextprotocol/python-sdk/issues/1821 logger.info(f"Rejected request with unknown or expired session ID: {request_mcp_session_id[:64]}") body = JSONRPCError( diff --git a/src/mcp/shared/exceptions.py b/src/mcp/shared/exceptions.py index 9c70588022..7d17d59e13 100644 --- a/src/mcp/shared/exceptions.py +++ b/src/mcp/shared/exceptions.py @@ -71,23 +71,6 @@ def __init__(self, method: str): self.method = method -class StatelessModeNotSupported(RuntimeError): - """Raised when attempting to use a method that is not supported in stateless mode. - - Server-to-client requests (sampling, elicitation, list_roots) are not - supported in stateless HTTP mode because there is no persistent connection - for bidirectional communication. - """ - - def __init__(self, method: str): - super().__init__( - f"Cannot use {method} in stateless HTTP mode. " - "Stateless mode does not support server-to-client requests. " - "Use stateful mode (stateless_http=False) to enable this feature." - ) - self.method = method - - class UrlElicitationRequiredError(MCPError): """Specialized error for when a tool requires URL mode elicitation(s) before proceeding. diff --git a/src/mcp/shared/inbound.py b/src/mcp/shared/inbound.py new file mode 100644 index 0000000000..04aa93c141 --- /dev/null +++ b/src/mcp/shared/inbound.py @@ -0,0 +1,153 @@ +"""Inbound request classification for the modern per-request-envelope path. + +Pure module: no I/O, no transport, no ``mcp.server`` imports. Runs the +validation ladder against a decoded JSON-RPC body and returns either an +:class:`InboundModernRoute` (every rung passed) or an +:class:`InboundLadderRejection` (the first rung that failed). Callers map a +rejection's ``code`` through :data:`ERROR_CODE_HTTP_STATUS` to pick the HTTP +status. +""" + +from collections.abc import Mapping, Sequence +from dataclasses import dataclass +from types import MappingProxyType +from typing import Any, Final + +from mcp.shared.version import MODERN_PROTOCOL_VERSIONS +from mcp.types import ( + CLIENT_CAPABILITIES_META_KEY, + CLIENT_INFO_META_KEY, + PROTOCOL_VERSION_META_KEY, + UnsupportedProtocolVersionErrorData, +) +from mcp.types.jsonrpc import ( + HEADER_MISMATCH, + INVALID_PARAMS, + INVALID_REQUEST, + METHOD_NOT_FOUND, + MISSING_REQUIRED_CLIENT_CAPABILITY, + PARSE_ERROR, + UNSUPPORTED_PROTOCOL_VERSION, +) + +__all__ = [ + "ERROR_CODE_HTTP_STATUS", + "InboundLadderRejection", + "InboundModernRoute", + "MCP_PROTOCOL_VERSION_HEADER", + "classify_inbound_request", +] + +MCP_PROTOCOL_VERSION_HEADER: Final = "mcp-protocol-version" +"""Canonical lowercase name of the HTTP header carrying the MCP protocol version.""" + +# INTERNAL_ERROR is deliberately unmapped (→ HTTP 200): the spec assigns no status to +# -32603, and whether handler-origin errors get 5xx is an open S4 question — see TODO(L66). +ERROR_CODE_HTTP_STATUS: Final[Mapping[int, int]] = MappingProxyType( + { + PARSE_ERROR: 400, + INVALID_REQUEST: 400, + INVALID_PARAMS: 400, + HEADER_MISMATCH: 400, + MISSING_REQUIRED_CLIENT_CAPABILITY: 400, + UNSUPPORTED_PROTOCOL_VERSION: 400, + METHOD_NOT_FOUND: 404, + } +) +"""HTTP status to send for a JSON-RPC ``error.code``. + +Consulted for classifier-origin *and* handler-origin errors, so one table +decides the wire status regardless of where the error was produced. Unmapped +codes fall back to the caller's default (typically 200). +""" + + +@dataclass(frozen=True) +class InboundModernRoute: + """A modern-protocol request whose envelope passed every ladder rung. + + ``client_info`` and ``client_capabilities`` are the raw envelope values; + the classifier checks presence only, not shape. Method existence is not a + ladder rung — kernel dispatch is the single source of truth for that. + """ + + protocol_version: str + client_info: Any + client_capabilities: Any + + +@dataclass(frozen=True) +class InboundLadderRejection: + """The first ladder rung that failed, as JSON-RPC error fields.""" + + code: int + message: str + data: Any = None + + +def classify_inbound_request( + body: Mapping[str, Any], + *, + headers: Mapping[str, str] | None = None, + supported_modern_versions: Sequence[str] = MODERN_PROTOCOL_VERSIONS, +) -> InboundModernRoute | InboundLadderRejection: + """Run the modern-protocol validation ladder over a decoded JSON-RPC body. + + Rungs, in order — first failure wins: + + 1. ``params._meta`` is a mapping carrying every reserved envelope key + (protocol version, client info, client capabilities) → else + :data:`~mcp.types.jsonrpc.INVALID_PARAMS`. + 2. When ``headers`` is given, its ``MCP-Protocol-Version`` entry equals + the envelope's protocol version → else + :data:`~mcp.types.jsonrpc.HEADER_MISMATCH`. Runs before the + supported-version rung so a client that disagrees with itself is told + so, rather than told the body's version is unsupported. + 3. The envelope's protocol version is in ``supported_modern_versions`` → + else :data:`~mcp.types.jsonrpc.UNSUPPORTED_PROTOCOL_VERSION` with + ``data = {"supported": [...], "requested": }``. + + Method existence is *not* a rung: kernel dispatch owns that decision so + custom-registered methods route and the answer lives in one place. + + Args: + body: The decoded JSON-RPC request mapping. Envelope shape + (``jsonrpc`` / ``id``) is not checked here. + headers: Transport headers keyed by lowercase name, or ``None`` to + skip the header rung (non-HTTP callers). + supported_modern_versions: Modern protocol revisions this server + accepts on the per-request-envelope path. + """ + try: + meta = body["params"]["_meta"] + protocol_version = meta[PROTOCOL_VERSION_META_KEY] + client_info = meta[CLIENT_INFO_META_KEY] + client_capabilities = meta[CLIENT_CAPABILITIES_META_KEY] + except (KeyError, TypeError): + return InboundLadderRejection( + code=INVALID_PARAMS, + message="params._meta must carry the reserved protocol-version, client-info and " + "client-capabilities envelope keys", + ) + + # TODO(L59): also validate Mcp-Method / Mcp-Name per SEP-2243 §Server Validation + if headers is not None and headers.get(MCP_PROTOCOL_VERSION_HEADER) != protocol_version: + return InboundLadderRejection( + code=HEADER_MISMATCH, + message=f"{MCP_PROTOCOL_VERSION_HEADER} header does not match the request envelope's protocol version", + ) + + if protocol_version not in supported_modern_versions: + return InboundLadderRejection( + code=UNSUPPORTED_PROTOCOL_VERSION, + message="Unsupported protocol version", + data=UnsupportedProtocolVersionErrorData( + supported=list(supported_modern_versions), requested=protocol_version + ).model_dump(mode="json"), + ) + + return InboundModernRoute( + protocol_version=protocol_version, + client_info=client_info, + client_capabilities=client_capabilities, + ) diff --git a/src/mcp/shared/jsonrpc_dispatcher.py b/src/mcp/shared/jsonrpc_dispatcher.py index a59cd119dc..7fabafff65 100644 --- a/src/mcp/shared/jsonrpc_dispatcher.py +++ b/src/mcp/shared/jsonrpc_dispatcher.py @@ -49,7 +49,7 @@ RequestId, ) -__all__ = ["JSONRPCDispatcher"] +__all__ = ["JSONRPCDispatcher", "handler_exception_to_error_data"] logger = logging.getLogger(__name__) @@ -67,6 +67,23 @@ the handler's scope; `"signal"` only sets `ctx.cancel_requested`.""" +def handler_exception_to_error_data(exc: BaseException) -> ErrorData | None: + """Map a handler-raised exception to its wire `ErrorData`. + + The two rungs every dispatcher shares: an `MCPError` carries its own + `ErrorData`; a pydantic `ValidationError` is the spec's INVALID_PARAMS + with empty ``data`` (no pydantic text on the wire). Returns ``None`` for + any other exception so each caller applies its own catch-all - + `JSONRPCDispatcher` currently pins ``code=0`` for v1 compat, + `to_jsonrpc_response` uses `INTERNAL_ERROR`. + """ + if isinstance(exc, MCPError): + return exc.error + if isinstance(exc, ValidationError): + return ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="") + return None + + def _coerce_id(request_id: RequestId) -> RequestId: """Coerce a stringified int request ID back to int so a peer-echoed ID still correlates (matches the TS SDK).""" if isinstance(request_id, str): @@ -667,7 +684,7 @@ async def _handle_request( # anyio absorbs the scope's own cancel at __exit__, and # `cancelled_caught` (unlike `cancel_called`) guarantees the # result write above did not happen - no double response. - # TODO(maxisbey): spec says SHOULD NOT respond after cancel; + # TODO(L38): spec says SHOULD NOT respond after cancel; # the existing server always has, so match that for now. answer_write_started = True await self._write_error(req.id, ErrorData(code=0, message="Request cancelled")) @@ -684,21 +701,17 @@ async def _handle_request( describe=f"shutdown error response for request {req.id!r}", ) raise - except MCPError as e: - await self._write_error(req.id, e.error) - except ValidationError: - # TODO(maxisbey): data="" pins existing-server compat (no pydantic - # text on the wire); revisit per the suite's divergence entry. - await self._write_error( - req.id, ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="") - ) except Exception as e: - logger.exception("handler for %r raised", req.method) - # TODO(maxisbey): code=0 pins existing-server compat; JSON-RPC says - # INTERNAL_ERROR. Revisit per the suite's divergence entry. - await self._write_error(req.id, ErrorData(code=0, message=str(e))) - if self._raise_handler_exceptions: - raise + error = handler_exception_to_error_data(e) + if error is not None: + await self._write_error(req.id, error) + else: + logger.exception("handler for %r raised", req.method) + # TODO(L58): code=0 pins existing-server compat; JSON-RPC says + # INTERNAL_ERROR. Revisit per the suite's divergence entry. + await self._write_error(req.id, ErrorData(code=0, message=str(e))) + if self._raise_handler_exceptions: + raise # No `_in_flight` pop here: the inner finally covers every path, and a late pop could evict a reused id. def _allocate_id(self) -> int: diff --git a/tests/client/test_list_roots_callback.py b/tests/client/test_list_roots_callback.py index f597ef7c09..423ab967bb 100644 --- a/tests/client/test_list_roots_callback.py +++ b/tests/client/test_list_roots_callback.py @@ -4,7 +4,8 @@ from mcp import Client from mcp.client import ClientRequestContext from mcp.server.mcpserver import Context, MCPServer -from mcp.types import ListRootsResult, Root, TextContent +from mcp.shared.exceptions import MCPError +from mcp.types import INVALID_REQUEST, ListRootsResult, Root, TextContent @pytest.mark.anyio @@ -37,10 +38,10 @@ async def test_list_roots(context: Context, message: str): assert isinstance(result.content[0], TextContent) assert result.content[0].text == "true" - # Test without list_roots callback + # Without a list_roots callback the client responds with an MCPError, which the + # tool body doesn't catch — the wrapper re-raises it as a top-level JSON-RPC + # error rather than wrapping it as an isError result. async with Client(server) as client: - # Make a request to trigger sampling callback - result = await client.call_tool("test_list_roots", {"message": "test message"}) - assert result.is_error is True - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Error executing tool test_list_roots: List roots not supported" + with pytest.raises(MCPError) as exc_info: + await client.call_tool("test_list_roots", {"message": "test message"}) + assert exc_info.value.error.code == INVALID_REQUEST diff --git a/tests/client/test_notification_response.py b/tests/client/test_notification_response.py index 69c8afeb84..4dbd78dbbe 100644 --- a/tests/client/test_notification_response.py +++ b/tests/client/test_notification_response.py @@ -201,3 +201,62 @@ async def test_invalid_json_response_sends_jsonrpc_error() -> None: with pytest.raises(MCPError, match="Failed to parse JSON response"): # pragma: no branch await session.list_tools() + + +def _create_non_2xx_json_body_app(status: int, body: bytes) -> Starlette: + """Server that returns a fixed non-2xx status + ``application/json`` body for non-init requests.""" + + async def handle_mcp_request(request: Request) -> Response: + data = json.loads(await request.body()) + if data.get("method") == "initialize": + return _init_json_response(data) + if "id" not in data: + return Response(status_code=202) + return Response(content=body, status_code=status, media_type="application/json") + + return Starlette(debug=True, routes=[Route("/mcp", handle_mcp_request, methods=["POST"])]) + + +async def test_client_surfaces_jsonrpc_error_from_non_2xx_body_with_correlated_id() -> None: + """SDK-defined: a JSON-RPC error in a non-2xx body is surfaced verbatim even when the + server set ``id: null`` — the client rewraps it under the pending request's id, so + the awaiting call resolves with the server's error code instead of the generic fallback.""" + body = json.dumps( + {"jsonrpc": "2.0", "id": None, "error": {"code": types.METHOD_NOT_FOUND, "message": "nope"}} + ).encode() + app = _create_non_2xx_json_body_app(400, body) + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) as client: + async with streamable_http_client("http://localhost/mcp", http_client=client) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + await session.initialize() + with pytest.raises(MCPError) as exc: + await session.list_tools() + assert exc.value.error.code == types.METHOD_NOT_FOUND + + +async def test_client_falls_back_to_generic_error_when_non_2xx_body_is_a_jsonrpc_result() -> None: + """SDK-defined: a non-2xx response whose JSON body parses as a JSON-RPC *result* (not an + error) falls through to the generic ``INTERNAL_ERROR`` fallback rather than being + treated as the request's reply.""" + app = _create_non_2xx_json_body_app(400, b'{"jsonrpc":"2.0","id":1,"result":{}}') + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) as client: + async with streamable_http_client("http://localhost/mcp", http_client=client) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + await session.initialize() + with pytest.raises(MCPError) as exc: + await session.list_tools() + assert exc.value.error.code == types.INTERNAL_ERROR + + +async def test_client_falls_back_to_session_terminated_when_404_body_is_malformed_json() -> None: + """SDK-defined: an unparseable ``application/json`` body on a 404 response is swallowed + and the status-derived ``INVALID_REQUEST`` (session-terminated) fallback resolves the + pending request — the parse failure never propagates.""" + app = _create_non_2xx_json_body_app(404, b"not valid json{{{") + async with httpx.AsyncClient(transport=httpx.ASGITransport(app=app)) as client: + async with streamable_http_client("http://localhost/mcp", http_client=client) as (read_stream, write_stream): + async with ClientSession(read_stream, write_stream) as session: # pragma: no branch + await session.initialize() + with pytest.raises(MCPError) as exc: + await session.list_tools() + assert exc.value.error.code == types.INVALID_REQUEST diff --git a/tests/client/test_sampling_callback.py b/tests/client/test_sampling_callback.py index 5163ef043a..901caa69f8 100644 --- a/tests/client/test_sampling_callback.py +++ b/tests/client/test_sampling_callback.py @@ -3,7 +3,9 @@ from mcp import Client from mcp.client import ClientRequestContext from mcp.server.mcpserver import Context, MCPServer +from mcp.shared.exceptions import MCPError from mcp.types import ( + INVALID_REQUEST, CreateMessageRequestParams, CreateMessageResult, CreateMessageResultWithTools, @@ -47,13 +49,13 @@ async def test_sampling_tool(message: str, ctx: Context) -> bool: assert isinstance(result.content[0], TextContent) assert result.content[0].text == "true" - # Test without sampling callback + # Without a sampling callback the client responds with an MCPError, which the + # tool body doesn't catch — the wrapper re-raises it as a top-level JSON-RPC + # error rather than wrapping it as an isError result. async with Client(server) as client: - # Make a request to trigger sampling callback - result = await client.call_tool("test_sampling", {"message": "Test message for sampling"}) - assert result.is_error is True - assert isinstance(result.content[0], TextContent) - assert result.content[0].text == "Error executing tool test_sampling: Sampling not supported" + with pytest.raises(MCPError) as exc_info: + await client.call_tool("test_sampling", {"message": "Test message for sampling"}) + assert exc_info.value.error.code == INVALID_REQUEST @pytest.mark.anyio diff --git a/tests/interaction/_requirements.py b/tests/interaction/_requirements.py index 1bd766f54b..9aee73b29b 100644 --- a/tests/interaction/_requirements.py +++ b/tests/interaction/_requirements.py @@ -97,9 +97,9 @@ ) _MODERN_NOTIFY_DROP = ( - "SingleExchangeDispatcher.notify() no-ops on the modern streamable-http driver; handler-emitted " - "logging/progress notifications never reach the per-request SSE response. Passes once SSE " - "response mode lands." + "The modern single-exchange dispatch context no-ops notify() on the streamable-http driver; " + "handler-emitted logging/progress notifications never reach the per-request SSE response. " + "Passes once SSE response mode lands." ) @@ -3046,6 +3046,50 @@ def __post_init__(self) -> None: transports=("streamable-http",), note="Only observable over streamable HTTP: the modern entry's exception-to-JSONRPCError boundary.", ), + "hosting:http:modern:discover-response-shape": Requirement( + source=f"{SPEC_2026_BASE_URL}/basic/index", + behavior=( + "A 2026-07-28 server/discover response carries supportedVersions, capabilities, and " + "serverInfo, with supportedVersions naming the modern protocol revisions the server accepts." + ), + added_in="2026-07-28", + transports=("streamable-http",), + note="Only observable over streamable HTTP: the raw result body is asserted at the wire.", + ), + "hosting:http:modern:removed-method-status-404": Requirement( + source=f"{SPEC_2026_BASE_URL}/basic/index", + behavior=( + "A method that exists at earlier protocol revisions but is removed at 2026-07-28 is " + "answered METHOD_NOT_FOUND, and the modern entry maps that error code to HTTP 404." + ), + added_in="2026-07-28", + transports=("streamable-http",), + note=( + "Only observable over streamable HTTP: the HTTP status is the assertion. Kernel-origin " + "METHOD_NOT_FOUND travels through the same status table as classifier-origin errors." + ), + ), + "hosting:http:modern:envelope-missing-key-status-400": Requirement( + source=f"{SPEC_2026_BASE_URL}/basic/transports/streamable-http", + behavior=( + "A 2026-07-28 request whose params._meta envelope omits a required reserved key is " + "rejected as INVALID_PARAMS at HTTP 400 before kernel dispatch." + ), + added_in="2026-07-28", + transports=("streamable-http",), + note="Only observable over streamable HTTP: the HTTP status is the assertion.", + ), + "hosting:http:modern:handler-error-status-via-table": Requirement( + source="sdk", + behavior=( + "A handler-raised MCPError on the 2026-07-28 entry reaches the wire as a top-level " + "JSON-RPC error with its data preserved, and the HTTP status is the error-code table " + "entry for that code (handler-origin and classifier-origin errors share one table)." + ), + added_in="2026-07-28", + transports=("streamable-http",), + note="Only observable over streamable HTTP: the modern entry's JSONRPCError-to-HTTP-status mapping.", + ), # ═══════════════════════════════════════════════════════════════════════════ # Client transport: streamable HTTP # ═══════════════════════════════════════════════════════════════════════════ diff --git a/tests/interaction/transports/test_hosting_http.py b/tests/interaction/transports/test_hosting_http.py index 9b46dc533d..5b5c7085b3 100644 --- a/tests/interaction/transports/test_hosting_http.py +++ b/tests/interaction/transports/test_hosting_http.py @@ -14,9 +14,14 @@ from mcp.server import Server, ServerRequestContext from mcp.server.transport_security import TransportSecuritySettings +from mcp.shared.version import MODERN_PROTOCOL_VERSIONS from mcp.types import ( + CLIENT_CAPABILITIES_META_KEY, + CLIENT_INFO_META_KEY, INVALID_PARAMS, PARSE_ERROR, + PROTOCOL_VERSION_META_KEY, + UNSUPPORTED_PROTOCOL_VERSION, CallToolRequestParams, CallToolResult, EmptyResult, @@ -155,7 +160,12 @@ async def test_malformed_and_batched_bodies_return_400() -> None: @requirement("hosting:http:protocol-version-400") @requirement("hosting:http:protocol-version-default") async def test_protocol_version_header_is_validated() -> None: - """An unsupported MCP-Protocol-Version header returns 400; an absent header is accepted as the default.""" + """An unsupported MCP-Protocol-Version header returns 400; an absent header is accepted as the default. + + An unrecognised header value routes to the modern entry (which owns rejection of unknown + versions), and a request without the per-request envelope is rejected at the first ladder + rung. Only known initialize-handshake versions and an absent header reach the legacy path. + """ async with mounted_app(_server()) as (http, _): session_id = await initialize_via_http(http) @@ -172,9 +182,7 @@ async def test_protocol_version_header_is_validated() -> None: ) assert bad.status_code == 400 - assert JSONRPCError.model_validate_json(bad.text).error.message.startswith( - "Bad Request: Unsupported protocol version: 1991-01-01." - ) + assert JSONRPCError.model_validate_json(bad.text).error.code == INVALID_PARAMS # 202 proves the request was accepted under the assumed default version (2025-03-26). assert defaulted.status_code == 202 @@ -185,18 +193,27 @@ async def test_unsupported_protocol_version_rejection_body_contains_the_sniffed_ SDK-defined: other SDKs detect this rejection by substring-matching ``Unsupported protocol version`` in the response body, so the literal must survive any rewording of the surrounding - message. Asserted at the wire because the SDK client never surfaces the rejection body. + message. The unsupported value must appear in both the header and the envelope so the + classifier reaches its version-supported rung rather than reporting a header mismatch first. """ + bad = "1991-01-01" + meta = { + PROTOCOL_VERSION_META_KEY: bad, + CLIENT_INFO_META_KEY: {"name": "t", "version": "0"}, + CLIENT_CAPABILITIES_META_KEY: {}, + } async with mounted_app(_server()) as (http, _): - session_id = await initialize_via_http(http) response = await http.post( "/mcp", - json={"jsonrpc": "2.0", "id": 2, "method": "ping"}, - headers=base_headers(session_id=session_id) | {"mcp-protocol-version": "1991-01-01"}, + json={"jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {"_meta": meta}}, + headers=base_headers() | {"mcp-protocol-version": bad}, ) assert response.status_code == 400 + error = JSONRPCError.model_validate_json(response.text).error + assert error.code == UNSUPPORTED_PROTOCOL_VERSION assert "Unsupported protocol version" in response.text + assert error.data == {"supported": list(MODERN_PROTOCOL_VERSIONS), "requested": bad} @requirement("hosting:http:json-response-mode") diff --git a/tests/interaction/transports/test_hosting_http_modern.py b/tests/interaction/transports/test_hosting_http_modern.py index 1f043510fe..f943f9e89e 100644 --- a/tests/interaction/transports/test_hosting_http_modern.py +++ b/tests/interaction/transports/test_hosting_http_modern.py @@ -16,23 +16,29 @@ import pytest from inline_snapshot import snapshot +from mcp import MCPError from mcp.client.session import ClientSession from mcp.client.streamable_http import streamable_http_client from mcp.server import Server, ServerRequestContext from mcp.types import ( + CLIENT_CAPABILITIES_META_KEY, INTERNAL_ERROR, + INVALID_PARAMS, METHOD_NOT_FOUND, + MISSING_REQUIRED_CLIENT_CAPABILITY, CallToolRequestParams, CallToolResult, + EmptyResult, Implementation, JSONRPCError, JSONRPCResponse, ListToolsResult, PaginatedRequestParams, + RequestParams, TextContent, Tool, ) -from tests.interaction._connect import BASE_URL, base_headers, initialize_body, initialize_via_http, mounted_app +from tests.interaction._connect import BASE_URL, base_headers, initialize_via_http, mounted_app from tests.interaction._requirements import requirement pytestmark = pytest.mark.anyio @@ -133,28 +139,30 @@ async def test_modern_response_carries_no_session_id_header() -> None: @requirement("hosting:http:modern:initialize-removed") async def test_modern_initialize_is_method_not_found() -> None: - """A 2026-07-28 initialize request is answered with METHOD_NOT_FOUND. - - Spec-mandated under the draft: initialize is not a defined method at 2026-07-28, so the - method/version gate rejects it before any handler runs. Asserted at the wire because the SDK - client at 2026-07-28 never sends initialize, so only a raw POST can drive the negative. + """A 2026-07-28 initialize request that carries a valid envelope is answered METHOD_NOT_FOUND at HTTP 404. + + Spec-mandated under the draft: initialize is not a defined method at 2026-07-28, so the kernel's + method/version gate rejects it before any handler runs. The body must carry the per-request + ``_meta`` envelope so the classifier ladder admits it as far as kernel dispatch -- without the + envelope the request is INVALID_PARAMS at rung 1, never METHOD_NOT_FOUND. Asserted at the wire + because the SDK client at 2026-07-28 never sends initialize, so only a raw POST can drive the + negative. """ + body = {"jsonrpc": "2.0", "id": 1, "method": "initialize", "params": {"_meta": _meta_envelope()}} async with mounted_app(_server()) as (http, _): - response = await http.post("/mcp", json=initialize_body(), headers=_modern_headers(method="initialize")) + response = await http.post("/mcp", json=body, headers=_modern_headers(method="initialize")) - assert response.status_code == 200 + assert response.status_code == 404 assert JSONRPCError.model_validate(response.json()).error.code == METHOD_NOT_FOUND @requirement("hosting:http:modern:legacy-fallthrough") -async def test_non_modern_version_header_falls_through_to_legacy_transport_unchanged() -> None: - """The 2026-07-28 routing branch fires only on its exact header; everything else reaches legacy. - - SDK-defined under the draft versioning rules: the modern entry must not change any 2025-era - byte. A 2025-era initialize on the same endpoint still completes (legacy serves it), and an - unrecognised ``MCP-Protocol-Version`` still falls through to the legacy gate and produces the - ``Unsupported protocol version`` literal that peer SDKs substring-sniff. Asserted at the wire - because the literal is only observable in the raw response body. +async def test_legacy_version_header_falls_through_and_unrecognised_header_routes_to_modern() -> None: + """SDK-defined under the draft versioning rules: only the known initialize-handshake protocol + versions reach the legacy transport, so a 2025-era ``initialize`` on the same endpoint still + completes unchanged. Any other ``MCP-Protocol-Version`` value routes to the modern entry, + where the validation ladder rejects it (a request without the per-request envelope fails the + first rung). The modern entry is therefore the single owner of unknown-version rejection. """ async with mounted_app(_server()) as (http, _): # 2025-era initialize through the same endpoint: the modern branch must not intercept it. @@ -166,7 +174,7 @@ async def test_non_modern_version_header_falls_through_to_legacy_transport_uncha ) assert unrecognised.status_code == 400 - assert "Unsupported protocol version" in unrecognised.text + assert JSONRPCError.model_validate_json(unrecognised.text).error.code == INVALID_PARAMS @requirement("hosting:http:modern:handler-exception-internal-error") @@ -198,6 +206,92 @@ async def call_tool(ctx: ServerRequestContext, params: CallToolRequestParams) -> assert "kaboom" not in error.message +@requirement("hosting:http:modern:discover-response-shape") +async def test_modern_server_discover_returns_capabilities_and_supported_versions() -> None: + """A 2026-07-28 server/discover POST returns capabilities, serverInfo, and supportedVersions. + + Spec-mandated under the draft: server/discover is the 2026 advertisement method that replaces + the initialize-response payload, and ``supportedVersions`` is the field a client picks its + per-request envelope version from. Asserted at the wire because the SDK client never exposes + the raw result body. + """ + body = {"jsonrpc": "2.0", "id": 1, "method": "server/discover", "params": {"_meta": _meta_envelope()}} + async with mounted_app(_server()) as (http, _): + response = await http.post("/mcp", json=body, headers=_modern_headers(method="server/discover")) + + assert response.status_code == 200 + result = JSONRPCResponse.model_validate(response.json()).result + assert result["supportedVersions"] == snapshot(["2026-07-28"]) + assert result["serverInfo"]["name"] == "modern" + assert "capabilities" in result + + +@requirement("hosting:http:modern:removed-method-status-404") +async def test_modern_removed_method_is_method_not_found_at_http_404() -> None: + """A 2026-07-28 ping (removed at 2026) is answered METHOD_NOT_FOUND and the HTTP status is 404. + + Spec-mandated for the error code: ping is not a defined method at 2026-07-28 so the kernel's + method/version gate rejects it. SDK-defined for the HTTP status: kernel-origin METHOD_NOT_FOUND + travels through the same error-code-to-status table as classifier-origin errors. Asserted at the + wire because the HTTP status is the assertion. + """ + body = {"jsonrpc": "2.0", "id": 1, "method": "ping", "params": {"_meta": _meta_envelope()}} + async with mounted_app(_server()) as (http, _): + response = await http.post("/mcp", json=body, headers=_modern_headers(method="ping")) + + assert response.status_code == 404 + assert JSONRPCError.model_validate(response.json()).error.code == METHOD_NOT_FOUND + + +@requirement("hosting:http:modern:envelope-missing-key-status-400") +async def test_modern_envelope_missing_required_meta_key_is_invalid_params_at_http_400() -> None: + """A 2026-07-28 request whose ``_meta`` envelope omits a required key is INVALID_PARAMS at HTTP 400. + + Spec-mandated under the draft transport: the per-request envelope must carry every reserved key, + so a missing ``clientCapabilities`` fails the classifier's first rung before any kernel dispatch. + Asserted at the wire because the HTTP status is the assertion. + """ + incomplete = _meta_envelope() + del incomplete[CLIENT_CAPABILITIES_META_KEY] + body = {"jsonrpc": "2.0", "id": 1, "method": "tools/list", "params": {"_meta": incomplete}} + async with mounted_app(_server()) as (http, _): + response = await http.post("/mcp", json=body, headers=_modern_headers(method="tools/list")) + + assert response.status_code == 400 + assert JSONRPCError.model_validate(response.json()).error.code == INVALID_PARAMS + + +@requirement("hosting:http:modern:handler-error-status-via-table") +async def test_modern_handler_raised_mcperror_maps_to_status_via_error_code_table() -> None: + """A handler-raised ``MCPError`` reaches the wire as a top-level JSON-RPC error at the table-mapped HTTP status. + + SDK-defined for the HTTP status: the modern entry maps every JSON-RPC ``error.code`` -- whether + classifier-origin or handler-origin -- through one error-code-to-status table, so a handler + raising ``MISSING_REQUIRED_CLIENT_CAPABILITY`` produces HTTP 400 with ``error.data`` preserved. + Spec-mandated for the error code: the named code and its ``requiredCapabilities`` data shape are + the spec's capability-gating contract. Registered via the low-level ``add_request_handler`` so + the high-level tool wrapper's error-swallowing is not on the path. + """ + + async def cap_check(ctx: ServerRequestContext, params: RequestParams) -> EmptyResult: + raise MCPError( + code=MISSING_REQUIRED_CLIENT_CAPABILITY, + message="sampling required", + data={"requiredCapabilities": ["sampling"]}, + ) + + server = _server() + server.add_request_handler("test/cap-check", RequestParams, cap_check) + body = {"jsonrpc": "2.0", "id": 1, "method": "test/cap-check", "params": {"_meta": _meta_envelope()}} + async with mounted_app(server) as (http, _): + response = await http.post("/mcp", json=body, headers=_modern_headers(method="test/cap-check")) + + assert response.status_code == 400 + error = JSONRPCError.model_validate(response.json()).error + assert error.code == MISSING_REQUIRED_CLIENT_CAPABILITY + assert error.data == {"requiredCapabilities": ["sampling"]} + + @requirement("hosting:http:modern:tools-call-stateless") @requirement("lifecycle:stateless:request-envelope") @requirement("lifecycle:stateless:caller-meta-preserved") diff --git a/tests/interaction/transports/test_streamable_http.py b/tests/interaction/transports/test_streamable_http.py index bf5a32f5ba..cb63e389ca 100644 --- a/tests/interaction/transports/test_streamable_http.py +++ b/tests/interaction/transports/test_streamable_http.py @@ -15,7 +15,9 @@ from mcp.client import ClientRequestContext from mcp.server.elicitation import AcceptedElicitation from mcp.server.mcpserver import Context, MCPServer +from mcp.shared.exceptions import MCPError from mcp.types import ( + INVALID_REQUEST, CallToolResult, ElicitRequestParams, ElicitResult, @@ -92,15 +94,15 @@ async def test_tool_calls_over_stateless_streamable_http() -> None: @requirement("transport:streamable-http:stateless-restrictions") async def test_stateless_streamable_http_rejects_server_initiated_requests() -> None: - """A handler that tries to call back to the client in stateless mode fails: there is no session.""" + """A handler that tries to call back to the client in stateless mode fails: there is no + back-channel for server-initiated requests. The resulting ``NoBackChannelError`` is an + ``MCPError``, so it surfaces as a top-level JSON-RPC error rather than an + ``isError`` result.""" async with connect_over_streamable_http(_smoke_server(), stateless_http=True) as client: - result = await client.call_tool("ask", {}) + with pytest.raises(MCPError) as exc_info: + await client.call_tool("ask", {}) - assert result.is_error is True - assert isinstance(result.content[0], TextContent) - # The exact message is the StatelessModeNotSupported exception text wrapped by the tool-error - # path; pin the stable prefix rather than the full exception prose. - assert result.content[0].text.startswith("Error executing tool ask:") + assert exc_info.value.error.code == INVALID_REQUEST @requirement("transport:streamable-http:notifications") diff --git a/tests/server/lowlevel/test_server_discover.py b/tests/server/lowlevel/test_server_discover.py new file mode 100644 index 0000000000..ed26e7244d --- /dev/null +++ b/tests/server/lowlevel/test_server_discover.py @@ -0,0 +1,151 @@ +"""Direct-handler tests for the auto-derived `server/discover` handler. + +These call the registered handler via the public `Server.get_request_handler` +accessor without spinning up a `ServerRunner` or any transport, so they verify +the handler's contract in isolation from the dispatch pipeline. +""" + +import importlib.metadata +from typing import Any, cast + +import pytest + +from mcp import types +from mcp.server import Server, ServerRequestContext +from mcp.shared.version import MODERN_PROTOCOL_VERSIONS + +# `Server._handle_discover` ignores its `ctx` argument entirely (it derives the +# result from server state), so a sentinel keeps the call site type-correct +# without dragging session machinery into a unit test. +_UNUSED_CTX = cast("ServerRequestContext[Any]", None) + + +async def _discover(server: Server[Any]) -> types.DiscoverResult: + entry = server.get_request_handler("server/discover") + assert entry is not None + result = await entry.handler(_UNUSED_CTX, types.RequestParams()) + assert isinstance(result, types.DiscoverResult) + return result + + +def test_registered_by_default() -> None: + """SDK-defined: a bare `Server` registers a `server/discover` handler out of + the box, typed for the base `RequestParams`.""" + server = Server("test-server") + entry = server.get_request_handler("server/discover") + assert entry is not None + assert entry.params_type is types.RequestParams + + +@pytest.mark.anyio +async def test_supported_versions_is_modern_set() -> None: + """`supportedVersions` is exactly the modern envelope set, not the full + legacy-compat list (D-008).""" + result = await _discover(Server("test-server")) + assert result.supported_versions == list(MODERN_PROTOCOL_VERSIONS) + + +@pytest.mark.anyio +async def test_server_info_reflects_constructor_fields() -> None: + """SDK-defined: `serverInfo` is built field-for-field from the `Server` + constructor arguments.""" + icons = [types.Icon(src="https://example.test/icon.png")] + server = Server( + "info-server", + version="9.9.9", + title="Info Server", + description="A server for testing discover.", + website_url="https://example.test", + icons=icons, + ) + result = await _discover(server) + assert result.server_info == types.Implementation( + name="info-server", + version="9.9.9", + title="Info Server", + description="A server for testing discover.", + website_url="https://example.test", + icons=icons, + ) + + +@pytest.mark.anyio +async def test_server_info_version_falls_back_to_package() -> None: + """SDK-defined: when no explicit version is supplied, `serverInfo.version` + falls back to the installed `mcp` package version.""" + result = await _discover(Server("unversioned")) + assert result.server_info.version == importlib.metadata.version("mcp") + + +@pytest.mark.anyio +async def test_instructions_threaded_through() -> None: + """SDK-defined: the `instructions` constructor argument is passed through + verbatim, defaulting to `None` when omitted.""" + server = Server("inst-server", instructions="Read the docs first.") + result = await _discover(server) + assert result.instructions == "Read the docs first." + + bare = await _discover(Server("bare")) + assert bare.instructions is None + + +@pytest.mark.anyio +async def test_capabilities_derived_from_registered_handlers() -> None: + """SDK-defined: capabilities are computed at handler call time from the + live registry, so post-construction `add_request_handler` calls are + reflected.""" + + async def list_tools( + ctx: ServerRequestContext[Any], params: types.PaginatedRequestParams | None + ) -> types.ListToolsResult: + raise NotImplementedError + + async def list_prompts( + ctx: ServerRequestContext[Any], params: types.PaginatedRequestParams | None + ) -> types.ListPromptsResult: + raise NotImplementedError + + server = Server("cap-server", on_list_tools=list_tools) + + before = await _discover(server) + assert before.capabilities.tools is not None + assert before.capabilities.prompts is None + + server.add_request_handler("prompts/list", types.PaginatedRequestParams, list_prompts) + + after = await _discover(server) + assert after.capabilities.tools is not None + assert after.capabilities.prompts is not None + + +@pytest.mark.anyio +async def test_discover_result_defaults_to_immediately_stale_private_cache() -> None: + """SDK-defined: `DiscoverResult` is cacheable; the auto-derived handler + relies on the model defaults (immediately-stale, private).""" + result = await _discover(Server("cache-server")) + assert result.ttl_ms == 0 + assert result.cache_scope == "private" + + +@pytest.mark.anyio +async def test_overridable_via_add_request_handler() -> None: + """SDK-defined: a custom `server/discover` handler registered via + `add_request_handler` replaces the auto-derived default wholesale.""" + server = Server("custom-server", version="1.0.0") + custom = types.DiscoverResult( + supported_versions=list(MODERN_PROTOCOL_VERSIONS), + capabilities=types.ServerCapabilities(), + server_info=types.Implementation(name="custom-server", version="1.0.0"), + instructions="overridden", + ttl_ms=60_000, + cache_scope="public", + ) + + async def custom_discover( + ctx: ServerRequestContext[Any], params: types.RequestParams | None + ) -> types.DiscoverResult: + return custom + + server.add_request_handler("server/discover", types.RequestParams, custom_discover) + result = await _discover(server) + assert result is custom diff --git a/tests/server/mcpserver/tools/test_base.py b/tests/server/mcpserver/tools/test_base.py index 22d5f973e9..5e20f61ad3 100644 --- a/tests/server/mcpserver/tools/test_base.py +++ b/tests/server/mcpserver/tools/test_base.py @@ -1,5 +1,9 @@ -from mcp.server.mcpserver import Context +import pytest + +from mcp import Client, types +from mcp.server.mcpserver import Context, MCPServer from mcp.server.mcpserver.tools.base import Tool +from mcp.shared.exceptions import MCPError def test_context_detected_in_union_annotation(): @@ -8,3 +12,45 @@ def my_tool(x: int, ctx: Context | None) -> str: tool = Tool.from_function(my_tool) assert tool.context_kwarg == "ctx" + + +@pytest.mark.anyio +async def test_mcperror_raised_from_a_tool_surfaces_as_a_top_level_jsonrpc_error_with_code_and_data_intact(): + """SDK-defined: ``MCPError`` carries JSON-RPC ``ErrorData(code, message, data)`` + and means "respond with a protocol error". The tool wrapper re-raises it so + the kernel writes a top-level JSON-RPC error - ``code`` and ``data`` survive + the round-trip rather than being flattened into ``CallToolResult(isError=True)``.""" + mcp = MCPServer(name="srv") + + @mcp.tool() + async def needs_sampling() -> str: + raise MCPError( + types.MISSING_REQUIRED_CLIENT_CAPABILITY, + "sampling capability required", + data={"requiredCapabilities": ["sampling"]}, + ) + + async with Client(mcp) as client: + with pytest.raises(MCPError) as exc_info: + await client.call_tool("needs_sampling", {}) + + assert exc_info.value.error.code == types.MISSING_REQUIRED_CLIENT_CAPABILITY + assert exc_info.value.error.data == {"requiredCapabilities": ["sampling"]} + + +@pytest.mark.anyio +async def test_non_mcperror_exception_raised_from_a_tool_is_wrapped_as_an_is_error_result(): + """SDK-defined: ordinary exceptions from a tool body are execution failures + the LLM should see, so they become ``CallToolResult(isError=True)`` rather + than a protocol-level JSON-RPC error. Pins the other arm of the same branch.""" + mcp = MCPServer(name="srv") + + @mcp.tool() + async def boom() -> str: + raise RuntimeError("execution failure") + + async with Client(mcp) as client: + result = await client.call_tool("boom", {}) + + assert isinstance(result, types.CallToolResult) + assert result.is_error is True diff --git a/tests/server/test_connection.py b/tests/server/test_connection.py index e5d60994c9..8ca1ae8a7a 100644 --- a/tests/server/test_connection.py +++ b/tests/server/test_connection.py @@ -1,9 +1,11 @@ """Tests for `Connection`. -`Connection` wraps an `Outbound` (the standalone stream). Its `notify` is -best-effort (never raises); `send_raw_request` is gated on -`has_standalone_channel`. Tested with a stub `Outbound` so we can assert wire -shape and inject failures. +`Connection` wraps an `Outbound` (the standalone stream). Construct it via the +`from_envelope` / `for_loop` factories so `protocol_version` is always +populated and `has_standalone_channel` is derived from the held outbound. Its +`notify` is best-effort (never raises); `send_raw_request` raises +`NoBackChannelError` structurally from the no-channel sentinel. Tested with a +stub `Outbound` so we can assert wire shape and inject failures. """ import logging @@ -17,6 +19,7 @@ from mcp.server.connection import Connection from mcp.shared.dispatcher import CallOptions from mcp.shared.exceptions import NoBackChannelError +from mcp.shared.version import MODERN_PROTOCOL_VERSIONS from mcp.types import ( LATEST_PROTOCOL_VERSION, ClientCapabilities, @@ -25,7 +28,6 @@ ElicitationCapability, EmptyResult, Implementation, - InitializeRequestParams, ListRootsRequest, ListRootsResult, PingRequest, @@ -37,13 +39,8 @@ SamplingToolsCapability, ) - -def _client_params(capabilities: ClientCapabilities) -> InitializeRequestParams: - return InitializeRequestParams( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=capabilities, - client_info=Implementation(name="t", version="0"), - ) +_CLIENT_INFO = Implementation(name="t", version="0") +_MODERN = MODERN_PROTOCOL_VERSIONS[0] class StubOutbound: @@ -67,10 +64,83 @@ async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: self.notifications.append((method, params)) +# --- factories ----------------------------------------------------------------- + + +def test_from_envelope_is_born_ready_with_no_back_channel(): + """SDK-defined: `from_envelope` populates `protocol_version`, sets `initialized`, + and holds the no-channel sentinel so `has_standalone_channel` derives False.""" + conn = Connection.from_envelope(_MODERN, None, None) + assert conn.protocol_version == _MODERN + assert conn.initialized.is_set() + assert conn.initialize_accepted is True + assert conn.has_standalone_channel is False + assert conn.client_params is None + assert conn.session_id is None + + +def test_from_envelope_records_client_params_when_both_info_and_caps_supplied(): + """SDK-defined: when both client info and capabilities are supplied, + `from_envelope` synthesizes `client_params` so capability checks can run.""" + caps = ClientCapabilities(sampling=SamplingCapability()) + conn = Connection.from_envelope(_MODERN, _CLIENT_INFO, caps) + assert conn.client_params is not None + assert conn.client_params.client_info.name == "t" + assert conn.client_params.capabilities.sampling is not None + assert conn.client_params.protocol_version == _MODERN + + +@pytest.mark.parametrize( + ("info", "caps"), + [(None, ClientCapabilities()), (_CLIENT_INFO, None)], +) +def test_from_envelope_leaves_client_params_none_when_either_is_missing( + info: Implementation | None, caps: ClientCapabilities | None +): + """SDK-defined: `client_params` is only synthesized when both info and + caps are present; either missing leaves it `None`.""" + conn = Connection.from_envelope(_MODERN, info, caps) + assert conn.client_params is None + + +def test_from_envelope_with_explicit_outbound_has_standalone_channel(): + """SDK-defined: duplex modern transports pass an outbound; `has_standalone_channel` + derives True since the held outbound is not the no-channel sentinel.""" + out = StubOutbound() + conn = Connection.from_envelope(_MODERN, None, None, outbound=out) + assert conn.has_standalone_channel is True + assert conn.outbound is out + assert conn.initialized.is_set() + + +def test_for_loop_seeds_version_from_hint_or_latest_and_is_not_born_ready(): + """SDK-defined: `for_loop` seeds `protocol_version` from the hint when given, + else `LATEST_PROTOCOL_VERSION`; the connection awaits the initialize handshake.""" + out = StubOutbound() + conn = Connection.for_loop(out) + assert conn.protocol_version == LATEST_PROTOCOL_VERSION + assert conn.has_standalone_channel is True + assert not conn.initialized.is_set() + assert conn.initialize_accepted is False + assert conn.client_params is None + + hinted = Connection.for_loop(out, protocol_version_hint=_MODERN) + assert hinted.protocol_version == _MODERN + + +def test_for_loop_records_session_id_when_supplied(): + """SDK-defined: `for_loop` stores the `session_id` kwarg verbatim.""" + conn = Connection.for_loop(StubOutbound(), session_id="sess-1") + assert conn.session_id == "sess-1" + + +# --- outbound channel ---------------------------------------------------------- + + @pytest.mark.anyio async def test_connection_notify_forwards_to_outbound(): out = StubOutbound() - conn = Connection(out, has_standalone_channel=True) + conn = Connection.for_loop(out) await conn.notify("notifications/message", {"level": "info", "data": "hi"}) assert out.notifications == [("notifications/message", {"level": "info", "data": "hi"})] @@ -79,24 +149,24 @@ async def test_connection_notify_forwards_to_outbound(): async def test_connection_notify_swallows_broken_stream_and_debug_logs(caplog: pytest.LogCaptureFixture): caplog.set_level(logging.DEBUG, logger="mcp.server.connection") out = StubOutbound(raise_on_send=anyio.BrokenResourceError) - conn = Connection(out, has_standalone_channel=True) + conn = Connection.for_loop(out) await conn.notify("notifications/message", {"data": "x"}) # must not raise assert "stream closed" in caplog.text.lower() @pytest.mark.anyio async def test_connection_notify_drops_when_no_standalone_channel(caplog: pytest.LogCaptureFixture): + """SDK-defined: the no-channel sentinel debug-logs and drops; `notify` never raises.""" caplog.set_level(logging.DEBUG, logger="mcp.server.connection") - out = StubOutbound() - conn = Connection(out, has_standalone_channel=False) + conn = Connection.from_envelope(LATEST_PROTOCOL_VERSION, None, None) await conn.notify("notifications/message", {"data": "x"}) # must not raise - assert out.notifications == [] assert "no standalone channel" in caplog.text.lower() @pytest.mark.anyio async def test_connection_send_raw_request_raises_nobackchannel_when_no_standalone_channel(): - conn = Connection(StubOutbound(), has_standalone_channel=False) + """SDK-defined: the no-channel sentinel raises structurally; `Connection` does no pre-check.""" + conn = Connection.from_envelope(LATEST_PROTOCOL_VERSION, None, None) with pytest.raises(NoBackChannelError): await conn.send_raw_request("ping", None) @@ -104,7 +174,7 @@ async def test_connection_send_raw_request_raises_nobackchannel_when_no_standalo @pytest.mark.anyio async def test_connection_send_raw_request_forwards_when_standalone_channel_present(): out = StubOutbound() - conn = Connection(out, has_standalone_channel=True) + conn = Connection.for_loop(out) result = await conn.send_raw_request("ping", None) assert out.requests == [("ping", None)] assert result == {} @@ -113,7 +183,7 @@ async def test_connection_send_raw_request_forwards_when_standalone_channel_pres @pytest.mark.anyio async def test_connection_send_request_with_spec_type_infers_result_type(): out = StubOutbound(result={"roots": [{"uri": "file:///ws"}]}) - conn = Connection(out, has_standalone_channel=True) + conn = Connection.for_loop(out) result = await conn.send_request(ListRootsRequest()) method, _ = out.requests[0] assert method == "roots/list" @@ -126,7 +196,7 @@ async def test_connection_send_request_validates_result_alias_only(): """Peer results validate alias-only; a snake_case key from the wire is ignored as extra, not populated by Python field name.""" snake = {"role": "assistant", "content": {"type": "text", "text": "x"}, "model": "m", "stop_reason": "endTurn"} - conn = Connection(StubOutbound(result=snake), has_standalone_channel=True) + conn = Connection.for_loop(StubOutbound(result=snake)) result = await conn.send_request(CreateMessageRequest(params=CreateMessageRequestParams(messages=[], max_tokens=1))) assert result.stop_reason is None @@ -134,14 +204,14 @@ async def test_connection_send_request_validates_result_alias_only(): @pytest.mark.anyio async def test_connection_send_request_with_result_type_kwarg_validates_custom_type(): out = StubOutbound(result={}) - conn = Connection(out, has_standalone_channel=True) + conn = Connection.for_loop(out) result = await conn.send_request(PingRequest(), result_type=EmptyResult) assert isinstance(result, EmptyResult) @pytest.mark.anyio async def test_connection_send_request_nonconforming_result_raises_validation_error(): - conn = Connection(StubOutbound(result={"bogus": 1}), has_standalone_channel=True) + conn = Connection.for_loop(StubOutbound(result={"bogus": 1})) with pytest.raises(ValidationError): await conn.send_request(ListRootsRequest()) @@ -150,7 +220,7 @@ async def test_connection_send_request_nonconforming_result_raises_validation_er async def test_send_request_validates_the_client_result_against_the_surface_schema(): """A spec-method result that fails the per-version surface schema raises `ValidationError` even when the caller's `result_type` would accept it.""" - conn = Connection(StubOutbound(result={"roots": "nope"}), has_standalone_channel=True) + conn = Connection.for_loop(StubOutbound(result={"roots": "nope"})) with pytest.raises(ValidationError): await conn.send_request(ListRootsRequest(), result_type=EmptyResult) @@ -158,8 +228,8 @@ async def test_send_request_validates_the_client_result_against_the_surface_sche @pytest.mark.anyio async def test_send_request_passes_a_spec_valid_client_result(): """A spec-valid client result passes the surface gate and parses to the typed model.""" - conn = Connection(StubOutbound(result={"roots": [{"uri": "file:///ws"}]}), has_standalone_channel=True) - conn.protocol_version = "2025-11-25" + conn = Connection.for_loop(StubOutbound(result={"roots": [{"uri": "file:///ws"}]})) + assert conn.protocol_version == LATEST_PROTOCOL_VERSION result = await conn.send_request(ListRootsRequest()) assert isinstance(result, ListRootsResult) assert str(result.roots[0].uri) == "file:///ws" @@ -178,8 +248,7 @@ class _CustomResult(BaseModel): async def test_send_request_skips_the_surface_gate_when_method_absent_at_version(): """Surface row absent for the negotiated version: gate is bypassed and only the inferred result type validates.""" - conn = Connection(StubOutbound(result={}), has_standalone_channel=True) - conn.protocol_version = "2026-07-28" + conn = Connection.for_loop(StubOutbound(result={}), protocol_version_hint=_MODERN) result = await conn.send_request(PingRequest()) assert isinstance(result, EmptyResult) @@ -187,8 +256,7 @@ async def test_send_request_skips_the_surface_gate_when_method_absent_at_version @pytest.mark.anyio async def test_send_request_with_a_custom_method_skips_the_surface_gate(): """Non-spec methods are not blocked by the surface gate; `result_type` validates.""" - conn = Connection(StubOutbound(result={"value": 7}), has_standalone_channel=True) - conn.protocol_version = "2025-11-25" + conn = Connection.for_loop(StubOutbound(result={"value": 7})) result = await conn.send_request(_CustomRequest(), result_type=_CustomResult) assert isinstance(result, _CustomResult) assert result.value == 7 @@ -197,7 +265,7 @@ async def test_send_request_with_a_custom_method_skips_the_surface_gate(): @pytest.mark.anyio async def test_connection_ping_sends_ping_on_standalone(): out = StubOutbound() - conn = Connection(out, has_standalone_channel=True) + conn = Connection.for_loop(out) await conn.ping() assert out.requests == [("ping", None)] @@ -205,7 +273,7 @@ async def test_connection_ping_sends_ping_on_standalone(): @pytest.mark.anyio async def test_connection_log_sends_logging_message_notification(): out = StubOutbound() - conn = Connection(out, has_standalone_channel=True) + conn = Connection.for_loop(out) await conn.log("info", {"k": "v"}, logger="my.logger") # pyright: ignore[reportDeprecated] method, params = out.notifications[0] assert method == "notifications/message" @@ -218,7 +286,7 @@ async def test_connection_log_sends_logging_message_notification(): @pytest.mark.anyio async def test_connection_log_with_meta_includes_meta_in_params(): out = StubOutbound() - conn = Connection(out, has_standalone_channel=True) + conn = Connection.for_loop(out) await conn.log("info", "x", meta={"traceId": "abc"}) # pyright: ignore[reportDeprecated] _, params = out.notifications[0] assert params is not None @@ -228,7 +296,7 @@ async def test_connection_log_with_meta_includes_meta_in_params(): @pytest.mark.anyio async def test_connection_list_changed_notifications_send_correct_methods(): out = StubOutbound() - conn = Connection(out, has_standalone_channel=True) + conn = Connection.for_loop(out) await conn.send_tool_list_changed() await conn.send_prompt_list_changed() await conn.send_resource_list_changed() @@ -246,14 +314,21 @@ async def test_connection_list_changed_notifications_send_correct_methods(): @pytest.mark.anyio async def test_connection_send_tool_list_changed_with_meta_includes_meta_only_params(): out = StubOutbound() - conn = Connection(out, has_standalone_channel=True) + conn = Connection.for_loop(out) await conn.send_tool_list_changed(meta={"k": 1}) assert out.notifications == [("notifications/tools/list_changed", {"_meta": {"k": 1}})] -def test_connection_check_capability_false_before_initialized(): - conn = Connection(StubOutbound(), has_standalone_channel=True) +# --- check_capability ---------------------------------------------------------- + + +def test_connection_check_capability_false_when_no_client_params_recorded(): + """SDK-defined: `check_capability` returns False when no `client_params` + were recorded, regardless of which factory built the connection.""" + conn = Connection.for_loop(StubOutbound()) assert conn.check_capability(ClientCapabilities(sampling=SamplingCapability())) is False + # Same for a born-ready connection that supplied neither info nor caps. + assert Connection.from_envelope(_MODERN, None, None).check_capability(ClientCapabilities()) is False @pytest.mark.parametrize( @@ -281,6 +356,11 @@ def test_connection_check_capability_false_before_initialized(): ClientCapabilities(sampling=SamplingCapability(tools=SamplingToolsCapability())), True, ), + ( + ClientCapabilities(sampling=SamplingCapability(context=SamplingContextCapability())), + ClientCapabilities(sampling=SamplingCapability(context=SamplingContextCapability())), + True, + ), (ClientCapabilities(experimental=None), ClientCapabilities(experimental={"a": {}}), False), (ClientCapabilities(experimental={"a": {}}), ClientCapabilities(experimental={"b": {}}), False), (ClientCapabilities(experimental={"a": {"x": 1}}), ClientCapabilities(experimental={"a": {"x": 2}}), False), @@ -288,17 +368,16 @@ def test_connection_check_capability_false_before_initialized(): ], ) def test_check_capability_per_field_branches(have: ClientCapabilities, want: ClientCapabilities, expected: bool): - conn = Connection(StubOutbound(), has_standalone_channel=True) - conn.client_params = _client_params(have) + conn = Connection.from_envelope(LATEST_PROTOCOL_VERSION, _CLIENT_INFO, have) assert conn.check_capability(want) is expected def test_connection_check_capability_true_when_client_declares_it(): - conn = Connection(StubOutbound(), has_standalone_channel=True) - conn.client_params = _client_params( - ClientCapabilities(sampling=SamplingCapability(), roots=RootsCapability(list_changed=True)) + conn = Connection.from_envelope( + LATEST_PROTOCOL_VERSION, + _CLIENT_INFO, + ClientCapabilities(sampling=SamplingCapability(), roots=RootsCapability(list_changed=True)), ) - conn.initialized.set() assert conn.check_capability(ClientCapabilities(sampling=SamplingCapability())) is True assert conn.check_capability(ClientCapabilities(roots=RootsCapability(list_changed=True))) is True assert conn.check_capability(ClientCapabilities(elicitation=ElicitationCapability())) is False diff --git a/tests/server/test_lowlevel_exception_handling.py b/tests/server/test_lowlevel_exception_handling.py index 015a5cbafa..15df7f1cef 100644 --- a/tests/server/test_lowlevel_exception_handling.py +++ b/tests/server/test_lowlevel_exception_handling.py @@ -39,7 +39,7 @@ async def test_server_run_exits_cleanly_when_transport_yields_exception_then_clo with anyio.fail_after(5): # stateless=True so server.run doesn't wait for initialize handshake. # Before the fix, this raised ExceptionGroup(ClosedResourceError). - await server.run(read_recv, write_send, server.create_initialization_options(), stateless=True) + await server.run(read_recv, write_send, server.create_initialization_options()) # write_send was closed inside run's `async with`; receive_nowait raises # EndOfStream iff the buffer is empty (i.e., server wrote nothing). diff --git a/tests/server/test_runner.py b/tests/server/test_runner.py index c4d10aea08..72b3ad17b5 100644 --- a/tests/server/test_runner.py +++ b/tests/server/test_runner.py @@ -1,49 +1,63 @@ -"""Tests for `ServerRunner`. +"""Tests for `ServerRunner` and the free-function drivers. -End-to-end over `JSONRPCDispatcher` with a real lowlevel `Server` as the -registry. The `connected_runner` helper starts both sides and (by default) -performs the initialize handshake, so each test exercises only the behaviour -under test. +The kernel tests run end-to-end over `JSONRPCDispatcher` with a real lowlevel +`Server` as the registry. The `connected_runner` helper starts both sides and +(by default) performs the initialize handshake, so each test exercises only the +behaviour under test. Driver tests (`serve_connection`, `serve_one`, +`to_jsonrpc_response`, `aclose_shielded`) follow at the bottom. """ -from collections.abc import AsyncIterator +from collections.abc import AsyncIterator, Mapping from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from functools import partial from typing import Any, cast import anyio +import anyio.abc import pytest from opentelemetry.trace import SpanKind, StatusCode import mcp.server.runner +from mcp.server.connection import Connection from mcp.server.context import ServerRequestContext from mcp.server.lowlevel.server import NotificationOptions, Server from mcp.server.models import InitializationOptions -from mcp.server.runner import ServerRunner, _extract_meta, _resolve_protocol_version, otel_middleware +from mcp.server.runner import ( + ServerRunner, + _extract_meta, + aclose_shielded, + otel_middleware, + serve_connection, + serve_one, + to_jsonrpc_response, +) from mcp.server.session import ServerSession -from mcp.shared.dispatcher import DispatchContext, DispatchMiddleware, OnRequest +from mcp.shared.dispatcher import CallOptions, DispatchContext, DispatchMiddleware, OnRequest from mcp.shared.exceptions import MCPError from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher -from mcp.shared.message import ClientMessageMetadata, ServerMessageMetadata +from mcp.shared.message import MessageMetadata from mcp.shared.peer import dump_params from mcp.shared.transport_context import TransportContext -from mcp.shared.version import SUPPORTED_PROTOCOL_VERSIONS +from mcp.shared.version import MODERN_PROTOCOL_VERSIONS, SUPPORTED_PROTOCOL_VERSIONS from mcp.types import ( INTERNAL_ERROR, INVALID_PARAMS, LATEST_PROTOCOL_VERSION, METHOD_NOT_FOUND, - PROTOCOL_VERSION_META_KEY, CallToolRequestParams, ClientCapabilities, ErrorData, Implementation, InitializeRequestParams, + JSONRPCError, + JSONRPCRequest, + JSONRPCResponse, ListToolsResult, NotificationParams, PaginatedRequestParams, ProgressNotificationParams, RequestParams, - RequestParamsMeta, SetLevelRequestParams, Tool, ) @@ -84,36 +98,46 @@ async def connected_runner( server: SrvT, *, initialized: bool = True, - stateless: bool = False, - has_standalone_channel: bool = True, init_options: InitializationOptions | None = None, - session_id: str | None = None, dispatch_middleware: list[DispatchMiddleware] | None = None, + connection: Connection | None = None, ) -> AsyncIterator[tuple[JSONRPCDispatcher[TransportContext], ServerRunner[dict[str, Any]]]]: """Yield `(client, runner)` running over an in-memory JSON-RPC dispatcher pair. - Starts the client (echo handlers) and `runner.run()` in a task group, wraps - the body in `anyio.fail_after(5)`, and cancels on exit. When - `initialized` is true the helper performs the real `initialize` request - before yielding, so tests start past the init-gate via the public path. + Starts the client (echo handlers) and the server-side dispatcher loop + (kernel `on_request`/`on_notify` + `aclose_shielded` teardown - the + `serve_connection` shape) in a task group, wraps the body in + `anyio.fail_after(5)`, and cancels on exit. When `initialized` is true the + helper performs the real `initialize` request before yielding, so tests + start past the init-gate via the public path. + + `connection` defaults to `Connection.for_loop(server_dispatcher)`. Pass a + factory-built connection (e.g. `Connection.from_envelope(...)`) to exercise + the born-ready path; the kernel reads it as a fact and is mode-agnostic. """ client, server_d, close = jsonrpc_pair() assert isinstance(client, JSONRPCDispatcher) and isinstance(server_d, JSONRPCDispatcher) + if connection is None: + connection = Connection.for_loop(server_d) runner = ServerRunner( server=server, - dispatcher=server_d, + connection=connection, lifespan_state={}, - has_standalone_channel=has_standalone_channel, init_options=init_options, - session_id=session_id, - stateless=stateless, dispatch_middleware=dispatch_middleware or [], ) c_req, c_notify = echo_handlers(Recorder()) body_exc: BaseException | None = None + + async def _drive(*, task_status: anyio.abc.TaskStatus[None]) -> None: + try: + await server_d.run(runner.on_request, runner.on_notify, task_status=task_status) + finally: + await aclose_shielded(connection) + async with anyio.create_task_group() as tg: await tg.start(client.run, c_req, c_notify) - await tg.start(runner.run) + await tg.start(_drive) try: with anyio.fail_after(5): if initialized: @@ -219,11 +243,21 @@ async def test_runner_routes_to_handler_and_builds_context(server: SrvT): assert isinstance(ctx, ServerRequestContext) assert ctx.lifespan_context == {} assert isinstance(ctx.session, ServerSession) - assert ctx.session is runner.session + assert ctx.session.protocol_version == runner.connection.protocol_version assert ctx.request_id is not None assert ctx.protocol_version == LATEST_PROTOCOL_VERSION +@pytest.mark.anyio +async def test_runner_builds_a_fresh_session_per_request(server: SrvT): + """`ctx.session` is built per-request from the per-request `DispatchContext` + and the connection's standalone outbound; it is not connection-scoped.""" + async with connected_runner(server) as (client, _): + await client.send_raw_request("tools/list", None) + await client.send_raw_request("tools/list", None) + assert _seen_ctx[0].session is not _seen_ctx[1].session + + @pytest.mark.anyio async def test_runner_spec_method_with_no_handler_raises_method_not_found(server: SrvT): async with connected_runner(server) as (client, _): @@ -605,7 +639,6 @@ async def reject_initialize(ctx: Ctx, method: str, params: Any, call_next: Any) assert await client.send_raw_request("ping", None) == {} assert runner.connection.initialize_accepted is False assert runner.connection.client_params is None - assert runner.connection.protocol_version is None assert not runner.connection.initialized.is_set() @@ -656,65 +689,6 @@ async def on_roots(ctx: Ctx, params: NotificationParams | None) -> None: ] -def test_resolve_protocol_version_handshake_committed_value_wins(): - md = ServerMessageMetadata(protocol_version="2025-03-26") - meta: RequestParamsMeta = {PROTOCOL_VERSION_META_KEY: "2025-03-26"} - assert _resolve_protocol_version("2025-06-18", meta, md) == "2025-06-18" - - -def test_resolve_protocol_version_reads_per_request_meta_when_no_handshake(): - md = ServerMessageMetadata(protocol_version="2025-03-26") - meta: RequestParamsMeta = {PROTOCOL_VERSION_META_KEY: "2025-06-18"} - assert _resolve_protocol_version(None, meta, md) == "2025-06-18" - - -def test_resolve_protocol_version_skips_unsupported_meta_value(): - md = ServerMessageMetadata(protocol_version="2025-03-26") - meta: RequestParamsMeta = {PROTOCOL_VERSION_META_KEY: "1900-01-01"} - assert _resolve_protocol_version(None, meta, md) == "2025-03-26" - - -def test_resolve_protocol_version_skips_non_string_meta_value(): - md = ServerMessageMetadata(protocol_version="2025-03-26") - meta: RequestParamsMeta = {PROTOCOL_VERSION_META_KEY: 42} - assert _resolve_protocol_version(None, meta, md) == "2025-03-26" - - -def test_resolve_protocol_version_reads_transport_hint_when_no_handshake_or_meta(): - md = ServerMessageMetadata(protocol_version="2025-06-18") - assert _resolve_protocol_version(None, None, md) == "2025-06-18" - assert _resolve_protocol_version(None, {}, md) == "2025-06-18" - - -def test_resolve_protocol_version_skips_unsupported_transport_hint(): - """The `initialize` params version reaches the metadata unvalidated; surface validation must never see it.""" - md = ServerMessageMetadata(protocol_version="1900-01-01") - assert _resolve_protocol_version(None, None, md) == "2025-11-25" - - -def test_resolve_protocol_version_terminal_default_with_no_signals(): - assert _resolve_protocol_version(None, None, None) == "2025-11-25" - assert _resolve_protocol_version(None, None, ServerMessageMetadata()) == "2025-11-25" - assert _resolve_protocol_version(None, None, ClientMessageMetadata()) == "2025-11-25" - - -@pytest.mark.anyio -async def test_runner_ctx_protocol_version_is_terminal_default_on_stateless_in_memory(server: SrvT): - async with connected_runner(server, initialized=False, stateless=True) as (client, runner): - await client.send_raw_request("tools/list", None) - ctx = _seen_ctx[0] - assert ctx.protocol_version == "2025-11-25" - assert ctx.session.protocol_version is None - assert runner.connection.protocol_version is None - - -@pytest.mark.anyio -async def test_runner_ctx_protocol_version_tracks_per_request_meta_on_stateless(server: SrvT): - async with connected_runner(server, initialized=False, stateless=True) as (client, _): - await client.send_raw_request("tools/list", {"_meta": {PROTOCOL_VERSION_META_KEY: "2025-06-18"}}) - assert _seen_ctx[0].protocol_version == "2025-06-18" - - def test_extract_meta_returns_none_for_absent_or_malformed(): """Context construction is independent of `_meta` validity; the params validation inside `call_next()` is what surfaces the error.""" @@ -837,21 +811,16 @@ async def bad_return(ctx: Ctx, params: PaginatedRequestParams | None) -> int: @pytest.mark.anyio -async def test_runner_stateless_skips_init_gate(server: SrvT): - async with connected_runner(server, initialized=False, stateless=True, has_standalone_channel=False) as (client, _): - result = await client.send_raw_request("tools/list", None) - assert result["tools"][0]["name"] == "t" - - -@pytest.mark.anyio -async def test_runner_stateless_connection_initialized_event_set_on_construction(server: SrvT): - """`connection.initialized` mirrors the gate flag in stateless mode so - `await connection.initialized.wait()` does not hang when no handshake - arrives.""" - async with connected_runner(server, initialized=False, stateless=True, has_standalone_channel=False) as (_, runner): +async def test_runner_with_born_ready_connection_skips_init_gate(server: SrvT): + """A `Connection.from_envelope` connection is born ready: the kernel's + init-gate is open without any handshake. The kernel is mode-agnostic - the + same `on_request` reads `connection.initialize_accepted` as a fact.""" + born_ready = Connection.from_envelope(LATEST_PROTOCOL_VERSION, None, None) + async with connected_runner(server, initialized=False, connection=born_ready) as (client, runner): assert runner.connection.initialize_accepted is True assert runner.connection.initialized.is_set() - await runner.connection.initialized.wait() + result = await client.send_raw_request("tools/list", None) + assert result["tools"][0]["name"] == "t" @pytest.mark.anyio @@ -954,6 +923,33 @@ async def discover(ctx: Ctx, params: RequestParams) -> Any: assert exc.value.error == ErrorData(code=METHOD_NOT_FOUND, message="Method not found", data="server/discover") +@pytest.mark.anyio +async def test_on_request_rejects_initialize_at_modern_version_with_method_not_found(server: SrvT): + """Spec-mandated: `initialize` has no `CLIENT_REQUESTS` row at the modern + version; kernel dispatch (not the inbound classifier) rejects it.""" + born_ready = Connection.from_envelope(MODERN_PROTOCOL_VERSIONS[0], None, None) + async with connected_runner(server, initialized=False, connection=born_ready) as (client, runner): + assert runner.connection.protocol_version == MODERN_PROTOCOL_VERSIONS[0] + with pytest.raises(MCPError) as exc: + await client.send_raw_request("initialize", _initialize_params()) + assert exc.value.error.code == METHOD_NOT_FOUND + + +@pytest.mark.anyio +async def test_on_request_dispatches_custom_method_registered_via_add_request_handler(server: SrvT): + """SDK-defined: a method outside `SPEC_CLIENT_METHODS` skips the version + gate and reaches its registered handler at any negotiated version.""" + + async def echo(ctx: Ctx, params: RequestParams) -> dict[str, Any]: + return {"echoed": True} + + server.add_request_handler("myorg/echo", RequestParams, echo) + born_ready = Connection.from_envelope(MODERN_PROTOCOL_VERSIONS[0], None, None) + async with connected_runner(server, initialized=False, connection=born_ready) as (client, _): + result = await client.send_raw_request("myorg/echo", None) + assert result == {"echoed": True} + + @pytest.mark.anyio async def test_runner_middleware_short_circuit_on_a_wrong_version_spec_method_skips_the_sieve(server: SrvT): """A server-tier middleware that returns without calling `call_next` for a @@ -1205,3 +1201,194 @@ async def _append(i: int) -> None: await client.send_raw_request("tools/list", None) assert cleaned == [2, 1] assert "abandoning remaining callbacks" not in caplog.text + + +# --- to_jsonrpc_response ------------------------------------------------------- + + +@pytest.mark.anyio +async def test_to_jsonrpc_response_wraps_success_as_jsonrpc_response(): + """SDK-defined: a handler coroutine resolving to a result dict is wrapped as a + `JSONRPCResponse` carrying the supplied id and the dict verbatim as `result`.""" + + async def ok() -> dict[str, Any]: + return {"k": "v"} + + reply = await to_jsonrpc_response(7, ok()) + assert isinstance(reply, JSONRPCResponse) + assert reply.id == 7 + assert reply.result == {"k": "v"} + + +@pytest.mark.anyio +async def test_to_jsonrpc_response_maps_mcp_error_to_jsonrpc_error(): + """SDK-defined: an `MCPError` raised by the handler coroutine is wrapped as a + `JSONRPCError` whose `error` carries the same code, message, and data.""" + + async def fail() -> dict[str, Any]: + raise MCPError(code=METHOD_NOT_FOUND, message="nope", data="x") + + reply = await to_jsonrpc_response("rid", fail()) + assert isinstance(reply, JSONRPCError) + assert reply.id == "rid" + assert reply.error == ErrorData(code=METHOD_NOT_FOUND, message="nope", data="x") + + +@pytest.mark.anyio +async def test_to_jsonrpc_response_maps_validation_error_to_invalid_params(): + """SDK-defined: a pydantic `ValidationError` escaping the handler coroutine is + mapped to `INVALID_PARAMS` with a generic message (validator detail does not + reach the wire).""" + + async def fail() -> dict[str, Any]: + Tool.model_validate({"name": 123}) # raises ValidationError + raise NotImplementedError + + reply = await to_jsonrpc_response(1, fail()) + assert isinstance(reply, JSONRPCError) + assert reply.error == ErrorData(code=INVALID_PARAMS, message="Invalid request parameters", data="") + + +@pytest.mark.anyio +async def test_to_jsonrpc_response_maps_unmapped_exception_to_internal_error_and_logs( + caplog: pytest.LogCaptureFixture, +): + """SDK-defined: an unmapped exception is logged server-side and surfaced as + `INTERNAL_ERROR` with a generic message; the exception text never reaches the + wire.""" + + async def fail() -> dict[str, Any]: + raise RuntimeError("boom") + + reply = await to_jsonrpc_response(1, fail()) + assert isinstance(reply, JSONRPCError) + assert reply.error.code == INTERNAL_ERROR + # Handler internals never reach the wire. + assert "boom" not in reply.error.message + assert "request handler raised" in caplog.text + + +# --- aclose_shielded ----------------------------------------------------------- + + +@pytest.mark.anyio +async def test_aclose_shielded_runs_callbacks_under_outer_cancellation(): + """The shield lets per-connection cleanup run even when the enclosing scope + is being cancelled.""" + cleaned: list[int] = [] + conn = Connection.from_envelope(LATEST_PROTOCOL_VERSION, None, None) + + async def _append() -> None: + await anyio.sleep(0) + cleaned.append(1) + + conn.exit_stack.push_async_callback(_append) + with anyio.CancelScope() as scope: + scope.cancel() + await aclose_shielded(conn) + assert cleaned == [1] + + +# --- serve_one / serve_connection --------------------------------------------- + + +@dataclass +class _StubDispatchContext: + """Minimal `DispatchContext` for `serve_one` driver tests. + + The modern entry hands a per-request context to `serve_one`; this stub + satisfies the protocol structurally with no real back-channel. + """ + + request_id: int | str | None + transport: TransportContext = field(default_factory=lambda: TransportContext(kind="direct", can_send_request=False)) + message_metadata: MessageMetadata = None + cancel_requested: anyio.Event = field(default_factory=anyio.Event) + can_send_request: bool = False + + async def send_raw_request( + self, method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None + ) -> dict[str, Any]: + raise NotImplementedError + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + raise NotImplementedError + + async def progress(self, progress: float, total: float | None = None, message: str | None = None) -> None: + raise NotImplementedError + + +async def _append_async(dst: list[int], v: int) -> None: + dst.append(v) + + +_LIFESPAN: dict[str, Any] = {} + + +@pytest.mark.anyio +async def test_serve_one_runs_handler_and_returns_jsonrpc_response(server: SrvT): + """The single-exchange driver: builds the kernel, runs `on_request` once, + wraps via `to_jsonrpc_response`, and tears down `connection.exit_stack`.""" + conn = Connection.from_envelope(LATEST_PROTOCOL_VERSION, None, None) + cleaned: list[int] = [] + conn.exit_stack.push_async_callback(_append_async, cleaned, 1) + request = JSONRPCRequest(jsonrpc="2.0", id=9, method="tools/list", params=None) + reply = await serve_one(server, request, connection=conn, dctx=_StubDispatchContext(9), lifespan_state=_LIFESPAN) + assert isinstance(reply, JSONRPCResponse) + assert reply.id == 9 + assert reply.result["tools"][0]["name"] == "t" + assert cleaned == [1] + ctx = _seen_ctx[0] + assert ctx.protocol_version == LATEST_PROTOCOL_VERSION + + +@pytest.mark.anyio +async def test_serve_one_maps_error_to_jsonrpc_error_and_still_closes_exit_stack(server: SrvT): + """SDK-defined: a kernel-produced error (here `METHOD_NOT_FOUND` for an + unregistered method) is wrapped as a `JSONRPCError`, and the per-request + exit stack is closed on the error path too.""" + conn = Connection.from_envelope(LATEST_PROTOCOL_VERSION, None, None) + cleaned: list[int] = [] + conn.exit_stack.push_async_callback(_append_async, cleaned, 1) + request = JSONRPCRequest(jsonrpc="2.0", id=2, method="resources/list", params=None) + reply = await serve_one(server, request, connection=conn, dctx=_StubDispatchContext(2), lifespan_state=_LIFESPAN) + assert isinstance(reply, JSONRPCError) + assert reply.error.code == METHOD_NOT_FOUND + assert cleaned == [1] + + +@pytest.mark.anyio +async def test_serve_one_reads_connection_protocol_version_as_a_fact(server: SrvT): + """`serve_one` builds the kernel over the entry's `Connection`; the kernel + reads `connection.protocol_version` for the version gate. A `from_envelope` + connection at a modern version rejects a method absent there.""" + conn = Connection.from_envelope(MODERN_PROTOCOL_VERSIONS[0], None, None) + request = JSONRPCRequest(jsonrpc="2.0", id=1, method="logging/setLevel", params={"level": "info"}) + reply = await serve_one(server, request, connection=conn, dctx=_StubDispatchContext(1), lifespan_state=_LIFESPAN) + assert isinstance(reply, JSONRPCError) + assert reply.error.code == METHOD_NOT_FOUND + + +@pytest.mark.anyio +async def test_serve_connection_drives_dispatcher_loop_and_tears_down(server: SrvT): + """The loop-mode driver: `serve_connection` builds the kernel, hands + `on_request`/`on_notify` to `dispatcher.run()`, and `aclose_shielded`s the + connection on the way out.""" + client, server_d, close = jsonrpc_pair() + assert isinstance(client, JSONRPCDispatcher) and isinstance(server_d, JSONRPCDispatcher) + conn = Connection.for_loop(server_d) + cleaned: list[int] = [] + conn.exit_stack.push_async_callback(_append_async, cleaned, 1) + c_req, c_notify = echo_handlers(Recorder()) + async with anyio.create_task_group() as tg: + await tg.start(client.run, c_req, c_notify) + await tg.start(partial(serve_connection, server, server_d, connection=conn, lifespan_state=_LIFESPAN)) + with anyio.fail_after(5): + await client.send_raw_request("initialize", _initialize_params()) + result = await client.send_raw_request("tools/list", None) + assert result["tools"][0]["name"] == "t" + assert cleaned == [] + close() + assert cleaned == [1] + assert conn.protocol_version == LATEST_PROTOCOL_VERSION + assert conn.client_params is not None diff --git a/tests/server/test_server_context.py b/tests/server/test_server_context.py index 5665d2ff77..9a9eaa3d97 100644 --- a/tests/server/test_server_context.py +++ b/tests/server/test_server_context.py @@ -31,21 +31,20 @@ class _Lifespan: @pytest.mark.anyio async def test_context_exposes_lifespan_and_connection_and_forwards_base_context(): captured: list[Context[_Lifespan]] = [] - conn = Connection.__new__(Connection) # placeholder until running_pair gives us the dispatcher + conn_holder: list[Connection] = [] async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: - ctx: Context[_Lifespan] = Context(dctx, lifespan=_Lifespan("app"), connection=conn) + ctx: Context[_Lifespan] = Context(dctx, lifespan=_Lifespan("app"), connection=conn_holder[0]) captured.append(ctx) return {} async with running_pair(direct_pair, server_on_request=server_on_request) as (client, server, *_): - # Now we have the server dispatcher; build the real Connection bound to it. - conn.__init__(server, has_standalone_channel=True, session_id="sess-1") + conn_holder.append(Connection.for_loop(server, session_id="sess-1")) with anyio.fail_after(5): await client.send_raw_request("t", None) ctx = captured[0] assert ctx.lifespan.name == "app" - assert ctx.connection is conn + assert ctx.connection is conn_holder[0] assert ctx.transport.kind == "direct" assert ctx.can_send_request is True assert ctx.session_id == "sess-1" @@ -58,9 +57,7 @@ async def test_context_log_sends_request_scoped_message_notification(): _, c_notify = echo_handlers(crec) async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: - ctx: Context[_Lifespan] = Context( - dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) - ) + ctx: Context[_Lifespan] = Context(dctx, lifespan=_Lifespan("app"), connection=Connection.for_loop(dctx)) await ctx.log("debug", "hello") # pyright: ignore[reportDeprecated] return {} @@ -82,9 +79,7 @@ async def test_context_log_includes_logger_and_meta_when_supplied(): _, c_notify = echo_handlers(crec) async def server_on_request(dctx: DCtx, method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: - ctx: Context[_Lifespan] = Context( - dctx, lifespan=_Lifespan("app"), connection=Connection(dctx, has_standalone_channel=True) - ) + ctx: Context[_Lifespan] = Context(dctx, lifespan=_Lifespan("app"), connection=Connection.for_loop(dctx)) await ctx.log("info", "x", logger="my.log", meta={"traceId": "t"}) # pyright: ignore[reportDeprecated] return {} diff --git a/tests/server/test_session.py b/tests/server/test_session.py index cb664d5b4e..84d5e3aa93 100644 --- a/tests/server/test_session.py +++ b/tests/server/test_session.py @@ -1,41 +1,38 @@ """Tests for `ServerSession`. -`ServerSession` is a thin proxy over a dispatcher and a `Connection`. Tested -with a stub dispatcher so we can assert what reaches the wire (method, params, -`CallOptions`, related-request-id) without standing up a full transport. +`ServerSession` is a thin per-request proxy over two `Outbound` channels and a +`Connection`. Tested with stub outbounds so we can assert what reaches the wire +(method, params, `CallOptions`) and which channel it routed to, without standing +up a transport. """ from collections.abc import Mapping -from typing import Any, cast +from typing import Any import pytest from pydantic import ValidationError from mcp import types -from mcp.server import Server, ServerRequestContext from mcp.server.connection import Connection from mcp.server.session import ServerSession -from mcp.shared.dispatcher import CallOptions -from mcp.shared.exceptions import NoBackChannelError -from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher +from mcp.shared.dispatcher import CallOptions, Outbound from mcp.shared.message import ServerMessageMetadata +from mcp.shared.version import MODERN_PROTOCOL_VERSIONS from mcp.types import ( LATEST_PROTOCOL_VERSION, ClientCapabilities, Implementation, - InitializeRequestParams, SamplingCapability, SamplingToolsCapability, ) -from .test_runner import connected_runner - -class StubDispatcher: +class StubOutbound: """Records `send_raw_request` / `notify` calls and returns a canned result.""" def __init__(self, result: dict[str, Any] | None = None) -> None: - self.requests: list[tuple[str, Mapping[str, Any] | None, CallOptions | None, Any]] = [] + self.requests: list[tuple[str, Mapping[str, Any] | None, CallOptions | None]] = [] + self.notifications: list[tuple[str, Mapping[str, Any] | None]] = [] self.result = result if result is not None else {} async def send_raw_request( @@ -43,40 +40,36 @@ async def send_raw_request( method: str, params: Mapping[str, Any] | None, opts: CallOptions | None = None, - *, - _related_request_id: Any = None, ) -> dict[str, Any]: - self.requests.append((method, params, opts, _related_request_id)) + self.requests.append((method, params, opts)) return self.result async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: - raise NotImplementedError + self.notifications.append((method, params)) def _make_session( - dispatcher: StubDispatcher, + outbound: StubOutbound, *, capabilities: ClientCapabilities | None = None, - has_standalone_channel: bool = True, - protocol_version: str | None = None, + protocol_version: str = LATEST_PROTOCOL_VERSION, ) -> ServerSession: - conn = Connection(dispatcher, has_standalone_channel=has_standalone_channel) - conn.protocol_version = protocol_version - if capabilities is not None: - conn.client_params = InitializeRequestParams( - protocol_version=LATEST_PROTOCOL_VERSION, - capabilities=capabilities, - client_info=Implementation(name="c", version="0"), - ) - # cast: `ServerSession` is typed to take `JSONRPCDispatcher` but only ever - # calls `send_raw_request` / `notify`, so the stub is structurally sufficient. - return ServerSession(cast("JSONRPCDispatcher[Any]", dispatcher), conn) + """Single-channel session: the stub is both request and standalone outbound.""" + client_info = Implementation(name="c", version="0") if capabilities is not None else None + conn = Connection.from_envelope(protocol_version, client_info, capabilities, outbound=outbound) + return ServerSession(outbound, conn) + + +def _two_channel_session(request_ch: Outbound, standalone_ch: Outbound) -> ServerSession: + """Distinct request/standalone outbounds so routing assertions can tell the channels apart.""" + conn = Connection.from_envelope(LATEST_PROTOCOL_VERSION, None, None, outbound=standalone_ch) + return ServerSession(request_ch, conn) @pytest.mark.anyio async def test_send_request_forwards_timeout_and_progress_callback_as_call_options(): - dispatcher = StubDispatcher(result={"roots": []}) - session = _make_session(dispatcher) + outbound = StubOutbound(result={"roots": []}) + session = _make_session(outbound) async def on_progress(progress: float, total: float | None, message: str | None) -> None: raise NotImplementedError @@ -85,57 +78,78 @@ async def on_progress(progress: float, total: float | None, message: str | None) types.ListRootsRequest(), types.ListRootsResult, request_read_timeout_seconds=2.5, - metadata=ServerMessageMetadata(related_request_id=7), progress_callback=on_progress, ) assert isinstance(result, types.ListRootsResult) - method, _params, opts, related = dispatcher.requests[0] + method, _params, opts = outbound.requests[0] assert method == "roots/list" assert opts == {"timeout": 2.5, "on_progress": on_progress} - assert related == 7 @pytest.mark.anyio async def test_send_request_omits_call_options_when_none_given(): - dispatcher = StubDispatcher(result={"roots": []}) - session = _make_session(dispatcher) + outbound = StubOutbound(result={"roots": []}) + session = _make_session(outbound) await session.send_request(types.ListRootsRequest(), types.ListRootsResult) - _method, _params, opts, related = dispatcher.requests[0] + _method, _params, opts = outbound.requests[0] assert opts is None - assert related is None @pytest.mark.anyio async def test_send_request_timeout_zero_is_forwarded(): """0 is a real timeout (fail at the first checkpoint, `anyio.fail_after(0)` - semantics) and must reach the dispatcher; only `None` means "no timeout".""" - dispatcher = StubDispatcher(result={}) - session = _make_session(dispatcher) + semantics) and must reach the channel; only `None` means "no timeout".""" + outbound = StubOutbound(result={}) + session = _make_session(outbound) await session.send_request(types.PingRequest(), types.EmptyResult, request_read_timeout_seconds=0.0) - assert dispatcher.requests[0][2] == {"timeout": 0.0} + assert outbound.requests[0][2] == {"timeout": 0.0} @pytest.mark.anyio -async def test_send_request_without_back_channel_or_related_id_fails_fast(): - """No standalone channel and no related request to ride on: raise instead - of parking forever on a response that cannot arrive.""" - dispatcher = StubDispatcher(result={}) - session = _make_session(dispatcher, has_standalone_channel=False) - with pytest.raises(NoBackChannelError): - await session.send_request(types.PingRequest(), types.EmptyResult) - assert dispatcher.requests == [] - # With a related request id the message rides that request's stream. - await session.send_request( - types.PingRequest(), types.EmptyResult, metadata=ServerMessageMetadata(related_request_id=3) +async def test_send_request_without_related_id_routes_to_standalone_channel(): + """SDK-defined: no `related_request_id` routes the request onto the connection's standalone channel.""" + request_ch = StubOutbound() + standalone_ch = StubOutbound(result={"roots": []}) + session = _two_channel_session(request_ch, standalone_ch) + await session.send_request(types.ListRootsRequest(), types.ListRootsResult) + assert request_ch.requests == [] + assert standalone_ch.requests[0][0] == "roots/list" + + +@pytest.mark.anyio +async def test_send_request_with_related_id_routes_to_request_channel(): + """SDK-defined: with `related_request_id` the request rides the per-request channel + (the originating POST's response stream over streamable HTTP).""" + request_ch = StubOutbound(result={"action": "cancel"}) + standalone_ch = StubOutbound() + session = _two_channel_session(request_ch, standalone_ch) + result = await session.send_request( + types.ElicitRequest(params=types.ElicitRequestFormParams(message="q", requested_schema={})), + types.ElicitResult, + metadata=ServerMessageMetadata(related_request_id=7), ) - assert dispatcher.requests[0][3] == 3 + assert isinstance(result, types.ElicitResult) + assert standalone_ch.requests == [] + assert request_ch.requests[0][0] == "elicitation/create" + + +@pytest.mark.anyio +async def test_send_notification_routes_by_related_request_id(): + """SDK-defined: notifications select channel by `related_request_id` exactly like requests.""" + request_ch = StubOutbound() + standalone_ch = StubOutbound() + session = _two_channel_session(request_ch, standalone_ch) + await session.send_tool_list_changed() + await session.send_progress_notification("tok", 0.5, related_request_id="req-1") + assert [m for m, _ in standalone_ch.notifications] == ["notifications/tools/list_changed"] + assert [m for m, _ in request_ch.notifications] == ["notifications/progress"] @pytest.mark.anyio async def test_send_request_validates_the_client_result_against_the_surface_schema(): """A spec-method result that fails the per-version surface schema raises `ValidationError` even when the caller's `result_type` would accept it.""" - session = _make_session(StubDispatcher(result={"roots": "nope"})) + session = _make_session(StubOutbound(result={"roots": "nope"})) with pytest.raises(ValidationError): await session.send_request(types.ListRootsRequest(), types.EmptyResult) @@ -143,7 +157,7 @@ async def test_send_request_validates_the_client_result_against_the_surface_sche @pytest.mark.anyio async def test_send_request_passes_a_spec_valid_client_result(): """A spec-valid client result passes the surface gate and parses to the typed model.""" - session = _make_session(StubDispatcher(result={"roots": [{"uri": "file:///ws"}]})) + session = _make_session(StubOutbound(result={"roots": [{"uri": "file:///ws"}]})) result = await session.send_request(types.ListRootsRequest(), types.ListRootsResult) assert isinstance(result, types.ListRootsResult) assert str(result.roots[0].uri) == "file:///ws" @@ -151,9 +165,9 @@ async def test_send_request_passes_a_spec_valid_client_result(): @pytest.mark.anyio async def test_send_request_skips_the_surface_gate_when_method_absent_at_version(): - """Surface row absent for the negotiated version: gate is bypassed and only + """Surface row absent for the connection's version: gate is bypassed and only `result_type` validates.""" - session = _make_session(StubDispatcher(result={}), protocol_version="2026-07-28") + session = _make_session(StubOutbound(result={}), protocol_version=MODERN_PROTOCOL_VERSIONS[0]) result = await session.send_request(types.PingRequest(), types.EmptyResult) assert isinstance(result, types.EmptyResult) @@ -163,7 +177,7 @@ async def test_send_request_validates_result_alias_only(): """Peer results validate alias-only; a snake_case key from the wire is ignored as extra, not populated by Python field name.""" snake = {"role": "assistant", "content": {"type": "text", "text": "x"}, "model": "m", "stop_reason": "endTurn"} - session = _make_session(StubDispatcher(result=snake)) + session = _make_session(StubOutbound(result=snake)) request = types.CreateMessageRequest(params=types.CreateMessageRequestParams(messages=[], max_tokens=1)) result = await session.send_request(request, types.CreateMessageResult) assert result.stop_reason is None @@ -171,9 +185,9 @@ async def test_send_request_validates_result_alias_only(): @pytest.mark.anyio async def test_create_message_with_tools_returns_with_tools_result(): - dispatcher = StubDispatcher(result={"role": "assistant", "content": [{"type": "text", "text": "ok"}], "model": "m"}) + outbound = StubOutbound(result={"role": "assistant", "content": [{"type": "text", "text": "ok"}], "model": "m"}) session = _make_session( - dispatcher, capabilities=ClientCapabilities(sampling=SamplingCapability(tools=SamplingToolsCapability())) + outbound, capabilities=ClientCapabilities(sampling=SamplingCapability(tools=SamplingToolsCapability())) ) result = await session.create_message( # pyright: ignore[reportDeprecated] messages=[types.SamplingMessage(role="user", content=types.TextContent(type="text", text="hi"))], @@ -181,74 +195,22 @@ async def test_create_message_with_tools_returns_with_tools_result(): tools=[types.Tool(name="t", input_schema={"type": "object"})], ) assert isinstance(result, types.CreateMessageResultWithTools) - method, params, _opts, _related = dispatcher.requests[0] + method, params, _opts = outbound.requests[0] assert method == "sampling/createMessage" assert params is not None and params["tools"][0]["name"] == "t" def test_check_client_capability_delegates_to_connection(): - dispatcher = StubDispatcher() - session = _make_session(dispatcher, capabilities=ClientCapabilities(sampling=SamplingCapability())) + outbound = StubOutbound() + session = _make_session(outbound, capabilities=ClientCapabilities(sampling=SamplingCapability())) assert session.check_client_capability(ClientCapabilities(sampling=SamplingCapability())) is True assert session.check_client_capability(ClientCapabilities(experimental={"x": {}})) is False -def _runner_server(seen_versions: list[str | None]) -> Server[dict[str, Any]]: - """A lowlevel Server whose tools/list handler records `ctx.session.protocol_version`.""" - - async def list_tools( - ctx: ServerRequestContext[dict[str, Any], Any], params: types.PaginatedRequestParams | None - ) -> types.ListToolsResult: - seen_versions.append(ctx.session.protocol_version) - return types.ListToolsResult(tools=[]) - - return Server(name="test-server", version="0.0.1", on_list_tools=list_tools) - - -def _init_params(protocol_version: str) -> dict[str, Any]: - return InitializeRequestParams( - protocol_version=protocol_version, - capabilities=ClientCapabilities(), - client_info=Implementation(name="test-client", version="1.0"), - ).model_dump(by_alias=True, exclude_none=True) - - -@pytest.mark.anyio -async def test_protocol_version_is_none_before_initialize(): - """No negotiated version is readable before the initialize handshake.""" - async with connected_runner(_runner_server([]), initialized=False) as (_client, runner): - assert runner.session.protocol_version is None - - -@pytest.mark.anyio -async def test_protocol_version_is_negotiated_version_after_initialize(): - """A supported requested version is echoed back and readable on the session, - both directly and from inside a handler via `ctx.session`.""" - seen: list[str | None] = [] - async with connected_runner(_runner_server(seen), initialized=False) as (client, runner): - result = await client.send_raw_request("initialize", _init_params("2025-03-26")) - assert result["protocolVersion"] == "2025-03-26" - assert runner.session.protocol_version == "2025-03-26" - await client.send_raw_request("tools/list", None) - assert seen == ["2025-03-26"] - - -@pytest.mark.anyio -async def test_protocol_version_reads_latest_when_requested_version_unsupported(): - """An unsupported requested version negotiates down to LATEST_PROTOCOL_VERSION.""" - async with connected_runner(_runner_server([]), initialized=False) as (client, runner): - result = await client.send_raw_request("initialize", _init_params("1999-01-01")) - assert result["protocolVersion"] == LATEST_PROTOCOL_VERSION - assert runner.session.protocol_version == LATEST_PROTOCOL_VERSION - - -@pytest.mark.anyio -async def test_protocol_version_is_none_on_stateless_connection(): - """Stateless connections never see a handshake: requests flow, but the - negotiated version legitimately stays None.""" - seen: list[str | None] = [] - async with connected_runner(_runner_server(seen), initialized=False, stateless=True) as (client, runner): - result = await client.send_raw_request("tools/list", None) - assert result == {"tools": []} - assert seen == [None] - assert runner.session.protocol_version is None +def test_protocol_version_proxies_connection(): + """SDK-defined: `session.protocol_version` reads through to the held `Connection`.""" + _ARBITRARY_VERSION = "sentinel-version" # identity-only: any string the connection holds + conn = Connection.from_envelope(_ARBITRARY_VERSION, None, None) + session = ServerSession(StubOutbound(), conn) + assert session.protocol_version == _ARBITRARY_VERSION + assert session.client_params is None diff --git a/tests/server/test_stateless_mode.py b/tests/server/test_stateless_mode.py index 7002fe1cf4..91d344253a 100644 --- a/tests/server/test_stateless_mode.py +++ b/tests/server/test_stateless_mode.py @@ -1,187 +1,170 @@ -"""Tests for stateless HTTP mode limitations. +"""Tests for the no-back-channel path (stateless HTTP). -Stateless HTTP mode does not support server-to-client requests because there -is no persistent connection for bidirectional communication. These tests verify -that appropriate errors are raised when attempting to use unsupported features. +A `Connection.from_envelope(...)` connection installs the no-channel sentinel +as its standalone outbound, so server-to-client requests with no related +request to ride on raise `NoBackChannelError` from the channel itself. See: https://github.com/modelcontextprotocol/python-sdk/issues/1097 """ +from collections.abc import Mapping from typing import Any -from unittest.mock import Mock -import anyio import pytest from mcp import types from mcp.server.connection import Connection -from mcp.server.context import ServerRequestContext -from mcp.server.lowlevel.server import Server from mcp.server.session import ServerSession -from mcp.shared.exceptions import NoBackChannelError, StatelessModeNotSupported -from mcp.shared.jsonrpc_dispatcher import JSONRPCDispatcher -from mcp.shared.message import SessionMessage -from mcp.types import JSONRPCRequest, JSONRPCResponse, ListToolsResult, PaginatedRequestParams - - -def _make_session(*, stateless: bool) -> ServerSession: - """A `ServerSession` with a mock dispatcher; the stateless guard fires before any send.""" - return ServerSession( - Mock(spec=JSONRPCDispatcher), - Connection(Mock(), has_standalone_channel=False), - stateless=stateless, - ) +from mcp.shared.dispatcher import CallOptions +from mcp.shared.exceptions import NoBackChannelError +from mcp.types import LATEST_PROTOCOL_VERSION + + +class StubOutbound: + """Records `send_raw_request` / `notify` calls and returns a canned result.""" + + def __init__(self, result: dict[str, Any] | None = None) -> None: + self.requests: list[tuple[str, Mapping[str, Any] | None, CallOptions | None]] = [] + self.notifications: list[tuple[str, Mapping[str, Any] | None]] = [] + self.result = result if result is not None else {} + + async def send_raw_request( + self, + method: str, + params: Mapping[str, Any] | None, + opts: CallOptions | None = None, + ) -> dict[str, Any]: + self.requests.append((method, params, opts)) + return self.result + + async def notify(self, method: str, params: Mapping[str, Any] | None) -> None: + self.notifications.append((method, params)) + + +def _no_channel_session(request_ch: StubOutbound | None = None) -> tuple[ServerSession, StubOutbound]: + """A session whose standalone channel is the connection's no-channel + sentinel; the request channel is a working stub.""" + conn = Connection.from_envelope(LATEST_PROTOCOL_VERSION, None, None) + assert conn.has_standalone_channel is False + request = request_ch if request_ch is not None else StubOutbound() + return ServerSession(request, conn), request @pytest.fixture -def stateless_session() -> ServerSession: - return _make_session(stateless=True) +def no_channel_session() -> ServerSession: + session, _ = _no_channel_session() + return session @pytest.mark.anyio -async def test_list_roots_fails_in_stateless_mode(stateless_session: ServerSession): - """Test that list_roots raises StatelessModeNotSupported in stateless mode.""" - with pytest.raises(StatelessModeNotSupported, match="list_roots"): - await stateless_session.list_roots() # pyright: ignore[reportDeprecated] +async def test_list_roots_raises_no_back_channel(no_channel_session: ServerSession): + """SDK-defined: `list_roots` has no `related_request_id` so it always rides + the standalone channel, which raises here.""" + with pytest.raises(NoBackChannelError) as exc: + await no_channel_session.list_roots() # pyright: ignore[reportDeprecated] + assert exc.value.method == "roots/list" @pytest.mark.anyio -async def test_create_message_fails_in_stateless_mode(stateless_session: ServerSession): - """Test that create_message raises StatelessModeNotSupported in stateless mode.""" - with pytest.raises(StatelessModeNotSupported, match="sampling"): - await stateless_session.create_message( # pyright: ignore[reportDeprecated] - messages=[ - types.SamplingMessage( - role="user", - content=types.TextContent(type="text", text="hello"), - ) - ], - max_tokens=100, - ) +async def test_send_ping_raises_no_back_channel(no_channel_session: ServerSession): + """SDK-defined: `send_ping` rides the standalone channel and raises when there is none.""" + with pytest.raises(NoBackChannelError) as exc: + await no_channel_session.send_ping() + assert exc.value.method == "ping" @pytest.mark.anyio -async def test_elicit_form_fails_in_stateless_mode(stateless_session: ServerSession): - """Test that elicit_form raises StatelessModeNotSupported in stateless mode.""" - with pytest.raises(StatelessModeNotSupported, match="elicitation"): - await stateless_session.elicit_form( - message="Please provide input", - requested_schema={"type": "object", "properties": {}}, +async def test_create_message_raises_no_back_channel_without_related_id(no_channel_session: ServerSession): + """SDK-defined: `create_message` without a related id rides the standalone channel and raises.""" + with pytest.raises(NoBackChannelError) as exc: + await no_channel_session.create_message( # pyright: ignore[reportDeprecated] + messages=[types.SamplingMessage(role="user", content=types.TextContent(type="text", text="hi"))], + max_tokens=100, ) + assert exc.value.method == "sampling/createMessage" @pytest.mark.anyio -async def test_elicit_url_fails_in_stateless_mode(stateless_session: ServerSession): - """Test that elicit_url raises StatelessModeNotSupported in stateless mode.""" - with pytest.raises(StatelessModeNotSupported, match="elicitation"): - await stateless_session.elicit_url( - message="Please authenticate", - url="https://example.com/auth", - elicitation_id="test-123", - ) +async def test_elicit_form_raises_no_back_channel_without_related_id(no_channel_session: ServerSession): + """SDK-defined: `elicit_form` without a related id rides the standalone channel and raises.""" + with pytest.raises(NoBackChannelError) as exc: + await no_channel_session.elicit_form(message="m", requested_schema={"type": "object", "properties": {}}) + assert exc.value.method == "elicitation/create" @pytest.mark.anyio -async def test_elicit_deprecated_fails_in_stateless_mode(stateless_session: ServerSession): - """Test that the deprecated elicit method also fails in stateless mode.""" - with pytest.raises(StatelessModeNotSupported, match="elicitation"): - await stateless_session.elicit( - message="Please provide input", - requested_schema={"type": "object", "properties": {}}, - ) +async def test_elicit_url_raises_no_back_channel_without_related_id(no_channel_session: ServerSession): + """SDK-defined: `elicit_url` without a related id rides the standalone channel and raises.""" + with pytest.raises(NoBackChannelError) as exc: + await no_channel_session.elicit_url(message="m", url="https://example.com/auth", elicitation_id="e-1") + assert exc.value.method == "elicitation/create" @pytest.mark.anyio -async def test_stateless_error_message_is_actionable(stateless_session: ServerSession): - """Test that the error message provides actionable guidance.""" - with pytest.raises(StatelessModeNotSupported) as exc_info: - await stateless_session.list_roots() # pyright: ignore[reportDeprecated] - - error_message = str(exc_info.value) - # Should mention it's stateless mode - assert "stateless HTTP mode" in error_message - # Should explain why it doesn't work - assert "server-to-client requests" in error_message - # Should tell user how to fix it - assert "stateless_http=False" in error_message +async def test_elicit_deprecated_raises_no_back_channel_without_related_id(no_channel_session: ServerSession): + """SDK-defined: the deprecated `elicit` alias routes the same as `elicit_form` and raises.""" + with pytest.raises(NoBackChannelError) as exc: + await no_channel_session.elicit(message="m", requested_schema={"type": "object", "properties": {}}) + assert exc.value.method == "elicitation/create" @pytest.mark.anyio -async def test_exception_has_method_attribute(stateless_session: ServerSession): - """Test that the exception has a method attribute for programmatic access.""" - with pytest.raises(StatelessModeNotSupported) as exc_info: - await stateless_session.list_roots() # pyright: ignore[reportDeprecated] +async def test_send_request_raises_no_back_channel_without_related_id(no_channel_session: ServerSession): + """SDK-defined: the generic `send_request` path with no metadata routes standalone and raises.""" + with pytest.raises(NoBackChannelError) as exc: + await no_channel_session.send_request(types.ListRootsRequest(), types.ListRootsResult) + assert exc.value.method == "roots/list" - assert exc_info.value.method == "list_roots" - -@pytest.fixture -def stateful_session() -> ServerSession: - return _make_session(stateless=False) +@pytest.mark.anyio +async def test_elicit_form_with_related_id_rides_the_request_channel(): + """SDK-defined: with a related request the message rides the per-request + channel, so the no-channel standalone is never touched and the call succeeds.""" + session, request_ch = _no_channel_session(StubOutbound(result={"action": "cancel"})) + result = await session.elicit_form( + message="m", requested_schema={"type": "object", "properties": {}}, related_request_id=3 + ) + assert isinstance(result, types.ElicitResult) + assert request_ch.requests[0][0] == "elicitation/create" @pytest.mark.anyio -async def test_stateful_mode_does_not_raise_stateless_error( - stateful_session: ServerSession, monkeypatch: pytest.MonkeyPatch -): - """Test that StatelessModeNotSupported is not raised in stateful mode. - - We mock send_request to avoid blocking on I/O while still verifying - that the stateless check passes. - """ - send_request_called = False +async def test_send_log_message_with_related_id_rides_the_request_channel(): + """SDK-defined: the deprecated ``send_log_message`` notification with a related id + rides the per-request channel, so it is delivered even with no standalone back-channel.""" + session, request_ch = _no_channel_session() + await session.send_log_message( # pyright: ignore[reportDeprecated] + level="info", data="hello", logger="test", related_request_id=3 + ) + assert request_ch.notifications == [("notifications/message", {"level": "info", "data": "hello", "logger": "test"})] - async def mock_send_request(*_: Any, **__: Any) -> types.ListRootsResult: - nonlocal send_request_called - send_request_called = True - return types.ListRootsResult(roots=[]) - monkeypatch.setattr(stateful_session, "send_request", mock_send_request) +@pytest.mark.anyio +async def test_unrelated_notification_is_dropped_silently(): + """SDK-defined: notifications on the no-channel standalone are best-effort — dropped, never raised.""" + session, request_ch = _no_channel_session() + await session.send_tool_list_changed() + assert request_ch.notifications == [] - # This should NOT raise StatelessModeNotSupported - result = await stateful_session.list_roots() # pyright: ignore[reportDeprecated] - assert send_request_called +@pytest.mark.anyio +async def test_loop_connection_outbound_does_not_raise_no_back_channel(): + """SDK-defined: a `for_loop` connection holds a real outbound, so the + standalone path reaches the channel rather than raising.""" + standalone = StubOutbound(result={"roots": []}) + conn = Connection.for_loop(standalone) + assert conn.has_standalone_channel is True + session = ServerSession(StubOutbound(), conn) + result = await session.list_roots() # pyright: ignore[reportDeprecated] assert isinstance(result, types.ListRootsResult) + assert standalone.requests[0][0] == "roots/list" @pytest.mark.anyio -async def test_server_run_stateless_wires_no_standalone_channel(): - """`Server.run(stateless=True)` must wire `Connection.has_standalone_channel=False`. - - Stateless HTTP has no standalone GET stream, so server-initiated requests on - the connection must fail fast with `NoBackChannelError` rather than write to - a channel that will never deliver a response. The `ServerSession` typed - helpers carry their own stateless guard (tested above); this pins the - `Connection` wiring that `Server.run` produces. - """ - captured: list[Connection] = [] - - async def list_tools(ctx: ServerRequestContext[Any], params: PaginatedRequestParams | None) -> ListToolsResult: - # `ServerRequestContext` doesn't expose `connection` directly yet (it - # will after the Context rework); reach it via the session for now. - captured.append(ctx.session._connection) # pyright: ignore[reportPrivateUsage] - return ListToolsResult(tools=[]) - - server: Server[Any] = Server("test", on_list_tools=list_tools) - - to_server, server_read = anyio.create_memory_object_stream[SessionMessage | Exception](10) - server_write, from_server = anyio.create_memory_object_stream[SessionMessage](10) - - async def run_server() -> None: - await server.run(server_read, server_write, server.create_initialization_options(), stateless=True) - - async with anyio.create_task_group() as tg, to_server, server_read, server_write, from_server: - tg.start_soon(run_server) - # stateless=True skips the init gate, so tools/list routes immediately. - await to_server.send(SessionMessage(JSONRPCRequest(jsonrpc="2.0", id=1, method="tools/list"))) - with anyio.fail_after(5): - response = (await from_server.receive()).message - assert isinstance(response, JSONRPCResponse) - tg.cancel_scope.cancel() - - assert len(captured) == 1 - conn = captured[0] - assert conn.has_standalone_channel is False - with pytest.raises(NoBackChannelError): +async def test_from_envelope_connection_ping_raises_no_back_channel(): + """SDK-defined: `Connection`'s own helpers route through the same sentinel, + so `ping` on a `from_envelope` connection raises.""" + conn = Connection.from_envelope(LATEST_PROTOCOL_VERSION, None, None) + with pytest.raises(NoBackChannelError) as exc: await conn.ping() + assert exc.value.method == "ping" diff --git a/tests/server/test_streamable_http_manager.py b/tests/server/test_streamable_http_manager.py index f02e520eea..0e8afed509 100644 --- a/tests/server/test_streamable_http_manager.py +++ b/tests/server/test_streamable_http_manager.py @@ -103,11 +103,12 @@ async def running_manager(): @pytest.mark.anyio async def test_stateful_session_cleanup_on_graceful_exit(running_manager: tuple[StreamableHTTPSessionManager, Server]): - manager, app = running_manager + manager, _app = running_manager - mock_mcp_run = AsyncMock(return_value=None) - # This will be called by StreamableHTTPSessionManager's run_server -> self.app.run - app.run = mock_mcp_run + # The manager's `run_server` task drives `serve_loop` directly (the manager + # owns lifespan); patch that seam so the loop returns immediately and we + # can observe the cleanup that follows. + mock_serve = AsyncMock(return_value=None) sent_messages: list[Message] = [] @@ -125,7 +126,8 @@ async def mock_receive(): # pragma: no cover return {"type": "http.request", "body": b"", "more_body": False} # Trigger session creation - await manager.handle_request(scope, mock_receive, mock_send) + with patch("mcp.server.streamable_http_manager.serve_loop", mock_serve): + await manager.handle_request(scope, mock_receive, mock_send) # Extract session ID from response headers session_id = None @@ -140,10 +142,9 @@ async def mock_receive(): # pragma: no cover assert session_id is not None, "Session ID not found in response headers" - # Ensure MCPServer.run was called - mock_mcp_run.assert_called_once() + mock_serve.assert_called_once() - # At this point, mock_mcp_run has completed, and the finally block in + # At this point, mock_serve has completed, and the finally block in # StreamableHTTPSessionManager's run_server should have executed. # To ensure the task spawned by handle_request finishes and cleanup occurs: @@ -158,10 +159,9 @@ async def mock_receive(): # pragma: no cover @pytest.mark.anyio async def test_stateful_session_cleanup_on_exception(running_manager: tuple[StreamableHTTPSessionManager, Server]): - manager, app = running_manager + manager, _app = running_manager - mock_mcp_run = AsyncMock(side_effect=TestException("Simulated crash")) - app.run = mock_mcp_run + mock_serve = AsyncMock(side_effect=TestException("Simulated crash")) sent_messages: list[Message] = [] @@ -184,7 +184,8 @@ async def mock_receive(): # pragma: no cover return {"type": "http.request", "body": b"", "more_body": False} # Trigger session creation - await manager.handle_request(scope, mock_receive, mock_send) + with patch("mcp.server.streamable_http_manager.serve_loop", mock_serve): + await manager.handle_request(scope, mock_receive, mock_send) session_id = None for msg in sent_messages: # pragma: no branch @@ -198,7 +199,7 @@ async def mock_receive(): # pragma: no cover assert session_id is not None, "Session ID not found in response headers" - mock_mcp_run.assert_called_once() + mock_serve.assert_called_once() # Give other tasks a chance to run to ensure the finally block executes await anyio.sleep(0.01) @@ -229,9 +230,6 @@ def track_transport(*args: Any, **kwargs: Any) -> StreamableHTTPServerTransport: with patch.object(streamable_http_manager, "StreamableHTTPServerTransport", side_effect=track_transport): async with manager.run(): - # Mock app.run to complete immediately - app.run = AsyncMock(return_value=None) - # Send a simple request sent_messages: list[Message] = [] diff --git a/tests/server/test_streamable_http_modern.py b/tests/server/test_streamable_http_modern.py index ce62d44cec..35ee17f3d6 100644 --- a/tests/server/test_streamable_http_modern.py +++ b/tests/server/test_streamable_http_modern.py @@ -2,61 +2,38 @@ The interaction suite under ``tests/interaction/transports/test_hosting_http_modern.py`` pins the wire contract end to end; these tests cover the module's internal seams directly -- -the closed back-channel on the dispatcher and dispatch context, the exception-to-error -mapping in ``handle()``, and the request-validation ladder in ``handle_modern_request``. +the closed back-channel on the dispatch context, and the request-validation ladder in +``handle_modern_request``. """ import logging -from collections.abc import Mapping from typing import Any import anyio import httpx import pytest -from starlette.requests import Request from starlette.types import Receive, Scope, Send -import mcp.server._streamable_http_modern as modern -from mcp.server import Server, ServerRequestContext -from mcp.server._streamable_http_modern import ( - SingleExchangeDispatcher, - _SingleExchangeDispatchContext, - handle_modern_request, -) +from mcp.server import Server, ServerRequestContext, runner +from mcp.server._streamable_http_modern import _SingleExchangeDispatchContext, handle_modern_request from mcp.server.transport_security import TransportSecuritySettings -from mcp.shared.dispatcher import DispatchContext from mcp.shared.exceptions import NoBackChannelError +from mcp.shared.inbound import MCP_PROTOCOL_VERSION_HEADER from mcp.shared.transport_context import TransportContext -from mcp.types import INVALID_PARAMS, PARSE_ERROR, JSONRPCError, JSONRPCRequest, ListToolsResult, PaginatedRequestParams +from mcp.shared.version import MODERN_PROTOCOL_VERSIONS +from mcp.types import ( + CLIENT_CAPABILITIES_META_KEY, + CLIENT_INFO_META_KEY, + INVALID_REQUEST, + PARSE_ERROR, + PROTOCOL_VERSION_META_KEY, + ListToolsResult, + PaginatedRequestParams, +) pytestmark = pytest.mark.anyio -def _request() -> Request: - return Request({"type": "http", "method": "POST", "headers": []}) - - -async def test_single_exchange_dispatcher_has_no_back_channel_and_is_never_driven() -> None: - """The dispatcher refuses server-initiated requests, drops notifications, and is not run-driven. - - A 2026-07-28 POST has no channel for the server to push to the client, and ``ServerRunner`` - never calls ``run()`` on this dispatcher -- ``handle()`` is invoked directly per request. - """ - dispatcher = SingleExchangeDispatcher(_request()) - with pytest.raises(NoBackChannelError): - await dispatcher.send_raw_request("sampling/createMessage", None) - assert await dispatcher.notify("notifications/message", None) is None - - async def on_request(ctx: DispatchContext[Any], method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: - raise AssertionError("unreachable") # pragma: no cover - - async def on_notify(ctx: DispatchContext[Any], method: str, params: Mapping[str, Any] | None) -> None: - raise AssertionError("unreachable") # pragma: no cover - - with pytest.raises(RuntimeError, match="never driven"): - await dispatcher.run(on_request, on_notify) - - async def test_single_exchange_dispatch_context_has_no_back_channel() -> None: """The per-request dispatch context refuses server-initiated requests and drops notify/progress.""" dctx = _SingleExchangeDispatchContext( @@ -71,41 +48,49 @@ async def test_single_exchange_dispatch_context_has_no_back_channel() -> None: assert await dctx.progress(0.5, total=1.0, message="half") is None -async def test_handle_maps_validation_error_to_invalid_params() -> None: - """A handler raising ``ValidationError`` is mapped to a ``-32602`` JSON-RPC error. - - Mirrors ``JSONRPCDispatcher``'s exception-to-wire boundary: a Pydantic validation failure - inside the handler becomes ``INVALID_PARAMS`` rather than the generic internal error. - """ - - async def on_request(ctx: DispatchContext[Any], method: str, params: Mapping[str, Any] | None) -> dict[str, Any]: - JSONRPCRequest.model_validate({}) # raises ValidationError - raise AssertionError("unreachable") # pragma: no cover - - dispatcher = SingleExchangeDispatcher(_request()) - msg = await dispatcher.handle(JSONRPCRequest(jsonrpc="2.0", id=7, method="tools/call", params={}), on_request) - assert isinstance(msg, JSONRPCError) - assert msg.id == 7 - assert msg.error.code == INVALID_PARAMS - - def _asgi_client(server: Server[Any], security_settings: TransportSecuritySettings | None = None) -> httpx.AsyncClient: async def app(scope: Scope, receive: Receive, send: Send) -> None: - await handle_modern_request(server, security_settings, "2026-07-28", scope, receive, send) + async with server.lifespan(server) as lifespan_state: + await handle_modern_request(server, security_settings, lifespan_state, scope, receive, send) - return httpx.AsyncClient(transport=httpx.ASGITransport(app=app), base_url="http://testserver") + return httpx.AsyncClient( + transport=httpx.ASGITransport(app=app), + base_url="http://testserver", + headers={MCP_PROTOCOL_VERSION_HEADER: MODERN_PROTOCOL_VERSIONS[0]}, + ) -async def test_handle_modern_request_rejects_non_post_with_405() -> None: - """A GET on the 2026-07-28 entry is answered with 405 before any body is read.""" +async def test_handle_modern_request_rejects_non_post_with_http_405_and_allow_header() -> None: + """SDK-defined: a GET on the modern entry is an HTTP-verb mismatch — 405 Method Not + Allowed with ``Allow: POST`` per RFC 9110. This is HTTP-layer (before JSON-RPC parsing) + so there is no JSON-RPC body.""" async with _asgi_client(Server("test")) as http: response = await http.get("/mcp") assert response.status_code == 405 assert response.headers["allow"] == "POST" + assert response.content == b"" + + +async def test_handle_modern_request_rejects_a_notification_body_with_invalid_request() -> None: + """SDK-defined: well-formed JSON that isn't a single JSON-RPC request object (e.g. a + notification, which lacks ``id``) is ``INVALID_REQUEST`` — distinct from ``PARSE_ERROR``, + which is for malformed JSON.""" + async with _asgi_client(Server("test")) as http: + response = await http.post( + "/mcp", + content=b'{"jsonrpc":"2.0","method":"notifications/cancelled","params":{"requestId":1}}', + headers={"content-type": "application/json"}, + ) + assert response.status_code == 400 + assert response.json()["error"]["code"] == INVALID_REQUEST async def test_handle_modern_request_rejects_malformed_body_with_parse_error() -> None: - """A POST whose body is not a valid ``JSONRPCRequest`` returns 400 with ``-32700``.""" + """An unparseable POST body yields HTTP 400 with a ``PARSE_ERROR`` JSON-RPC error envelope. + + SDK-defined: the 400 status comes from the SDK's error-code→HTTP-status table; spec-mandated: the + body is a full JSON-RPC error object with ``id: null`` and code ``-32700``. + """ async with _asgi_client(Server("test")) as http: response = await http.post("/mcp", content=b"not json", headers={"content-type": "application/json"}) assert response.status_code == 400 @@ -113,7 +98,7 @@ async def test_handle_modern_request_rejects_malformed_body_with_parse_error() - assert response.json() == { "jsonrpc": "2.0", "id": None, - "error": {"code": PARSE_ERROR, "message": "Parse error", "data": None}, + "error": {"code": PARSE_ERROR, "message": "Parse error"}, } @@ -129,20 +114,45 @@ async def test_handle_modern_request_returns_transport_security_error_response() def _list_tools_body() -> dict[str, Any]: """A minimal valid 2026-07-28 ``tools/list`` request body, including the required ``_meta`` envelope.""" meta = { - "io.modelcontextprotocol/protocolVersion": "2026-07-28", - "io.modelcontextprotocol/clientInfo": {"name": "raw", "version": "0.0.0"}, - "io.modelcontextprotocol/clientCapabilities": {}, + PROTOCOL_VERSION_META_KEY: MODERN_PROTOCOL_VERSIONS[0], + CLIENT_INFO_META_KEY: {"name": "raw", "version": "0.0.0"}, + CLIENT_CAPABILITIES_META_KEY: {}, } return {"jsonrpc": "2.0", "id": 1, "method": "tools/list", "params": {"_meta": meta}} +async def test_handle_modern_request_routes_with_mis_shaped_envelope_client_info() -> None: + """SDK-defined: a mis-shaped ``clientInfo`` envelope value is treated as not supplied — + the request still routes (200 + result) and the handler observes ``client_params is None`` + rather than the request being rejected at the validation ladder. A non-spec method is + used so the kernel's per-method params validation does not re-reject the envelope.""" + seen: list[object] = [] + + async def greet(ctx: ServerRequestContext, params: PaginatedRequestParams) -> dict[str, Any]: + seen.append(ctx.session.client_params) + return {"ok": True} + + server: Server[Any] = Server("test") + server.add_request_handler("custom/greet", PaginatedRequestParams, greet) + + body = _list_tools_body() + body["method"] = "custom/greet" + body["params"]["_meta"][CLIENT_INFO_META_KEY] = "not-an-object" + async with _asgi_client(server) as http: + response = await http.post("/mcp", json=body, headers={"content-type": "application/json"}) + assert response.status_code == 200 + assert response.json()["result"] == {"ok": True} + assert seen == [None] + + async def test_handle_modern_request_sends_response_when_exit_stack_cleanup_raises( caplog: pytest.LogCaptureFixture, ) -> None: """A raising ``connection.exit_stack`` callback is logged and swallowed; the computed result still ships. - The exit-stack guard mirrors ``ServerRunner.run``: cleanup runs in a ``finally`` after the - handler, and an exception there must not displace the JSON-RPC response that was already built. + The exit-stack guard is `aclose_shielded`: cleanup runs in `serve_one`'s ``finally`` after + the handler, and an exception there must not displace the JSON-RPC response that was already + built. """ async def boom() -> None: @@ -152,7 +162,7 @@ async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | ctx.session._connection.exit_stack.push_async_callback(boom) return ListToolsResult(tools=[], ttl_ms=0, cache_scope="public") - with caplog.at_level(logging.ERROR, logger=modern.__name__): + with caplog.at_level(logging.ERROR, logger=runner.__name__): async with _asgi_client(Server("test", on_list_tools=list_tools)) as http: response = await http.post("/mcp", json=_list_tools_body(), headers={"content-type": "application/json"}) @@ -170,7 +180,7 @@ async def test_handle_modern_request_sends_response_when_exit_stack_cleanup_hang blocker at its first checkpoint, the abandonment warning is logged, and the JSON-RPC response that was built before cleanup is sent unchanged. """ - monkeypatch.setattr(modern, "_EXIT_STACK_CLOSE_TIMEOUT", 0) + monkeypatch.setattr(runner, "_EXIT_STACK_CLOSE_TIMEOUT", 0) async def block() -> None: await anyio.Event().wait() @@ -180,7 +190,7 @@ async def list_tools(ctx: ServerRequestContext, params: PaginatedRequestParams | ctx.session._connection.exit_stack.push_async_callback(block) return ListToolsResult(tools=[], ttl_ms=0, cache_scope="public") - with anyio.fail_after(5), caplog.at_level(logging.WARNING, logger=modern.__name__): + with anyio.fail_after(5), caplog.at_level(logging.WARNING, logger=runner.__name__): async with _asgi_client(Server("test", on_list_tools=list_tools)) as http: response = await http.post("/mcp", json=_list_tools_body(), headers={"content-type": "application/json"}) # coverage.py on Python 3.11 misreports the lines below as unhit (the test passes there); diff --git a/tests/shared/test_inbound.py b/tests/shared/test_inbound.py new file mode 100644 index 0000000000..75ed93e99a --- /dev/null +++ b/tests/shared/test_inbound.py @@ -0,0 +1,221 @@ +"""Pure-function tests of :mod:`mcp.shared.inbound`. + +Independent verifier of the classifier: every ladder rung is exercised +pass+fail with no ``mcp.server`` / transport imports and no inlined error-code +or protocol-version literals — all facts are imported from their one source. +""" + +import dataclasses +from typing import Any + +import pytest + +from mcp.shared.inbound import ( + ERROR_CODE_HTTP_STATUS, + MCP_PROTOCOL_VERSION_HEADER, + InboundLadderRejection, + InboundModernRoute, + classify_inbound_request, +) +from mcp.shared.version import MODERN_PROTOCOL_VERSIONS +from mcp.types import ( + CLIENT_CAPABILITIES_META_KEY, + CLIENT_INFO_META_KEY, + LATEST_PROTOCOL_VERSION, + PROTOCOL_VERSION_META_KEY, +) +from mcp.types.jsonrpc import ( + HEADER_MISMATCH, + INVALID_PARAMS, + INVALID_REQUEST, + METHOD_NOT_FOUND, + MISSING_REQUIRED_CLIENT_CAPABILITY, + PARSE_ERROR, + UNSUPPORTED_PROTOCOL_VERSION, +) + +MODERN = MODERN_PROTOCOL_VERSIONS[0] +"""The modern protocol-version string, read from the registry — never inlined here.""" + +CLIENT_INFO = {"name": "t", "version": "0"} +CLIENT_CAPS: dict[str, Any] = {} + + +def envelope( + method: str = "tools/list", + *, + version: str = MODERN, + drop: frozenset[str] = frozenset(), +) -> dict[str, Any]: + """Build a JSON-RPC body carrying a complete modern ``_meta`` envelope. + + ``drop`` removes named envelope keys so rung-1 failures are driven from one + table instead of repeating reserved-key constants per call site. + """ + meta: dict[str, Any] = { + PROTOCOL_VERSION_META_KEY: version, + CLIENT_INFO_META_KEY: CLIENT_INFO, + CLIENT_CAPABILITIES_META_KEY: CLIENT_CAPS, + } + for key in drop: + del meta[key] + return {"jsonrpc": "2.0", "id": 1, "method": method, "params": {"_meta": meta}} + + +def assert_rejected(result: object, code: int) -> InboundLadderRejection: + assert isinstance(result, InboundLadderRejection) + assert result.code == code + return result + + +# --- rung 1: envelope-three-keys ----------------------------------------------- + + +@pytest.mark.parametrize( + "body", + [ + pytest.param({"jsonrpc": "2.0", "id": 1, "method": "tools/list"}, id="no-params"), + pytest.param({"jsonrpc": "2.0", "id": 1, "method": "tools/list", "params": {}}, id="no-meta"), + pytest.param(envelope(drop=frozenset({PROTOCOL_VERSION_META_KEY})), id="meta-missing-version"), + pytest.param(envelope(drop=frozenset({CLIENT_INFO_META_KEY})), id="meta-missing-client-info"), + pytest.param(envelope(drop=frozenset({CLIENT_CAPABILITIES_META_KEY})), id="meta-missing-client-caps"), + ], +) +def test_envelope_rung_rejects_missing_keys(body: dict[str, Any]) -> None: + """Spec-mandated: a modern request lacking any of the three reserved ``_meta`` keys is rejected INVALID_PARAMS.""" + rejection = assert_rejected(classify_inbound_request(body), INVALID_PARAMS) + assert rejection.data is None + + +@pytest.mark.parametrize( + "body", + [ + pytest.param({"jsonrpc": "2.0", "id": 1, "method": "tools/list", "params": None}, id="params-none"), + pytest.param({"jsonrpc": "2.0", "id": 1, "method": "tools/list", "params": {"_meta": None}}, id="meta-none"), + pytest.param( + {"jsonrpc": "2.0", "id": 1, "method": "tools/list", "params": {"_meta": 0}}, id="meta-non-mapping" + ), + ], +) +def test_envelope_rung_rejects_non_mapping_shapes(body: dict[str, Any]) -> None: + """Spec-mandated: non-mapping ``params`` / ``_meta`` cannot carry the envelope and reject INVALID_PARAMS.""" + assert_rejected(classify_inbound_request(body), INVALID_PARAMS) + + +# --- rung 2: protocol-version-supported ---------------------------------------- + + +def test_version_rung_rejects_unsupported_with_data_shape() -> None: + """Spec-mandated: an envelope version outside the modern set rejects with the ``supported``/``requested`` data.""" + rejection = assert_rejected( + classify_inbound_request(envelope(version=LATEST_PROTOCOL_VERSION)), + UNSUPPORTED_PROTOCOL_VERSION, + ) + assert rejection.data == { + "supported": list(MODERN_PROTOCOL_VERSIONS), + "requested": LATEST_PROTOCOL_VERSION, + } + + +def test_version_rung_data_reflects_supplied_supported_list() -> None: + """SDK-defined: the caller-supplied ``supported_modern_versions`` is what rejection ``data.supported`` echoes.""" + custom = (LATEST_PROTOCOL_VERSION,) + rejection = assert_rejected( + classify_inbound_request(envelope(), supported_modern_versions=custom), + UNSUPPORTED_PROTOCOL_VERSION, + ) + assert rejection.data == {"supported": list(custom), "requested": MODERN} + + +# --- rung 3: header ↔ envelope agreement --------------------------------------- + + +def test_header_rung_does_not_reject_when_headers_arg_is_none() -> None: + """SDK-defined: ``headers=None`` (non-HTTP transports) means rung 3 has nothing to check and the ladder proceeds.""" + result = classify_inbound_request(envelope(), headers=None) + assert isinstance(result, InboundModernRoute) + + +def test_header_rung_passes_when_header_matches_envelope() -> None: + """Spec-mandated: an HTTP version header equal to the envelope version passes rung 3.""" + result = classify_inbound_request(envelope(), headers={MCP_PROTOCOL_VERSION_HEADER: MODERN}) + assert isinstance(result, InboundModernRoute) + + +@pytest.mark.parametrize( + "headers", + [ + pytest.param({MCP_PROTOCOL_VERSION_HEADER: LATEST_PROTOCOL_VERSION}, id="mismatch"), + pytest.param({}, id="header-absent"), + ], +) +def test_header_rung_rejects_on_disagreement(headers: dict[str, str]) -> None: + """Spec-mandated: an absent or mismatched HTTP version header rejects HEADER_MISMATCH.""" + assert_rejected(classify_inbound_request(envelope(), headers=headers), HEADER_MISMATCH) + + +# --- all rungs pass ------------------------------------------------------------ + + +def test_all_rungs_pass_yields_route() -> None: + """Spec-mandated: a complete envelope at a supported version with agreeing header routes, surfacing the envelope.""" + result = classify_inbound_request(envelope(), headers={MCP_PROTOCOL_VERSION_HEADER: MODERN}) + assert isinstance(result, InboundModernRoute) + assert result.protocol_version == MODERN + assert result.client_info == CLIENT_INFO + assert result.client_capabilities == CLIENT_CAPS + + +@pytest.mark.parametrize("method", ["initialize", "myorg/custom", "does/not/exist"]) +def test_classifier_passes_unknown_method_through_to_route(method: str) -> None: + """SDK-defined: the classifier does not gate on method — kernel dispatch is the single owner of that decision.""" + result = classify_inbound_request(envelope(method), headers={MCP_PROTOCOL_VERSION_HEADER: MODERN}) + assert isinstance(result, InboundModernRoute) + + +def test_ladder_first_failure_wins() -> None: + """Spec-mandated: rungs evaluate in order — header-mismatch and version-unsupported + would both fail; the header rung fires first so an inconsistent client is told it + disagrees with itself rather than that its body version is unsupported.""" + body = envelope(version=LATEST_PROTOCOL_VERSION) + result = classify_inbound_request(body, headers={MCP_PROTOCOL_VERSION_HEADER: MODERN}) + assert_rejected(result, HEADER_MISMATCH) + + +# --- ERROR_CODE_HTTP_STATUS ---------------------------------------------------- + + +@pytest.mark.parametrize( + ("code", "status"), + [ + (PARSE_ERROR, 400), + (INVALID_REQUEST, 400), + (INVALID_PARAMS, 400), + (HEADER_MISMATCH, 400), + (MISSING_REQUIRED_CLIENT_CAPABILITY, 400), + (UNSUPPORTED_PROTOCOL_VERSION, 400), + (METHOD_NOT_FOUND, 404), + ], +) +def test_error_code_http_status_table(code: int, status: int) -> None: + """SDK-defined: pins the JSON-RPC error code → HTTP status mapping the streamable transport reads.""" + assert ERROR_CODE_HTTP_STATUS[code] == status + + +def test_error_code_http_status_covers_every_ladder_code() -> None: + """SDK-defined: every code the ladder can emit has an HTTP-status entry, so the transport never has to default.""" + ladder_codes = {INVALID_PARAMS, UNSUPPORTED_PROTOCOL_VERSION, HEADER_MISMATCH} + assert ladder_codes <= ERROR_CODE_HTTP_STATUS.keys() + + +# --- shape invariants ---------------------------------------------------------- + + +def test_verdict_dataclasses_are_frozen() -> None: + """SDK-defined: both verdict dataclasses are frozen so a route/rejection cannot be mutated after classification.""" + route = classify_inbound_request(envelope()) + assert isinstance(route, InboundModernRoute) + rejection = InboundLadderRejection(code=METHOD_NOT_FOUND, message="m") + for verdict in (route, rejection): + with pytest.raises(dataclasses.FrozenInstanceError): + setattr(verdict, "message", "mutated") diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index a3273add58..7ceac8e869 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -50,6 +50,8 @@ from mcp.shared.session import RequestResponder from mcp.types import ( DEFAULT_NEGOTIATED_VERSION, + INVALID_PARAMS, + INVALID_REQUEST, CallToolRequestParams, CallToolResult, InitializeResult, @@ -1045,8 +1047,10 @@ async def test_streamable_http_client_session_termination(basic_app: Starlette) ): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch # Attempt to make a request after termination - with pytest.raises(MCPError, match="Session terminated"): # pragma: no branch + with pytest.raises(MCPError) as exc_info: # pragma: no branch await session.list_tools() + assert exc_info.value.error.code == INVALID_REQUEST + assert "terminated" in exc_info.value.error.message.lower() @pytest.mark.anyio @@ -1106,8 +1110,10 @@ async def mock_delete(self: httpx.AsyncClient, *args: Any, **kwargs: Any) -> htt ): async with ClientSession(read_stream, write_stream) as session: # pragma: no branch # Attempt to make a request after termination - with pytest.raises(MCPError, match="Session terminated"): # pragma: no branch + with pytest.raises(MCPError) as exc_info: # pragma: no branch await session.list_tools() + assert exc_info.value.error.code == INVALID_REQUEST + assert "terminated" in exc_info.value.error.message.lower() @pytest.mark.anyio @@ -1388,7 +1394,7 @@ async def test_streamablehttp_stateless_ctx_protocol_version_tracks_the_header( assert response.status_code == 200 echoed = json.loads(first_sse_data(response)["result"]["content"][0]["text"]) assert echoed["protocol_version"] == expected - assert echoed["session_protocol_version"] is None + assert echoed["session_protocol_version"] == expected @pytest.mark.anyio @@ -1503,7 +1509,8 @@ async def test_server_validates_protocol_version_header(basic_app: Starlette) -> session_id = init_response.headers.get(MCP_SESSION_ID_HEADER) assert session_id is not None - # Test request with invalid protocol version (should fail) + # An unrecognised header value routes to the modern entry, where the + # validation ladder rejects an envelope-less body at rung 1. response = await client.post( "/mcp", headers={ @@ -1515,21 +1522,7 @@ async def test_server_validates_protocol_version_header(basic_app: Starlette) -> json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-2"}, ) assert response.status_code == 400 - assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower() - - # Test request with unsupported protocol version (should fail) - response = await client.post( - "/mcp", - headers={ - "Accept": "application/json, text/event-stream", - "Content-Type": "application/json", - MCP_SESSION_ID_HEADER: session_id, - MCP_PROTOCOL_VERSION_HEADER: "1999-01-01", # Very old unsupported version - }, - json={"jsonrpc": "2.0", "method": "tools/list", "id": "test-3"}, - ) - assert response.status_code == 400 - assert MCP_PROTOCOL_VERSION_HEADER in response.text or "protocol version" in response.text.lower() + assert response.json()["error"]["code"] == INVALID_PARAMS # Test request with valid protocol version (should succeed) negotiated_version = extract_protocol_version_from_sse(init_response)