diff --git a/src/mcp/server/streamable_http.py b/src/mcp/server/streamable_http.py index 9103996a52..5fc05051c9 100644 --- a/src/mcp/server/streamable_http.py +++ b/src/mcp/server/streamable_http.py @@ -12,8 +12,9 @@ from collections.abc import AsyncGenerator, Awaitable, Callable from contextlib import asynccontextmanager from dataclasses import dataclass +from functools import partial from http import HTTPStatus -from typing import Any +from typing import Any, Final import anyio import pydantic_core @@ -59,6 +60,11 @@ # Special key for the standalone GET stream GET_STREAM_KEY = "_GET_stream" +# Buffer for the per-request `_request_streams` so the serial `message_router` +# can deposit a response and move on instead of head-of-line blocking the +# whole session on a lazily-started `sse_writer`. See #1764. +REQUEST_STREAM_BUFFER_SIZE: Final = 16 + # Session ID validation pattern (visible ASCII characters ranging from 0x21 to 0x7E) # Pattern ensures entire string contains only valid characters by using ^ and $ anchors SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$") @@ -66,6 +72,8 @@ # Type aliases StreamId = str EventId = str +# An SSE event-dict as accepted by sse-starlette (`event`, `data`, `id`, `retry`). +SSEEvent = dict[str, Any] @dataclass @@ -169,7 +177,7 @@ def __init__( MemoryObjectReceiveStream[EventMessage], ], ] = {} - self._sse_stream_writers: dict[RequestId, MemoryObjectSendStream[dict[str, str]]] = {} + self._sse_stream_writers: dict[RequestId, MemoryObjectSendStream[SSEEvent]] = {} self._terminated = False # Idle timeout cancel scope; managed by the session manager. self.idle_scope: anyio.CancelScope | None = None @@ -257,31 +265,48 @@ async def close_standalone_stream_callback() -> None: return SessionMessage(message, metadata=metadata) - async def _maybe_send_priming_event( - self, - request_id: RequestId, - sse_stream_writer: MemoryObjectSendStream[dict[str, Any]], - protocol_version: str, - ) -> None: - """Send priming event for SSE resumability if event_store is configured. + async def _mint_priming_event(self, stream_id: StreamId, protocol_version: str) -> SSEEvent | None: + """Store the priming cursor for `stream_id` and return its SSE wire form. - Only sends priming events to clients with protocol version >= 2025-11-25, - which includes the fix for handling empty SSE data. Older clients would - crash trying to parse empty data as JSON. + Called before the request is dispatched so the priming row precedes + anything `message_router` can store for this stream. Returns `None` + when no event store is configured or the client predates 2025-11-25 + (older clients cannot parse the empty-data event). """ if not self._event_store: - return - # Priming events have empty data which older clients cannot handle. + return None if not is_version_at_least(protocol_version, "2025-11-25"): - return - priming_event_id = await self._event_store.store_event( - str(request_id), # Convert RequestId to StreamId (str) - None, # Priming event has no payload - ) - priming_event: dict[str, str | int] = {"id": priming_event_id, "data": ""} + return None + priming_event_id = await self._event_store.store_event(stream_id, None) + priming_event: SSEEvent = {"id": priming_event_id, "data": ""} if self._retry_interval is not None: priming_event["retry"] = self._retry_interval - await sse_stream_writer.send(priming_event) + return priming_event + + async def _run_sse_writer( + self, + request_id: RequestId, + sse_stream_writer: MemoryObjectSendStream[SSEEvent], + request_stream_reader: MemoryObjectReceiveStream[EventMessage], + priming_event: SSEEvent | None, + ) -> None: + """Forward `_request_streams[request_id]` onto the SSE wire for one POST.""" + try: + async with sse_stream_writer, request_stream_reader: + if priming_event is not None: + await sse_stream_writer.send(priming_event) + async for event_message in request_stream_reader: + await sse_stream_writer.send(self._create_event_data(event_message)) + if isinstance(event_message.message, JSONRPCResponse | JSONRPCError): + break + except anyio.ClosedResourceError: # pragma: lax no cover + logger.debug("SSE stream closed by close_sse_stream()") + except Exception: # pragma: lax no cover + logger.exception("Error in SSE writer") + finally: + logger.debug("Closing SSE writer") + self._sse_stream_writers.pop(request_id, None) + await self._clean_up_memory_streams(request_id) def _create_error_response( self, @@ -335,7 +360,7 @@ def _get_session_id(self, request: Request) -> str | None: """Extract the session ID from request headers.""" return request.headers.get(MCP_SESSION_ID_HEADER) - def _create_event_data(self, event_message: EventMessage) -> dict[str, str]: + def _create_event_data(self, event_message: EventMessage) -> SSEEvent: """Create event data dictionary from an EventMessage.""" event_data = { "event": "message", @@ -525,13 +550,13 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re else request.headers.get(MCP_PROTOCOL_VERSION_HEADER, DEFAULT_NEGOTIATED_VERSION) ) - # Extract the request ID outside the try block for proper scope request_id = str(message.id) - # Register this stream for the request ID - self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](0) - request_stream_reader = self._request_streams[request_id][1] if self.is_json_response_enabled: + self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage]( + REQUEST_STREAM_BUFFER_SIZE + ) + request_stream_reader = self._request_streams[request_id][1] # Process the message metadata = ServerMessageMetadata(request_context=request, protocol_version=protocol_version) session_message = SessionMessage(message, metadata=metadata) @@ -575,41 +600,19 @@ async def _handle_post_request(self, scope: Scope, request: Request, receive: Re finally: await self._clean_up_memory_streams(request_id) else: - # Create SSE stream - sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) + # Mint the priming event before any per-request state exists: + # `EventStore.store_event` is user code and may raise, in which + # case the outer handler returns a 500 with nothing to clean up. + # Still strictly precedes dispatch, so storage order == wire order. + priming_event = await self._mint_priming_event(request_id, protocol_version) - # Store writer reference so close_sse_stream() can close it + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[SSEEvent](0) self._sse_stream_writers[request_id] = sse_stream_writer + self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage]( + REQUEST_STREAM_BUFFER_SIZE + ) + request_stream_reader = self._request_streams[request_id][1] - async def sse_writer(): - # Get the request ID from the incoming request message - try: - async with sse_stream_writer, request_stream_reader: - # Send priming event for SSE resumability - await self._maybe_send_priming_event(request_id, sse_stream_writer, protocol_version) - - # Process messages from the request-specific stream - async for event_message in request_stream_reader: - # Build the event data - event_data = self._create_event_data(event_message) - await sse_stream_writer.send(event_data) - - # If response, remove from pending streams and close - if isinstance(event_message.message, JSONRPCResponse | JSONRPCError): - break - except anyio.ClosedResourceError: # pragma: lax no cover - # Expected when close_sse_stream() is called - logger.debug("SSE stream closed by close_sse_stream()") - except Exception: # pragma: lax no cover - logger.exception("Error in SSE writer") - finally: - logger.debug("Closing SSE writer") - self._sse_stream_writers.pop(request_id, None) - await self._clean_up_memory_streams(request_id) - - # Create and start EventSourceResponse - # SSE stream mode (original behavior) - # Set up headers headers = { "Cache-Control": "no-cache, no-transform", "Connection": "keep-alive", @@ -618,7 +621,9 @@ async def sse_writer(): } response = EventSourceResponse( content=sse_stream_reader, - data_sender_callable=sse_writer, + data_sender_callable=partial( + self._run_sse_writer, request_id, sse_stream_writer, request_stream_reader, priming_event + ), headers=headers, ) @@ -637,10 +642,7 @@ async def sse_writer(): finally: await sse_stream_reader.aclose() - except Exception as err: # pragma: lax no cover - # Reached only when something raises during POST handling outside - # the per-SSE-stream guard above; whether tests reach this depends - # on client teardown timing. + except Exception as err: logger.exception("Error handling POST request") response = self._create_error_response( f"Error handling POST request: {err}", @@ -648,9 +650,8 @@ async def sse_writer(): INTERNAL_ERROR, ) await response(scope, receive, send) - if writer: - await writer.send(Exception(err)) - return # pragma: no cover + await writer.send(Exception(err)) + return async def _handle_get_request(self, request: Request, send: Send) -> None: """Handle GET request to establish SSE. @@ -701,13 +702,15 @@ async def _handle_get_request(self, request: Request, send: Send) -> None: return # Create SSE stream - sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[SSEEvent](0) async def standalone_sse_writer(): try: # Create a standalone message stream for server-initiated messages - self._request_streams[GET_STREAM_KEY] = anyio.create_memory_object_stream[EventMessage](0) + self._request_streams[GET_STREAM_KEY] = anyio.create_memory_object_stream[EventMessage]( + REQUEST_STREAM_BUFFER_SIZE + ) standalone_stream_reader = self._request_streams[GET_STREAM_KEY][1] async with sse_stream_writer, standalone_stream_reader: @@ -894,11 +897,10 @@ async def _replay_events(self, last_event_id: str, request: Request, send: Send) if self.mcp_session_id: # pragma: no branch headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id - # Get protocol version from header (already validated in _validate_protocol_version) replay_protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER, DEFAULT_NEGOTIATED_VERSION) # Create SSE stream for replay - sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0) + sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[SSEEvent](0) async def replay_sender(): try: @@ -916,11 +918,17 @@ async def send_event(event_message: EventMessage) -> None: # Register SSE writer so close_sse_stream() can close it self._sse_stream_writers[stream_id] = sse_stream_writer - # Send priming event for this new connection - await self._maybe_send_priming_event(stream_id, sse_stream_writer, replay_protocol_version) + # Prime the resumed connection so the client sees the stream + # is re-registered. The replay→live-tail ordering window here + # is pre-existing and tracked separately. + priming_event = await self._mint_priming_event(stream_id, replay_protocol_version) + if priming_event is not None: + await sse_stream_writer.send(priming_event) # Create new request streams for this connection - self._request_streams[stream_id] = anyio.create_memory_object_stream[EventMessage](0) + self._request_streams[stream_id] = anyio.create_memory_object_stream[EventMessage]( + REQUEST_STREAM_BUFFER_SIZE + ) msg_reader = self._request_streams[stream_id][1] # Forward messages to SSE diff --git a/tests/interaction/transports/test_hosting_resume.py b/tests/interaction/transports/test_hosting_resume.py index b22df0ff2b..47e7781f85 100644 --- a/tests/interaction/transports/test_hosting_resume.py +++ b/tests/interaction/transports/test_hosting_resume.py @@ -182,6 +182,46 @@ async def count(ctx: Context) -> str: ) +@requirement("hosting:resume:priming") +async def test_a_pre_2025_11_25_reconnect_replays_without_minting_a_priming_event() -> None: + """A pre-2025-11-25 client reconnecting via Last-Event-ID gets the replay with no priming row. + + The store-length assertion is the load-bearing proof that no priming cursor was minted. + """ + release = anyio.Event() + store = SequencedEventStore() + mcp = MCPServer("resumable") + + @mcp.tool() + async def count(ctx: Context) -> str: + await ctx.info("tick 1") # pyright: ignore[reportDeprecated] + await release.wait() + await ctx.info("tick 2") # pyright: ignore[reportDeprecated] + return "counted" + + async with mounted_app(mcp, event_store=store, retry_interval=0) as (http, _): + session_id = await initialize_via_http(http) + with anyio.fail_after(5): + async with http.stream( + "POST", "/mcp", content=_tools_call(1, "count", {}), headers=base_headers(session_id=session_id) + ) as response: + _, first = await _read_events(response, 2) + release.set() + await store.wait_until_stored(6) + old_client_headers = base_headers(session_id=session_id) | { + "mcp-protocol-version": "2025-06-18", + "last-event-id": first.id, + } + async with http.stream("GET", "/mcp", headers=old_client_headers) as replay: # pragma: no branch + assert replay.status_code == 200 + missed = await _read_events(replay, 2) + + assert [(event.id, bool(event.data)) for event in missed] == snapshot([("5", True), ("6", True)]) + # No priming cursor was minted on reconnect: the store still holds only the six rows + # written before the GET (init priming+response, POST priming, tick 1, tick 2, result). + assert len(store._events) == 6 + + @requirement("hosting:resume:bad-event-id") async def test_an_unknown_last_event_id_yields_an_empty_replay_stream() -> None: """A Last-Event-ID the event store cannot map produces an empty SSE stream rather than an error. diff --git a/tests/server/test_streamable_http_router.py b/tests/server/test_streamable_http_router.py new file mode 100644 index 0000000000..c3b03a1f95 --- /dev/null +++ b/tests/server/test_streamable_http_router.py @@ -0,0 +1,158 @@ +"""Regression coverage for the StreamableHTTP per-session response router.""" + +import anyio +import pytest +from starlette.types import Message, Scope + +from mcp.server.streamable_http import ( + REQUEST_STREAM_BUFFER_SIZE, + EventCallback, + EventId, + EventMessage, + EventStore, + StreamableHTTPServerTransport, + StreamId, +) +from mcp.shared.message import SessionMessage +from mcp.types import JSONRPCMessage, JSONRPCResponse + + +class _OrderTrackingStore(EventStore): + def __init__(self) -> None: + self.stored: list[tuple[StreamId, JSONRPCMessage | None]] = [] + + async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) -> EventId: + self.stored.append((stream_id, message)) + return str(len(self.stored)) + + async def replay_events_after(self, last_event_id: EventId, send_callback: EventCallback) -> StreamId | None: + raise NotImplementedError + + +class _PrimingFailingStore(EventStore): + async def store_event(self, stream_id: StreamId, message: JSONRPCMessage | None) -> EventId: + raise RuntimeError("backend unavailable") + + async def replay_events_after(self, last_event_id: EventId, send_callback: EventCallback) -> StreamId | None: + raise NotImplementedError + + +@pytest.mark.anyio +async def test_router_unconsumed_request_stream_does_not_block_siblings() -> None: + """A response whose `sse_writer` is not yet receiving must not park the router (#1764). + + Drives the routing layer directly (the production race does not reproduce + on loopback), so this pins the router semantics, not the call sites. + """ + transport = StreamableHTTPServerTransport(mcp_session_id="sid", is_json_response_enabled=False) + streams = transport._request_streams + async with transport.connect() as (_read_stream, write_stream): + # Model two concurrent POSTs at the point _handle_post_request has + # registered the per-request stream but A's sse_writer has not yet + # reached its first receive(). + streams["A"] = anyio.create_memory_object_stream[EventMessage](REQUEST_STREAM_BUFFER_SIZE) + streams["B"] = anyio.create_memory_object_stream[EventMessage](REQUEST_STREAM_BUFFER_SIZE) + a_send, a_recv = streams["A"] + b_reader = streams["B"][1] + b_received = anyio.Event() + + async def consume_b() -> None: + async with b_reader: + await b_reader.receive() + b_received.set() + + async def server_writes() -> None: + await write_stream.send(SessionMessage(JSONRPCResponse(jsonrpc="2.0", id="A", result={}))) + await write_stream.send(SessionMessage(JSONRPCResponse(jsonrpc="2.0", id="B", result={}))) + + async with anyio.create_task_group() as tg: + tg.start_soon(consume_b) + tg.start_soon(server_writes) + with anyio.fail_after(5): + await b_received.wait() + # A's response was buffered for its (late) consumer, not dropped. + assert a_send.statistics().current_buffer_used == 1 + await a_recv.aclose() + await a_send.aclose() + + +@pytest.mark.anyio +async def test_priming_event_is_stored_before_any_routed_message() -> None: + """`_mint_priming_event` is awaited before the request is dispatched, so the + priming row precedes every `message_router` store for that stream regardless + of when `sse_writer` is scheduled. + """ + store = _OrderTrackingStore() + transport = StreamableHTTPServerTransport(mcp_session_id="sid", is_json_response_enabled=False, event_store=store) + streams = transport._request_streams + + async with transport.connect() as (_read_stream, write_stream): + # POST handler step: mint priming for "A" before dispatch. + priming = await transport._mint_priming_event("A", "2025-11-25") + assert priming is not None + streams["A"] = anyio.create_memory_object_stream[EventMessage](REQUEST_STREAM_BUFFER_SIZE) + a_send, a_recv = streams["A"] + + # Server emits 5 messages for "A" with no sse_writer scheduled. Each + # write_stream.send() rendezvous-hands to message_router, which stores + # then deposits into A's buffer; reading them back proves the router + # has finished storing. + for i in range(5): + await write_stream.send(SessionMessage(JSONRPCResponse(jsonrpc="2.0", id="A", result={"n": i}))) + with anyio.fail_after(5): + for _ in range(5): + await a_recv.receive() + await a_recv.aclose() + await a_send.aclose() + + assert store.stored[0] == ("A", None) + assert [sid for sid, _ in store.stored] == ["A"] * 6 + assert all(msg is not None for _, msg in store.stored[1:]) + + +@pytest.mark.anyio +async def test_priming_store_failure_leaves_no_per_request_state() -> None: + """`EventStore.store_event` raising on the priming row must not leak per-request entries.""" + transport = StreamableHTTPServerTransport( + mcp_session_id=None, + is_json_response_enabled=False, + event_store=_PrimingFailingStore(), + ) + + body = b'{"jsonrpc":"2.0","id":"req-1","method":"tools/list","params":{}}' + scope: Scope = { + "type": "http", + "method": "POST", + "path": "/", + "query_string": b"", + "headers": [ + (b"accept", b"application/json, text/event-stream"), + (b"content-type", b"application/json"), + (b"mcp-protocol-version", b"2025-11-25"), + ], + } + body_sent = False + + async def receive() -> Message: + nonlocal body_sent + if not body_sent: + body_sent = True + return {"type": "http.request", "body": body, "more_body": False} + raise NotImplementedError + + sent: list[Message] = [] + + async def asgi_send(message: Message) -> None: + sent.append(message) + + async with transport.connect() as (read_stream, _write_stream): + async with anyio.create_task_group() as tg: + tg.start_soon(transport.handle_request, scope, receive, asgi_send) + with anyio.fail_after(5): + forwarded = await read_stream.receive() + assert isinstance(forwarded, Exception) + + assert transport._request_streams == {} + assert transport._sse_stream_writers == {} + assert sent[0]["type"] == "http.response.start" + assert sent[0]["status"] == 500 diff --git a/tests/shared/test_streamable_http.py b/tests/shared/test_streamable_http.py index a3273add58..823639e143 100644 --- a/tests/shared/test_streamable_http.py +++ b/tests/shared/test_streamable_http.py @@ -1643,80 +1643,30 @@ async def test_handle_sse_event_skips_empty_data() -> None: @pytest.mark.anyio async def test_priming_event_not_sent_for_old_protocol_version() -> None: - """_maybe_send_priming_event skips for old protocol versions (backwards compat).""" - # Create a transport with an event store - transport = StreamableHTTPServerTransport( - "/mcp", - event_store=SimpleEventStore(), - ) - - # Create a mock stream writer - write_stream, read_stream = anyio.create_memory_object_stream[dict[str, Any]](1) - - try: - # Call _maybe_send_priming_event with OLD protocol version - should NOT send - await transport._maybe_send_priming_event("test-request-id", write_stream, "2025-06-18") - - # Nothing should have been written to the stream - assert write_stream.statistics().current_buffer_used == 0 - - # Now test with NEW protocol version - should send - await transport._maybe_send_priming_event("test-request-id-2", write_stream, "2025-11-25") - - # Should have written a priming event - assert write_stream.statistics().current_buffer_used == 1 - finally: - await write_stream.aclose() - await read_stream.aclose() + """`_mint_priming_event` skips for old protocol versions (backwards compat).""" + transport = StreamableHTTPServerTransport("/mcp", event_store=SimpleEventStore()) + assert await transport._mint_priming_event("test-request-id", "2025-06-18") is None + assert await transport._mint_priming_event("test-request-id-2", "2025-11-25") is not None @pytest.mark.anyio async def test_priming_event_not_sent_without_event_store() -> None: - """_maybe_send_priming_event returns early when no event_store is configured.""" - # Create a transport WITHOUT an event store + """`_mint_priming_event` returns `None` when no event_store is configured.""" transport = StreamableHTTPServerTransport("/mcp") - - # Create a mock stream writer - write_stream, read_stream = anyio.create_memory_object_stream[dict[str, Any]](1) - - try: - # Call _maybe_send_priming_event - should return early without sending - await transport._maybe_send_priming_event("test-request-id", write_stream, "2025-11-25") - - # Nothing should have been written to the stream - assert write_stream.statistics().current_buffer_used == 0 - finally: - await write_stream.aclose() - await read_stream.aclose() + assert await transport._mint_priming_event("test-request-id", "2025-11-25") is None @pytest.mark.anyio async def test_priming_event_includes_retry_interval() -> None: - """_maybe_send_priming_event includes the retry field when retry_interval is set.""" - # Create a transport with an event store AND retry_interval + """`_mint_priming_event` includes the retry field when `retry_interval` is set.""" transport = StreamableHTTPServerTransport( "/mcp", event_store=SimpleEventStore(), retry_interval=5000, ) - - # Create a mock stream writer - write_stream, read_stream = anyio.create_memory_object_stream[dict[str, Any]](1) - - try: - # Call _maybe_send_priming_event with new protocol version - await transport._maybe_send_priming_event("test-request-id", write_stream, "2025-11-25") - - # Should have written a priming event with retry field - assert write_stream.statistics().current_buffer_used == 1 - - # Read the event and verify it has retry field - event = await read_stream.receive() - assert "retry" in event - assert event["retry"] == 5000 - finally: - await write_stream.aclose() - await read_stream.aclose() + event = await transport._mint_priming_event("test-request-id", "2025-11-25") + assert event is not None + assert event["retry"] == 5000 @pytest.mark.anyio @@ -1753,26 +1703,13 @@ async def test_close_sse_stream_callback_not_provided_for_old_protocol_version() @pytest.mark.anyio async def test_priming_event_not_sent_for_unknown_protocol_version() -> None: - """_maybe_send_priming_event treats unrecognized version strings conservatively. + """`_mint_priming_event` treats unrecognized version strings conservatively. A garbage version must not be mistaken for a future one (lexicographically "zzz" sorts after every date-shaped revision). """ - transport = StreamableHTTPServerTransport( - "/mcp", - event_store=SimpleEventStore(), - ) - - write_stream, read_stream = anyio.create_memory_object_stream[dict[str, Any]](1) - - try: - await transport._maybe_send_priming_event("test-request-id", write_stream, "zzz") - - # Nothing should have been written to the stream - assert write_stream.statistics().current_buffer_used == 0 - finally: - await write_stream.aclose() - await read_stream.aclose() + transport = StreamableHTTPServerTransport("/mcp", event_store=SimpleEventStore()) + assert await transport._mint_priming_event("test-request-id", "zzz") is None @pytest.mark.anyio