diff --git a/src/mcp/client/session.py b/src/mcp/client/session.py index 4b24e98b1d..7d08a1b96f 100644 --- a/src/mcp/client/session.py +++ b/src/mcp/client/session.py @@ -85,6 +85,8 @@ async def __call__( async def _default_message_handler( message: RequestResponder[types.ServerRequest, types.ClientResult] | types.ServerNotification | Exception, ) -> None: + if isinstance(message, Exception): + raise message await anyio.lowlevel.checkpoint() diff --git a/tests/client/test_session.py b/tests/client/test_session.py index c171360de2..7c2446e9f3 100644 --- a/tests/client/test_session.py +++ b/tests/client/test_session.py @@ -1476,3 +1476,14 @@ async def test_send_notification_after_close_is_dropped_silently(): finally: for s in (s2c_send, s2c_recv, c2s_send, c2s_recv): s.close() + + +@pytest.mark.anyio +async def test_default_message_handler(): + from mcp.client.session import _default_message_handler + + with pytest.raises(ValueError, match="test error"): + await _default_message_handler(ValueError("test error")) + + # Should not raise for non-exception + await _default_message_handler(types.ToolListChangedNotification()) diff --git a/tests/transports/stdio/test_lifecycle.py b/tests/transports/stdio/test_lifecycle.py index 8a370c10f6..f185d77900 100644 --- a/tests/transports/stdio/test_lifecycle.py +++ b/tests/transports/stdio/test_lifecycle.py @@ -228,7 +228,7 @@ def test_fallback_process_reports_death_through_returncode_without_a_wait_call() try: process = FallbackProcess(popen) - os.waitid(os.P_PID, popen.pid, os.WEXITED | os.WNOWAIT) + os.waitid(os.P_PID, popen.pid, os.WEXITED | os.WNOWAIT) # pyright: ignore[reportAttributeAccessIssue, reportUnknownMemberType] assert process.returncode == 0 finally: popen.stdin.close()