Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 81 additions & 73 deletions src/mcp/server/streamable_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@
from collections.abc import AsyncGenerator, Awaitable, Callable
from contextlib import asynccontextmanager
from dataclasses import dataclass
from functools import partial
from http import HTTPStatus
from typing import Any
from typing import Any, Final

import anyio
import pydantic_core
Expand Down Expand Up @@ -59,13 +60,20 @@
# Special key for the standalone GET stream
GET_STREAM_KEY = "_GET_stream"

# Buffer for the per-request `_request_streams` so the serial `message_router`
# can deposit a response and move on instead of head-of-line blocking the
# whole session on a lazily-started `sse_writer`. See #1764.
REQUEST_STREAM_BUFFER_SIZE: Final = 16

# Session ID validation pattern (visible ASCII characters ranging from 0x21 to 0x7E)
# Pattern ensures entire string contains only valid characters by using ^ and $ anchors
SESSION_ID_PATTERN = re.compile(r"^[\x21-\x7E]+$")

# Type aliases
StreamId = str
EventId = str
# An SSE event-dict as accepted by sse-starlette (`event`, `data`, `id`, `retry`).
SSEEvent = dict[str, Any]


@dataclass
Expand Down Expand Up @@ -169,7 +177,7 @@
MemoryObjectReceiveStream[EventMessage],
],
] = {}
self._sse_stream_writers: dict[RequestId, MemoryObjectSendStream[dict[str, str]]] = {}
self._sse_stream_writers: dict[RequestId, MemoryObjectSendStream[SSEEvent]] = {}
self._terminated = False
# Idle timeout cancel scope; managed by the session manager.
self.idle_scope: anyio.CancelScope | None = None
Expand Down Expand Up @@ -257,31 +265,48 @@

return SessionMessage(message, metadata=metadata)

async def _maybe_send_priming_event(
self,
request_id: RequestId,
sse_stream_writer: MemoryObjectSendStream[dict[str, Any]],
protocol_version: str,
) -> None:
"""Send priming event for SSE resumability if event_store is configured.
async def _mint_priming_event(self, stream_id: StreamId, protocol_version: str) -> SSEEvent | None:
"""Store the priming cursor for `stream_id` and return its SSE wire form.

Only sends priming events to clients with protocol version >= 2025-11-25,
which includes the fix for handling empty SSE data. Older clients would
crash trying to parse empty data as JSON.
Called before the request is dispatched so the priming row precedes
anything `message_router` can store for this stream. Returns `None`
when no event store is configured or the client predates 2025-11-25
(older clients cannot parse the empty-data event).
"""
if not self._event_store:
return
# Priming events have empty data which older clients cannot handle.
return None
if not is_version_at_least(protocol_version, "2025-11-25"):
return
priming_event_id = await self._event_store.store_event(
str(request_id), # Convert RequestId to StreamId (str)
None, # Priming event has no payload
)
priming_event: dict[str, str | int] = {"id": priming_event_id, "data": ""}
return None
priming_event_id = await self._event_store.store_event(stream_id, None)
priming_event: SSEEvent = {"id": priming_event_id, "data": ""}
if self._retry_interval is not None:
priming_event["retry"] = self._retry_interval
await sse_stream_writer.send(priming_event)
return priming_event

async def _run_sse_writer(
self,
request_id: RequestId,
sse_stream_writer: MemoryObjectSendStream[SSEEvent],
request_stream_reader: MemoryObjectReceiveStream[EventMessage],
priming_event: SSEEvent | None,
) -> None:
"""Forward `_request_streams[request_id]` onto the SSE wire for one POST."""
try:
async with sse_stream_writer, request_stream_reader:
if priming_event is not None:
await sse_stream_writer.send(priming_event)
async for event_message in request_stream_reader:
await sse_stream_writer.send(self._create_event_data(event_message))
if isinstance(event_message.message, JSONRPCResponse | JSONRPCError):
break
except anyio.ClosedResourceError: # pragma: lax no cover
logger.debug("SSE stream closed by close_sse_stream()")
except Exception: # pragma: lax no cover
logger.exception("Error in SSE writer")
finally:
logger.debug("Closing SSE writer")
self._sse_stream_writers.pop(request_id, None)
await self._clean_up_memory_streams(request_id)

def _create_error_response(
self,
Expand Down Expand Up @@ -335,7 +360,7 @@
"""Extract the session ID from request headers."""
return request.headers.get(MCP_SESSION_ID_HEADER)

def _create_event_data(self, event_message: EventMessage) -> dict[str, str]:
def _create_event_data(self, event_message: EventMessage) -> SSEEvent:
"""Create event data dictionary from an EventMessage."""
event_data = {
"event": "message",
Expand Down Expand Up @@ -525,13 +550,13 @@
else request.headers.get(MCP_PROTOCOL_VERSION_HEADER, DEFAULT_NEGOTIATED_VERSION)
)

# Extract the request ID outside the try block for proper scope
request_id = str(message.id)
# Register this stream for the request ID
self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](0)
request_stream_reader = self._request_streams[request_id][1]

if self.is_json_response_enabled:
self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](
REQUEST_STREAM_BUFFER_SIZE
)
request_stream_reader = self._request_streams[request_id][1]
# Process the message
metadata = ServerMessageMetadata(request_context=request, protocol_version=protocol_version)
session_message = SessionMessage(message, metadata=metadata)
Expand Down Expand Up @@ -575,41 +600,19 @@
finally:
await self._clean_up_memory_streams(request_id)
else:
# Create SSE stream
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
# Mint the priming event before any per-request state exists:
# `EventStore.store_event` is user code and may raise, in which
# case the outer handler returns a 500 with nothing to clean up.
# Still strictly precedes dispatch, so storage order == wire order.
priming_event = await self._mint_priming_event(request_id, protocol_version)

# Store writer reference so close_sse_stream() can close it
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[SSEEvent](0)
self._sse_stream_writers[request_id] = sse_stream_writer
self._request_streams[request_id] = anyio.create_memory_object_stream[EventMessage](
REQUEST_STREAM_BUFFER_SIZE
)
request_stream_reader = self._request_streams[request_id][1]

async def sse_writer():
# Get the request ID from the incoming request message
try:
async with sse_stream_writer, request_stream_reader:
# Send priming event for SSE resumability
await self._maybe_send_priming_event(request_id, sse_stream_writer, protocol_version)

# Process messages from the request-specific stream
async for event_message in request_stream_reader:
# Build the event data
event_data = self._create_event_data(event_message)
await sse_stream_writer.send(event_data)

# If response, remove from pending streams and close
if isinstance(event_message.message, JSONRPCResponse | JSONRPCError):
break
except anyio.ClosedResourceError: # pragma: lax no cover
# Expected when close_sse_stream() is called
logger.debug("SSE stream closed by close_sse_stream()")
except Exception: # pragma: lax no cover
logger.exception("Error in SSE writer")
finally:
logger.debug("Closing SSE writer")
self._sse_stream_writers.pop(request_id, None)
await self._clean_up_memory_streams(request_id)

# Create and start EventSourceResponse
# SSE stream mode (original behavior)
# Set up headers
headers = {
"Cache-Control": "no-cache, no-transform",
"Connection": "keep-alive",
Expand All @@ -618,7 +621,9 @@
}
response = EventSourceResponse(
content=sse_stream_reader,
data_sender_callable=sse_writer,
data_sender_callable=partial(
self._run_sse_writer, request_id, sse_stream_writer, request_stream_reader, priming_event
),
headers=headers,
)

Expand All @@ -637,20 +642,16 @@
finally:
await sse_stream_reader.aclose()

except Exception as err: # pragma: lax no cover
# Reached only when something raises during POST handling outside
# the per-SSE-stream guard above; whether tests reach this depends
# on client teardown timing.
except Exception as err:
logger.exception("Error handling POST request")
response = self._create_error_response(
f"Error handling POST request: {err}",
HTTPStatus.INTERNAL_SERVER_ERROR,
INTERNAL_ERROR,
)
await response(scope, receive, send)
if writer:
await writer.send(Exception(err))
return # pragma: no cover
await writer.send(Exception(err))
return

async def _handle_get_request(self, request: Request, send: Send) -> None:
"""Handle GET request to establish SSE.
Expand Down Expand Up @@ -701,13 +702,15 @@
return

# Create SSE stream
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[SSEEvent](0)

async def standalone_sse_writer():
try:
# Create a standalone message stream for server-initiated messages

self._request_streams[GET_STREAM_KEY] = anyio.create_memory_object_stream[EventMessage](0)
self._request_streams[GET_STREAM_KEY] = anyio.create_memory_object_stream[EventMessage](
REQUEST_STREAM_BUFFER_SIZE
)
standalone_stream_reader = self._request_streams[GET_STREAM_KEY][1]

async with sse_stream_writer, standalone_stream_reader:
Expand Down Expand Up @@ -894,11 +897,10 @@
if self.mcp_session_id: # pragma: no branch
headers[MCP_SESSION_ID_HEADER] = self.mcp_session_id

# Get protocol version from header (already validated in _validate_protocol_version)
replay_protocol_version = request.headers.get(MCP_PROTOCOL_VERSION_HEADER, DEFAULT_NEGOTIATED_VERSION)

# Create SSE stream for replay
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[dict[str, str]](0)
sse_stream_writer, sse_stream_reader = anyio.create_memory_object_stream[SSEEvent](0)

async def replay_sender():
try:
Expand All @@ -913,14 +915,20 @@

# If stream ID not in mapping, create it
if stream_id and stream_id not in self._request_streams: # pragma: no branch
# Register SSE writer so close_sse_stream() can close it
self._sse_stream_writers[stream_id] = sse_stream_writer

# Send priming event for this new connection
await self._maybe_send_priming_event(stream_id, sse_stream_writer, replay_protocol_version)
# Prime the resumed connection so the client sees the stream
# is re-registered. The replay→live-tail ordering window here
# is pre-existing and tracked separately.
priming_event = await self._mint_priming_event(stream_id, replay_protocol_version)
if priming_event is not None:
await sse_stream_writer.send(priming_event)

# Create new request streams for this connection
self._request_streams[stream_id] = anyio.create_memory_object_stream[EventMessage](0)
self._request_streams[stream_id] = anyio.create_memory_object_stream[EventMessage](
REQUEST_STREAM_BUFFER_SIZE
)

Check notice on line 931 in src/mcp/server/streamable_http.py

View check run for this annotation

Claude / Claude Code Review

replay_sender lacks finally cleanup, leaking _sse_stream_writers/_request_streams entries (pre-existing)

Pre-existing issue (not introduced by this PR, but adjacent to the priming logic rewritten here): `replay_sender` has no `finally` cleanup, so when a resumed SSE connection ends, the closed writer stays in `_sse_stream_writers[stream_id]` and the `_request_streams[stream_id]` entry can linger — long-lived sessions with many resumptions accumulate stale entries. A `finally` mirroring `_run_sse_writer` (pop `_sse_stream_writers` + `_clean_up_memory_streams`) would close the gap, either here or in
Comment on lines 918 to +931

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🟣 Pre-existing issue (not introduced by this PR, but adjacent to the priming logic rewritten here): replay_sender has no finally cleanup, so when a resumed SSE connection ends, the closed writer stays in _sse_stream_writers[stream_id] and the _request_streams[stream_id] entry can linger — long-lived sessions with many resumptions accumulate stale entries. A finally mirroring _run_sse_writer (pop _sse_stream_writers + _clean_up_memory_streams) would close the gap, either here or in the separately-tracked replay follow-up.

Extended reasoning...

What the bug is. The replay path registers per-stream state but never tears it down. In _replay_eventsreplay_sender, the code does self._sse_stream_writers[stream_id] = sse_stream_writer (line 919), mints/sends the priming event, and creates self._request_streams[stream_id] (line ~929). The other two SSE writer paths clean up after themselves: _run_sse_writer has a finally that pops _sse_stream_writers[request_id] and calls _clean_up_memory_streams(request_id), and standalone_sse_writer's finally calls _clean_up_memory_streams(GET_STREAM_KEY). replay_sender has no finally at all — the outer finally in _replay_events only closes the wire-level sse_stream_writer/sse_stream_reader, not the dict entries.

The code path that triggers it. The common case is the normal resume lifecycle: a client reconnects with Last-Event-ID, the replay completes, the resumed request finishes (its response is delivered over the replay stream), and the client drops the GET. sse-starlette tears down the response and replay_sender unwinds — but nothing removes _sse_stream_writers[stream_id] (only close_sse_stream() ever pops it on this path) or _request_streams[stream_id].

Why existing code doesn't prevent it. _request_streams[stream_id] is reaped lazily only if message_router later attempts a send to that id and hits BrokenResourceError — which for a request that has already completed never happens, so it persists until terminate() or connect() teardown. _sse_stream_writers is worse: neither terminate() nor connect()'s finally clears it, so those entries persist for the transport's lifetime. (One refinement to the original finding: the router does not keep buffering 16 messages and then block on the dead stream — replay_sender's async with msg_reader closes the receive end on unwind, so a stray routed message raises BrokenResourceError immediately and the entry self-heals. The leak is stale dict entries, not a wedge.)

Concrete walk-through. (1) Client POSTs tools/call id=7; tool emits notifications; client disconnects mid-stream. (2) Client reconnects: GET /mcp with Last-Event-ID. replay_sender replays missed events, registers _sse_stream_writers["7"] and _request_streams["7"], and tails live messages. (3) The tool finishes; the router stores and deposits the JSONRPCResponse; replay_sender forwards it; the client closes the GET. (4) replay_sender unwinds — but _sse_stream_writers["7"] still holds the closed writer and _request_streams["7"] still holds closed streams. Repeat for ids 8, 9, 10… on a long-lived session and the dicts grow monotonically. A secondary wrinkle: the writer is registered before _mint_priming_event, so a user EventStore.store_event exception (swallowed by the broad except Exception) leaves an orphaned _sse_stream_writers entry with no matching _request_streams entry. A later resume for the same stream id also finds the stale _request_streams entry and skips re-registration/priming.

Impact. Bounded, slow growth — one small dict entry plus a closed memory-stream pair per resumed request, capped by session lifetime for _request_streams and by transport lifetime for _sse_stream_writers. Not a correctness wedge, which is why this is flagged as pre-existing rather than blocking.

How to fix. Give replay_sender a finally mirroring _run_sse_writer: self._sse_stream_writers.pop(stream_id, None) (guarded for the case where stream_id was never resolved) and await self._clean_up_memory_streams(stream_id). Since this PR already rewrote the priming logic in this exact block and notes the replay-path ordering window is tracked separately, this could ride along there if preferred.

msg_reader = self._request_streams[stream_id][1]

# Forward messages to SSE
Expand Down
40 changes: 40 additions & 0 deletions tests/interaction/transports/test_hosting_resume.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,46 @@ async def count(ctx: Context) -> str:
)


@requirement("hosting:resume:priming")
async def test_a_pre_2025_11_25_reconnect_replays_without_minting_a_priming_event() -> None:
"""A pre-2025-11-25 client reconnecting via Last-Event-ID gets the replay with no priming row.

The store-length assertion is the load-bearing proof that no priming cursor was minted.
"""
release = anyio.Event()
store = SequencedEventStore()
mcp = MCPServer("resumable")

@mcp.tool()
async def count(ctx: Context) -> str:
await ctx.info("tick 1") # pyright: ignore[reportDeprecated]
await release.wait()
await ctx.info("tick 2") # pyright: ignore[reportDeprecated]
return "counted"

async with mounted_app(mcp, event_store=store, retry_interval=0) as (http, _):
session_id = await initialize_via_http(http)
with anyio.fail_after(5):
async with http.stream(
"POST", "/mcp", content=_tools_call(1, "count", {}), headers=base_headers(session_id=session_id)
) as response:
_, first = await _read_events(response, 2)
release.set()
await store.wait_until_stored(6)
old_client_headers = base_headers(session_id=session_id) | {
"mcp-protocol-version": "2025-06-18",
"last-event-id": first.id,
}
async with http.stream("GET", "/mcp", headers=old_client_headers) as replay: # pragma: no branch
assert replay.status_code == 200
missed = await _read_events(replay, 2)

assert [(event.id, bool(event.data)) for event in missed] == snapshot([("5", True), ("6", True)])
# No priming cursor was minted on reconnect: the store still holds only the six rows
# written before the GET (init priming+response, POST priming, tick 1, tick 2, result).
assert len(store._events) == 6


@requirement("hosting:resume:bad-event-id")
async def test_an_unknown_last_event_id_yields_an_empty_replay_stream() -> None:
"""A Last-Event-ID the event store cannot map produces an empty SSE stream rather than an error.
Expand Down
Loading
Loading