diff --git a/conformance/src/bin/server.rs b/conformance/src/bin/server.rs index fda1ab938..078d76821 100644 --- a/conformance/src/bin/server.rs +++ b/conformance/src/bin/server.rs @@ -213,9 +213,9 @@ impl ServerHandler for ConformanceServer { &self, request: CallToolRequestParams, cx: RequestContext, - ) -> Result { + ) -> Result { let args = request.arguments.unwrap_or_default(); - match request.name.as_ref() { + let result = match request.name.as_ref() { "test_simple_text" => Ok(CallToolResult::success(vec![ContentBlock::text( "This is a simple text response for testing.", )])), @@ -530,7 +530,8 @@ impl ServerHandler for ConformanceServer { format!("Unknown tool: {}", request.name), None, )), - } + }; + result.map(Into::into) } async fn list_resources( @@ -555,9 +556,9 @@ impl ServerHandler for ConformanceServer { &self, request: ReadResourceRequestParams, _cx: RequestContext, - ) -> Result { + ) -> Result { let uri = request.uri.as_str(); - match uri { + let result = match uri { "test://static-text" => Ok(ReadResourceResult::new(vec![ ResourceContents::TextResourceContents { uri: uri.into(), @@ -598,7 +599,8 @@ impl ServerHandler for ConformanceServer { )) } } - } + }; + result.map(Into::into) } async fn list_resource_templates( @@ -679,8 +681,8 @@ impl ServerHandler for ConformanceServer { &self, request: GetPromptRequestParams, _cx: RequestContext, - ) -> Result { - match request.name.as_str() { + ) -> Result { + let result = match request.name.as_str() { "test_simple_prompt" => Ok(GetPromptResult::new(vec![PromptMessage::new_text( Role::User, "This is a simple test prompt.", @@ -724,7 +726,8 @@ impl ServerHandler for ConformanceServer { format!("Unknown prompt: {}", request.name), None, )), - } + }; + result.map(Into::into) } async fn complete( diff --git a/crates/rmcp-macros/src/prompt_handler.rs b/crates/rmcp-macros/src/prompt_handler.rs index af4f24bd8..24032eaab 100644 --- a/crates/rmcp-macros/src/prompt_handler.rs +++ b/crates/rmcp-macros/src/prompt_handler.rs @@ -35,7 +35,7 @@ pub fn prompt_handler(attr: TokenStream, input: TokenStream) -> syn::Result, - ) -> Result { + ) -> Result { let prompt_context = rmcp::handler::server::prompt::PromptContext::new( self, request.name, diff --git a/crates/rmcp-macros/src/tool_handler.rs b/crates/rmcp-macros/src/tool_handler.rs index dc935828d..1e39eb5f3 100644 --- a/crates/rmcp-macros/src/tool_handler.rs +++ b/crates/rmcp-macros/src/tool_handler.rs @@ -47,7 +47,7 @@ pub fn tool_handler(attr: TokenStream, input: TokenStream) -> syn::Result, - ) -> Result { + ) -> Result { let tcc = rmcp::handler::server::tool::ToolCallContext::new(self, request, context); #router.call(tcc).await } diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 9704bfc72..92879578f 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -22,6 +22,7 @@ features = [ "client-side-sse", "elicitation", "macros", + "request-state", "reqwest", "reqwest-native-tls", "reqwest-tls-no-provider", @@ -64,6 +65,10 @@ schemars = { version = "1.0", optional = true, features = ["chrono04"] } # for image encoding base64 = { version = "0.22", optional = true } +# for SEP-2322 requestState integrity sealing (opt-in via the `request-state` feature) +hmac = { version = "0.12", optional = true } +sha2 = { version = "0.10", optional = true } + # for HTTP client reqwest = { version = "0.13.2", default-features = false, features = [ "json", @@ -120,6 +125,9 @@ server = ["transport-async-rw", "dep:schemars", "dep:pastey"] macros = ["dep:rmcp-macros", "dep:pastey"] elicitation = ["dep:url"] +# SEP-2322 requestState integrity helper (HMAC-SHA256 seal/open codec) +request-state = ["dep:hmac", "dep:sha2", "base64"] + # reqwest http client __reqwest = ["dep:reqwest"] @@ -309,6 +317,11 @@ name = "test_trace_context" required-features = ["server", "client"] path = "tests/test_trace_context.rs" +[[test]] +name = "test_mrtr_behavior" +required-features = ["server", "client"] +path = "tests/test_mrtr_behavior.rs" + [[test]] name = "test_prompt_macros" required-features = ["server", "client"] diff --git a/crates/rmcp/src/handler/server.rs b/crates/rmcp/src/handler/server.rs index 3cec563e8..779901764 100644 --- a/crates/rmcp/src/handler/server.rs +++ b/crates/rmcp/src/handler/server.rs @@ -27,6 +27,9 @@ impl Service for H { ) -> Result<::Resp, McpError> { // `context` is moved into the dispatch below, so read the negotiated version first. let protocol_version = context.protocol_version(); + let mrtr_supported = protocol_version + .as_ref() + .is_some_and(|v| v.as_str() >= ProtocolVersion::V_2026_07_28.as_str()); let result = match request { ClientRequest::InitializeRequest(request) => self .initialize(request.params, context) @@ -46,7 +49,7 @@ impl Service for H { ClientRequest::GetPromptRequest(request) => self .get_prompt(request.params, context) .await - .map(ServerResult::GetPromptResult), + .map(ServerResult::from), ClientRequest::ListPromptsRequest(request) => self .list_prompts(request.params, context) .await @@ -62,7 +65,7 @@ impl Service for H { ClientRequest::ReadResourceRequest(request) => self .read_resource(request.params, context) .await - .map(ServerResult::ReadResourceResult), + .map(ServerResult::from), ClientRequest::SubscribeRequest(request) => self .subscribe(request.params, context) .await @@ -105,7 +108,7 @@ impl Service for H { } else { self.call_tool(request.params, context) .await - .map(ServerResult::CallToolResult) + .map(ServerResult::from) } } ClientRequest::ListToolsRequest(request) => self @@ -133,6 +136,17 @@ impl Service for H { .await .map(ServerResult::CancelTaskResult), }; + let result = result.and_then(|result| { + if matches!(result, ServerResult::InputRequiredResult(_)) && !mrtr_supported { + Err(McpError::invalid_request( + "InputRequiredResult requires negotiated protocol version 2026-07-28 or newer", + None, + )) + } else { + Ok(result) + } + }); + // SEP-2164: peers negotiating 2026-07-28+ get the standard INVALID_PARAMS code for // resource-not-found; older peers keep RESOURCE_NOT_FOUND. ISO `YYYY-MM-DD` versions // compare lexically the same as chronologically. @@ -229,7 +243,7 @@ macro_rules! server_handler_methods { &self, request: GetPromptRequestParams, context: RequestContext, - ) -> impl Future> + MaybeSendFuture + '_ { + ) -> impl Future> + MaybeSendFuture + '_ { std::future::ready(Err(McpError::method_not_found::())) } fn list_prompts( @@ -259,7 +273,7 @@ macro_rules! server_handler_methods { &self, request: ReadResourceRequestParams, context: RequestContext, - ) -> impl Future> + MaybeSendFuture + '_ { + ) -> impl Future> + MaybeSendFuture + '_ { std::future::ready(Err( McpError::method_not_found::(), )) @@ -312,7 +326,7 @@ macro_rules! server_handler_methods { &self, request: CallToolRequestParams, context: RequestContext, - ) -> impl Future> + MaybeSendFuture + '_ { + ) -> impl Future> + MaybeSendFuture + '_ { std::future::ready(Err(McpError::method_not_found::())) } fn list_tools( @@ -485,7 +499,7 @@ macro_rules! impl_server_handler_for_wrapper { &self, request: GetPromptRequestParams, context: RequestContext, - ) -> impl Future> + MaybeSendFuture + '_ { + ) -> impl Future> + MaybeSendFuture + '_ { (**self).get_prompt(request, context) } @@ -518,7 +532,7 @@ macro_rules! impl_server_handler_for_wrapper { &self, request: ReadResourceRequestParams, context: RequestContext, - ) -> impl Future> + MaybeSendFuture + '_ { + ) -> impl Future> + MaybeSendFuture + '_ { (**self).read_resource(request, context) } @@ -542,7 +556,7 @@ macro_rules! impl_server_handler_for_wrapper { &self, request: CallToolRequestParams, context: RequestContext, - ) -> impl Future> + MaybeSendFuture + '_ { + ) -> impl Future> + MaybeSendFuture + '_ { (**self).call_tool(request, context) } diff --git a/crates/rmcp/src/handler/server/prompt.rs b/crates/rmcp/src/handler/server/prompt.rs index ffce6b2e0..a75e02713 100644 --- a/crates/rmcp/src/handler/server/prompt.rs +++ b/crates/rmcp/src/handler/server/prompt.rs @@ -15,7 +15,7 @@ pub use super::common::{Extension, RequestId}; use crate::{ RoleServer, handler::server::wrapper::Parameters, - model::{GetPromptResult, PromptMessage}, + model::{GetPromptResponse, GetPromptResult, InputRequiredResult, PromptMessage}, service::{MaybeBoxFuture, MaybeSend, MaybeSendFuture, RequestContext}, }; @@ -59,12 +59,12 @@ pub trait GetPromptHandler { fn handle( self, context: PromptContext<'_, S>, - ) -> MaybeBoxFuture<'_, Result>; + ) -> MaybeBoxFuture<'_, Result>; } /// Type alias for dynamic prompt handlers #[cfg(not(feature = "local"))] -pub type DynGetPromptHandler = dyn for<'a> Fn(PromptContext<'a, S>) -> BoxFuture<'a, Result> +pub type DynGetPromptHandler = dyn for<'a> Fn(PromptContext<'a, S>) -> BoxFuture<'a, Result> + Send + Sync; @@ -73,7 +73,7 @@ pub type DynGetPromptHandler = dyn for<'a> Fn( PromptContext<'a, S>, ) -> futures::future::LocalBoxFuture< 'a, - Result, + Result, >; /// Adapter type for async methods that return `Vec` @@ -91,28 +91,35 @@ pub struct SyncPromptMethodAdapter(PhantomData R>); /// Trait for types that can be converted into GetPromptResult pub trait IntoGetPromptResult { - fn into_get_prompt_result(self) -> Result; + fn into_get_prompt_result(self) -> Result; } impl IntoGetPromptResult for GetPromptResult { - fn into_get_prompt_result(self) -> Result { - Ok(self) + fn into_get_prompt_result(self) -> Result { + Ok(self.into()) + } +} + +impl IntoGetPromptResult for InputRequiredResult { + fn into_get_prompt_result(self) -> Result { + Ok(self.into()) } } impl IntoGetPromptResult for Vec { - fn into_get_prompt_result(self) -> Result { + fn into_get_prompt_result(self) -> Result { Ok(GetPromptResult { result_type: Default::default(), description: None, messages: self, meta: None, - }) + } + .into()) } } impl IntoGetPromptResult for Result { - fn into_get_prompt_result(self) -> Result { + fn into_get_prompt_result(self) -> Result { self.and_then(|v| v.into_get_prompt_result()) } } @@ -129,7 +136,7 @@ pin_project_lite::pin_project! { }, Ready { #[pin] - result: futures::future::Ready>, + result: futures::future::Ready>, } } } @@ -139,7 +146,7 @@ where F: Future, R: IntoGetPromptResult, { - type Output = Result; + type Output = Result; fn poll( self: std::pin::Pin<&mut Self>, @@ -216,7 +223,7 @@ macro_rules! impl_prompt_handler_for { fn handle( self, mut context: PromptContext<'_, S>, - ) -> MaybeBoxFuture<'_, Result> + ) -> MaybeBoxFuture<'_, Result> { $( let result = $Tn::from_context_part(&mut context); @@ -249,7 +256,7 @@ macro_rules! impl_prompt_handler_for { fn handle( self, mut context: PromptContext<'_, S>, - ) -> MaybeBoxFuture<'_, Result> + ) -> MaybeBoxFuture<'_, Result> { $( let result = $Tn::from_context_part(&mut context); @@ -280,7 +287,7 @@ macro_rules! impl_prompt_handler_for { fn handle( self, mut context: PromptContext<'_, S>, - ) -> MaybeBoxFuture<'_, Result> + ) -> MaybeBoxFuture<'_, Result> { // Extract all parameters before moving into the async block $( @@ -315,7 +322,7 @@ macro_rules! impl_prompt_handler_for { fn handle( self, mut context: PromptContext<'_, S>, - ) -> MaybeBoxFuture<'_, Result> + ) -> MaybeBoxFuture<'_, Result> { $( let result = $Tn::from_context_part(&mut context); diff --git a/crates/rmcp/src/handler/server/router.rs b/crates/rmcp/src/handler/server/router.rs index 45ff9a586..e934137b8 100644 --- a/crates/rmcp/src/handler/server/router.rs +++ b/crates/rmcp/src/handler/server/router.rs @@ -106,7 +106,7 @@ where context, ); let result = self.tool_router.call(tool_call_context).await?; - Ok(ServerResult::CallToolResult(result)) + Ok(ServerResult::from(result)) } else { self.service .handle_request(ClientRequest::CallToolRequest(request), context) @@ -129,7 +129,7 @@ where context, ); let result = self.prompt_router.get_prompt(prompt_context).await?; - Ok(ServerResult::GetPromptResult(result)) + Ok(ServerResult::from(result)) } else { self.service .handle_request(ClientRequest::GetPromptRequest(request), context) @@ -193,7 +193,7 @@ mod tests { async fn test_router_deferred_notifier_e2e() { let mut router = Router::new(DummyHandler).with_tool(tool::ToolRoute::new_dyn( Tool::new("my_tool", "test", Arc::new(Default::default())), - |_ctx| Box::pin(async { Ok(CallToolResult::default()) }), + |_ctx| Box::pin(async { Ok(CallToolResult::default().into()) }), )); let id_provider: Arc = diff --git a/crates/rmcp/src/handler/server/router/prompt.rs b/crates/rmcp/src/handler/server/router/prompt.rs index e952b2a39..509fb4287 100644 --- a/crates/rmcp/src/handler/server/router/prompt.rs +++ b/crates/rmcp/src/handler/server/router/prompt.rs @@ -2,7 +2,7 @@ use std::{borrow::Cow, sync::Arc}; use crate::{ handler::server::prompt::{DynGetPromptHandler, GetPromptHandler, PromptContext}, - model::{GetPromptResult, Prompt}, + model::{GetPromptResponse, Prompt}, service::{MaybeBoxFuture, MaybeSend}, }; @@ -50,7 +50,8 @@ impl PromptRoute { where H: for<'a> Fn( PromptContext<'a, S>, - ) -> MaybeBoxFuture<'a, Result> + ) + -> MaybeBoxFuture<'a, Result> + MaybeSend + 'static, { @@ -175,7 +176,7 @@ where pub async fn get_prompt( &self, context: PromptContext<'_, S>, - ) -> Result { + ) -> Result { let item = self.map.get(context.name.as_str()).ok_or_else(|| { crate::ErrorData::invalid_params( format!("prompt '{}' not found", context.name), diff --git a/crates/rmcp/src/handler/server/router/tool.rs b/crates/rmcp/src/handler/server/router/tool.rs index dece66d95..215116250 100644 --- a/crates/rmcp/src/handler/server/router/tool.rs +++ b/crates/rmcp/src/handler/server/router/tool.rs @@ -137,21 +137,19 @@ use crate::{ tool::{CallToolHandler, DynCallToolHandler, ToolCallContext}, tool_name_validation::validate_and_warn_tool_name, }, - model::{CallToolResult, ContentBlock, ErrorCode, Tool, ToolAnnotations}, + model::{CallToolResponse, CallToolResult, ContentBlock, ErrorCode, Tool, ToolAnnotations}, service::{MaybeBoxFuture, MaybeSend}, }; const TOOL_ARGUMENT_DESERIALIZATION_ERROR_PREFIX: &str = "failed to deserialize parameters:"; -fn into_tool_argument_error(error: crate::ErrorData) -> Result { +fn into_tool_argument_error(error: crate::ErrorData) -> Result { if error.code == ErrorCode::INVALID_PARAMS && error .message .starts_with(TOOL_ARGUMENT_DESERIALIZATION_ERROR_PREFIX) { - return Ok(CallToolResult::error(vec![ContentBlock::text( - error.message, - )])); + return Ok(CallToolResult::error(vec![ContentBlock::text(error.message)]).into()); } Err(error) @@ -200,7 +198,8 @@ impl ToolRoute { where C: for<'a> Fn( ToolCallContext<'a, S>, - ) -> MaybeBoxFuture<'a, Result> + ) + -> MaybeBoxFuture<'a, Result> + MaybeSend + 'static, { @@ -561,7 +560,7 @@ where pub async fn call( &self, context: ToolCallContext<'_, S>, - ) -> Result { + ) -> Result { let name = context.name(); if self.disabled.contains(name) { return Err(crate::ErrorData::invalid_params("tool not found", None)); @@ -679,6 +678,9 @@ mod tests { .call(ctx) .await .expect("argument validation should be a tool result"); + let CallToolResponse::Complete(result) = result else { + panic!("expected complete CallToolResult"); + }; assert_eq!(result.is_error, Some(true)); let text = result @@ -696,7 +698,7 @@ mod tests { let service = DummyService; let mut router = ToolRouter::new().with_route(ToolRoute::new_dyn( crate::model::Tool::new("test_tool", "a test tool", Arc::new(Default::default())), - |_ctx| Box::pin(async { Ok(CallToolResult::default()) }), + |_ctx| Box::pin(async { Ok(CallToolResult::default().into()) }), )); router.disable_route("test_tool"); diff --git a/crates/rmcp/src/handler/server/tool.rs b/crates/rmcp/src/handler/server/tool.rs index bf350797d..cb4966df0 100644 --- a/crates/rmcp/src/handler/server/tool.rs +++ b/crates/rmcp/src/handler/server/tool.rs @@ -16,7 +16,10 @@ pub use super::{ use crate::{ RoleServer, handler::server::wrapper::Parameters, - model::{CallToolRequestParams, CallToolResult, IntoContents, JsonObject}, + model::{ + CallToolRequestParams, CallToolResponse, CallToolResult, InputRequiredResult, IntoContents, + JsonObject, + }, service::{MaybeBoxFuture, MaybeSend, MaybeSendFuture, RequestContext}, }; @@ -77,36 +80,46 @@ impl AsRequestContext for ToolCallContext<'_, S> { } pub trait IntoCallToolResult { - fn into_call_tool_result(self) -> Result; + fn into_call_tool_result(self) -> Result; } impl IntoCallToolResult for T { - fn into_call_tool_result(self) -> Result { - Ok(CallToolResult::success(self.into_contents())) + fn into_call_tool_result(self) -> Result { + Ok(CallToolResult::success(self.into_contents()).into()) } } impl IntoCallToolResult for CallToolResult { - fn into_call_tool_result(self) -> Result { - Ok(self) + fn into_call_tool_result(self) -> Result { + Ok(self.into()) + } +} + +impl IntoCallToolResult for InputRequiredResult { + fn into_call_tool_result(self) -> Result { + Ok(self.into()) } } impl IntoCallToolResult for crate::ErrorData { - fn into_call_tool_result(self) -> Result { + fn into_call_tool_result(self) -> Result { Err(self) } } impl IntoCallToolResult for Result { - fn into_call_tool_result(self) -> Result { + fn into_call_tool_result(self) -> Result { match self { Ok(value) => value.into_call_tool_result(), Err(error) => match error.into_call_tool_result() { - Ok(mut result) => { + Ok(CallToolResponse::Complete(mut result)) => { result.is_error = Some(true); - Ok(result) + Ok(result.into()) } + Ok(CallToolResponse::InputRequired(_)) => Err(crate::ErrorData::internal_error( + "InputRequiredResult cannot be returned from a tool error branch", + None, + )), Err(e) => Err(e), }, } @@ -124,7 +137,7 @@ pin_project_lite::pin_project! { }, Ready { #[pin] - result: Ready>, + result: Ready>, } } } @@ -134,7 +147,7 @@ where F: Future, R: IntoCallToolResult, { - type Output = Result; + type Output = Result; fn poll( self: std::pin::Pin<&mut Self>, @@ -153,20 +166,21 @@ pub trait CallToolHandler { fn call( self, context: ToolCallContext<'_, S>, - ) -> MaybeBoxFuture<'_, Result>; + ) -> MaybeBoxFuture<'_, Result>; } #[cfg(not(feature = "local"))] -pub type DynCallToolHandler = dyn for<'s> Fn(ToolCallContext<'s, S>) -> BoxFuture<'s, Result> +pub type DynCallToolHandler = dyn for<'s> Fn(ToolCallContext<'s, S>) -> BoxFuture<'s, Result> + Send + Sync; #[cfg(feature = "local")] -pub type DynCallToolHandler = - dyn for<'s> Fn( - ToolCallContext<'s, S>, - ) - -> futures::future::LocalBoxFuture<'s, Result>; +pub type DynCallToolHandler = dyn for<'s> Fn( + ToolCallContext<'s, S>, +) -> futures::future::LocalBoxFuture< + 's, + Result, +>; // Tool-specific extractor for tool name #[expect(clippy::exhaustive_structs, reason = "intentionally exhaustive")] @@ -205,7 +219,10 @@ impl FromContextPart> for JsonObject { } impl<'s, S> ToolCallContext<'s, S> { - pub fn invoke(self, h: H) -> MaybeBoxFuture<'s, Result> + pub fn invoke( + self, + h: H, + ) -> MaybeBoxFuture<'s, Result> where H: CallToolHandler, { @@ -248,7 +265,7 @@ macro_rules! impl_for { fn call( self, mut context: ToolCallContext<'_, S>, - ) -> MaybeBoxFuture<'_, Result>{ + ) -> MaybeBoxFuture<'_, Result>{ $( let result = $Tn::from_context_part(&mut context); let $Tn = match result { @@ -279,7 +296,7 @@ macro_rules! impl_for { fn call( self, mut context: ToolCallContext, - ) -> MaybeBoxFuture<'static, Result>{ + ) -> MaybeBoxFuture<'static, Result>{ $( let result = $Tn::from_context_part(&mut context); let $Tn = match result { @@ -308,7 +325,7 @@ macro_rules! impl_for { fn call( self, mut context: ToolCallContext, - ) -> MaybeBoxFuture<'static, Result> { + ) -> MaybeBoxFuture<'static, Result> { $( let result = $Tn::from_context_part(&mut context); let $Tn = match result { @@ -333,7 +350,7 @@ macro_rules! impl_for { fn call( self, mut context: ToolCallContext, - ) -> MaybeBoxFuture<'static, Result> { + ) -> MaybeBoxFuture<'static, Result> { $( let result = $Tn::from_context_part(&mut context); let $Tn = match result { diff --git a/crates/rmcp/src/handler/server/wrapper/json.rs b/crates/rmcp/src/handler/server/wrapper/json.rs index c03fbd032..7c5297963 100644 --- a/crates/rmcp/src/handler/server/wrapper/json.rs +++ b/crates/rmcp/src/handler/server/wrapper/json.rs @@ -3,7 +3,10 @@ use std::borrow::Cow; use schemars::JsonSchema; use serde::Serialize; -use crate::{handler::server::tool::IntoCallToolResult, model::CallToolResult}; +use crate::{ + handler::server::tool::IntoCallToolResult, + model::{CallToolResponse, CallToolResult}, +}; /// Json wrapper for structured output /// @@ -27,7 +30,7 @@ impl JsonSchema for Json { // Implementation for Json to create structured content impl IntoCallToolResult for Json { - fn into_call_tool_result(self) -> Result { + fn into_call_tool_result(self) -> Result { let value = serde_json::to_value(self.0).map_err(|e| { crate::ErrorData::internal_error( format!("Failed to serialize structured content: {}", e), @@ -35,6 +38,6 @@ impl IntoCallToolResult for Json { ) })?; - Ok(CallToolResult::structured(value)) + Ok(CallToolResult::structured(value).into()) } } diff --git a/crates/rmcp/src/model.rs b/crates/rmcp/src/model.rs index f6c74c133..97e14923d 100644 --- a/crates/rmcp/src/model.rs +++ b/crates/rmcp/src/model.rs @@ -14,6 +14,8 @@ mod extension; mod meta; mod mrtr; mod prompt; +#[cfg(feature = "request-state")] +mod request_state; mod resource; mod serde_impl; mod task; @@ -26,6 +28,8 @@ pub use extension::*; pub use meta::*; pub use mrtr::*; pub use prompt::*; +#[cfg(feature = "request-state")] +pub use request_state::*; pub use resource::*; use serde::{Deserialize, Serialize, de::DeserializeOwned}; use serde_json::Value; diff --git a/crates/rmcp/src/model/mrtr.rs b/crates/rmcp/src/model/mrtr.rs index e4a5b3fb8..2b8ba71be 100644 --- a/crates/rmcp/src/model/mrtr.rs +++ b/crates/rmcp/src/model/mrtr.rs @@ -10,13 +10,47 @@ //! [`InputRequiredResult`] instead of the normal result. The client fulfills the //! [`InputRequests`], then retries the original request with [`InputResponses`] and //! the echoed `requestState`. +//! +//! # Using MRTR +//! +//! **Server:** return an [`InputRequiredResult`] from a tool/prompt/resource +//! handler via the matching outcome enum ([`CallToolResponse`], +//! [`GetPromptResponse`], [`ReadResourceResponse`]). The SDK only lets an +//! `InputRequiredResult` reach a peer that negotiated protocol version +//! `2026-07-28` or newer; older peers get a protocol error instead. +//! +//! **Client:** the high-level `RunningService` helpers — `call_tool`, +//! `get_prompt`, and `read_resource` — automatically fulfil each +//! [`InputRequest`] through the local `ClientHandler` and retry, up to +//! [`DEFAULT_MRTR_MAX_ROUNDS`]. Use the `*_once` variants (e.g. +//! `call_tool_once`) to receive an [`InputRequiredResult`] directly and drive +//! the rounds yourself. +//! +//! # `requestState` is untrusted +//! +//! The client echoes `requestState` back verbatim, so a stateless server that +//! stores meaningful data in it MUST verify integrity before trusting the echoed +//! value. Enable the `request-state` feature and use `RequestStateCodec` to seal +//! and open it, or keep the state server-side and use `requestState` only as an +//! opaque handle. +//! +//! A complete runnable walkthrough lives in the `servers_mrtr` example. use std::collections::BTreeMap; use serde::{Deserialize, Serialize}; use serde_json::Value; -use super::{CreateMessageRequest, ElicitRequest, ListRootsRequest, Meta, ResultType}; +use super::{ + CallToolResult, CreateMessageRequest, ElicitRequest, GetPromptResult, ListRootsRequest, Meta, + ReadResourceResult, ResultType, ServerResult, +}; + +/// Default maximum number of MRTR rounds a high-level client call will drive. +/// +/// This matches the default used by other Tier 1 SDKs and prevents a +/// misbehaving peer from keeping a request alive indefinitely. +pub const DEFAULT_MRTR_MAX_ROUNDS: usize = 10; /// A server-initiated request that can appear inside [`InputRequests`]. /// @@ -53,6 +87,101 @@ pub type InputRequests = BTreeMap; /// for use as a `BTreeMap` value. pub type InputResponses = BTreeMap; +/// Result of a `tools/call` request, including the MRTR intermediate result. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum CallToolResponse { + /// The server completed the tool call. + Complete(CallToolResult), + /// The server requires client-side input before the tool call can complete. + InputRequired(InputRequiredResult), +} + +impl From for CallToolResponse { + fn from(result: CallToolResult) -> Self { + Self::Complete(result) + } +} + +impl From for CallToolResponse { + fn from(result: InputRequiredResult) -> Self { + Self::InputRequired(result) + } +} + +impl From for ServerResult { + fn from(response: CallToolResponse) -> Self { + match response { + CallToolResponse::Complete(result) => ServerResult::CallToolResult(result), + CallToolResponse::InputRequired(result) => ServerResult::InputRequiredResult(result), + } + } +} + +/// Result of a `prompts/get` request, including the MRTR intermediate result. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum GetPromptResponse { + /// The server completed the prompt request. + Complete(GetPromptResult), + /// The server requires client-side input before the prompt can be returned. + InputRequired(InputRequiredResult), +} + +impl From for GetPromptResponse { + fn from(result: GetPromptResult) -> Self { + Self::Complete(result) + } +} + +impl From for GetPromptResponse { + fn from(result: InputRequiredResult) -> Self { + Self::InputRequired(result) + } +} + +impl From for ServerResult { + fn from(response: GetPromptResponse) -> Self { + match response { + GetPromptResponse::Complete(result) => ServerResult::GetPromptResult(result), + GetPromptResponse::InputRequired(result) => ServerResult::InputRequiredResult(result), + } + } +} + +/// Result of a `resources/read` request, including the MRTR intermediate result. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum ReadResourceResponse { + /// The server completed the resource read. + Complete(ReadResourceResult), + /// The server requires client-side input before the resource can be returned. + InputRequired(InputRequiredResult), +} + +impl From for ReadResourceResponse { + fn from(result: ReadResourceResult) -> Self { + Self::Complete(result) + } +} + +impl From for ReadResourceResponse { + fn from(result: InputRequiredResult) -> Self { + Self::InputRequired(result) + } +} + +impl From for ServerResult { + fn from(response: ReadResourceResponse) -> Self { + match response { + ReadResourceResponse::Complete(result) => ServerResult::ReadResourceResult(result), + ReadResourceResponse::InputRequired(result) => { + ServerResult::InputRequiredResult(result) + } + } + } +} + /// A result indicating that additional input is needed before the request /// can be completed. /// diff --git a/crates/rmcp/src/model/request_state.rs b/crates/rmcp/src/model/request_state.rs new file mode 100644 index 000000000..479b4bf06 --- /dev/null +++ b/crates/rmcp/src/model/request_state.rs @@ -0,0 +1,323 @@ +//! Integrity protection for SEP-2322 `requestState`. +//! +//! In the multi round-trip request (MRTR) flow, a server places an opaque +//! `requestState` string in an [`InputRequiredResult`](super::InputRequiredResult) +//! and the client echoes it back verbatim on retry. From the server's point of +//! view the echoed value is **untrusted, attacker-controlled input**: a client +//! can send back anything it likes. A server that stores meaningful state inside +//! `requestState` (rather than in a server-side session) MUST verify that the +//! value it receives is one it actually produced. +//! +//! [`RequestStateCodec`] provides an opt-in way to do this. It seals a payload +//! into an opaque string with an HMAC-SHA256 tag and opens it again, rejecting +//! any value that was forged or tampered with. This mirrors the approach taken +//! by other MCP SDKs for stateless MRTR servers. +//! +//! This helper is only about *integrity*, not *confidentiality*: the sealed +//! payload is signed, not encrypted, so it is base64url-readable by anyone. +//! Do not put secrets in it. Replay protection (nonces, expiry) is the caller's +//! responsibility and can be embedded in the sealed payload. +//! +//! Using the codec is entirely optional. A server that keeps its state +//! server-side, or that does not trust `requestState` for anything security +//! sensitive, can keep building the string by hand via +//! [`InputRequiredResult::from_request_state`](super::InputRequiredResult::from_request_state). +//! +//! # Examples +//! +//! ``` +//! use rmcp::model::RequestStateCodec; +//! +//! // Derive the key from a per-process secret; keep it out of client reach. +//! let codec = RequestStateCodec::new(b"a-32-byte-or-longer-secret-key!!!"); +//! +//! let sealed = codec.seal(b"tool=weather;step=2"); +//! // `sealed` is what the server puts in `InputRequiredResult::request_state`. +//! +//! // On retry the client echoes `sealed` back untouched. +//! let opened = codec.open(&sealed).expect("integrity check passes"); +//! assert_eq!(opened, b"tool=weather;step=2"); +//! +//! // A tampered value is rejected instead of silently trusted. +//! let mut tampered = sealed.clone(); +//! tampered.push('x'); +//! assert!(codec.open(&tampered).is_err()); +//! ``` + +use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD}; +use hmac::{Hmac, Mac}; +use serde::{Serialize, de::DeserializeOwned}; +use sha2::Sha256; +use thiserror::Error; + +type HmacSha256 = Hmac; + +/// Version tag prefixing every sealed value, so the wire format can evolve. +const VERSION: &str = "rs1"; + +/// Domain-separation label mixed into the HMAC so a `requestState` tag can never +/// be confused with an HMAC computed for some other purpose using the same key. +const DOMAIN: &[u8] = b"rmcp/mrtr/request-state/v1"; + +/// Errors returned when opening a sealed [`RequestStateCodec`] value. +#[derive(Debug, Error)] +#[non_exhaustive] +pub enum RequestStateError { + /// The value is not a well-formed sealed request state (wrong prefix or + /// missing sections). + #[error("request state is malformed or uses an unsupported format")] + MalformedFormat, + + /// A section of the value was not valid base64url. + #[error("request state is not valid base64url")] + InvalidEncoding, + + /// The HMAC tag did not match; the value was forged or tampered with. + #[error("request state failed integrity verification")] + IntegrityCheckFailed, + + /// The sealed payload could not be serialized to JSON. + #[error("failed to serialize request state payload: {0}")] + Serialization(#[source] serde_json::Error), + + /// The opened payload could not be deserialized from JSON. + #[error("failed to deserialize request state payload: {0}")] + Deserialization(#[source] serde_json::Error), +} + +/// A keyed codec that seals and opens SEP-2322 `requestState` values with +/// HMAC-SHA256 integrity protection. +/// +/// Construct one codec per signing key and reuse it for the lifetime of the +/// key. The same key must be used to [`seal`](Self::seal) and +/// [`open`](Self::open) a value, so it has to survive across the rounds of a +/// single MRTR exchange (e.g. a stable per-process or per-deployment secret). +/// +/// The key may be any length; HMAC internally normalizes it. For meaningful +/// security use a high-entropy key of at least 32 bytes. +#[derive(Clone)] +pub struct RequestStateCodec { + key: Box<[u8]>, +} + +impl std::fmt::Debug for RequestStateCodec { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // Never leak the signing key through Debug output. + f.debug_struct("RequestStateCodec") + .field("key", &"") + .finish() + } +} + +impl RequestStateCodec { + /// Creates a codec from a signing key. + pub fn new(key: impl Into>) -> Self { + Self { + key: key.into().into_boxed_slice(), + } + } + + /// Seals raw bytes into an opaque, integrity-protected string suitable for + /// use as `requestState`. + pub fn seal(&self, payload: &[u8]) -> String { + let mut mac = self.mac(); + mac.update(payload); + let tag = mac.finalize().into_bytes(); + + // base64url without padding encodes 3 bytes as 4 chars, rounding up. + let b64_len = |n: usize| n.div_ceil(3) * 4; + let mut out = + String::with_capacity(VERSION.len() + 2 + b64_len(payload.len()) + b64_len(tag.len())); + out.push_str(VERSION); + out.push('.'); + URL_SAFE_NO_PAD.encode_string(payload, &mut out); + out.push('.'); + URL_SAFE_NO_PAD.encode_string(tag.as_slice(), &mut out); + out + } + + /// Seals a serializable value by encoding it as JSON before sealing. + /// + /// # Errors + /// + /// Returns [`RequestStateError::Serialization`] if `value` cannot be encoded + /// as JSON. + pub fn seal_json(&self, value: &T) -> Result { + let payload = serde_json::to_vec(value).map_err(RequestStateError::Serialization)?; + Ok(self.seal(&payload)) + } + + /// Opens a sealed value, verifying its integrity and returning the original + /// bytes. + /// + /// # Errors + /// + /// Returns [`RequestStateError::IntegrityCheckFailed`] if the value was not + /// produced by this key, and [`RequestStateError::MalformedFormat`] or + /// [`RequestStateError::InvalidEncoding`] if it is not a well-formed sealed + /// value. + pub fn open(&self, sealed: &str) -> Result, RequestStateError> { + let mut parts = sealed.split('.'); + let version = parts.next().ok_or(RequestStateError::MalformedFormat)?; + let payload_b64 = parts.next().ok_or(RequestStateError::MalformedFormat)?; + let tag_b64 = parts.next().ok_or(RequestStateError::MalformedFormat)?; + if parts.next().is_some() || version != VERSION { + return Err(RequestStateError::MalformedFormat); + } + + let payload = URL_SAFE_NO_PAD + .decode(payload_b64) + .map_err(|_| RequestStateError::InvalidEncoding)?; + let tag = URL_SAFE_NO_PAD + .decode(tag_b64) + .map_err(|_| RequestStateError::InvalidEncoding)?; + + let mut mac = self.mac(); + mac.update(&payload); + // `verify_slice` compares in constant time and rejects wrong-length tags. + mac.verify_slice(&tag) + .map_err(|_| RequestStateError::IntegrityCheckFailed)?; + + Ok(payload) + } + + /// Opens a sealed value and deserializes its JSON payload. + /// + /// # Errors + /// + /// Returns the same integrity and format errors as [`Self::open`], plus + /// [`RequestStateError::Deserialization`] if the payload is not valid JSON + /// for `T`. + pub fn open_json(&self, sealed: &str) -> Result { + let payload = self.open(sealed)?; + serde_json::from_slice(&payload).map_err(RequestStateError::Deserialization) + } + + /// Builds an HMAC instance keyed for request-state tags, pre-fed with the + /// domain-separation label so `seal` and `open` stay in agreement. + fn mac(&self) -> HmacSha256 { + let mut mac = + HmacSha256::new_from_slice(&self.key).expect("HMAC accepts keys of any length"); + mac.update(DOMAIN); + mac + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn seal_open_roundtrips_bytes() { + let codec = RequestStateCodec::new(b"test-key-test-key-test-key-32byte".to_vec()); + let sealed = codec.seal(b"hello world"); + assert!(sealed.starts_with("rs1.")); + assert_eq!(codec.open(&sealed).unwrap(), b"hello world"); + } + + #[test] + fn seal_open_roundtrips_json() { + #[derive(serde::Serialize, serde::Deserialize, PartialEq, Debug)] + struct State { + tool: String, + round: u32, + } + let codec = RequestStateCodec::new(b"another-strong-signing-key-here!!".to_vec()); + let state = State { + tool: "weather".into(), + round: 3, + }; + let sealed = codec.seal_json(&state).unwrap(); + let opened: State = codec.open_json(&sealed).unwrap(); + assert_eq!(opened, state); + } + + #[test] + fn empty_payload_roundtrips() { + let codec = RequestStateCodec::new(b"k".to_vec()); + let sealed = codec.seal(b""); + assert_eq!(codec.open(&sealed).unwrap(), b""); + } + + #[test] + fn tampered_payload_is_rejected() { + let codec = RequestStateCodec::new(b"signing-key-signing-key-signing!!".to_vec()); + let sealed = codec.seal(b"amount=100"); + + // Flip the payload section but keep the original tag. + let mut parts: Vec<&str> = sealed.split('.').collect(); + let forged_payload = URL_SAFE_NO_PAD.encode(b"amount=999"); + parts[1] = &forged_payload; + let forged = parts.join("."); + + assert!(matches!( + codec.open(&forged), + Err(RequestStateError::IntegrityCheckFailed) + )); + } + + #[test] + fn different_key_is_rejected() { + let signer = RequestStateCodec::new(b"the-real-signing-key-value-here!!".to_vec()); + let attacker = RequestStateCodec::new(b"a-totally-different-forged-key!!!".to_vec()); + let sealed = signer.seal(b"trusted"); + assert!(matches!( + attacker.open(&sealed), + Err(RequestStateError::IntegrityCheckFailed) + )); + } + + #[test] + fn appended_bytes_are_rejected() { + let codec = RequestStateCodec::new(b"key-key-key-key-key-key-key-key!!".to_vec()); + let mut sealed = codec.seal(b"state"); + sealed.push('x'); + assert!(codec.open(&sealed).is_err()); + } + + #[test] + fn wrong_version_prefix_is_malformed() { + let codec = RequestStateCodec::new(b"key".to_vec()); + let sealed = codec.seal(b"state"); + let bumped = sealed.replacen("rs1.", "rs2.", 1); + assert!(matches!( + codec.open(&bumped), + Err(RequestStateError::MalformedFormat) + )); + } + + #[test] + fn missing_sections_are_malformed() { + let codec = RequestStateCodec::new(b"key".to_vec()); + assert!(matches!( + codec.open("rs1"), + Err(RequestStateError::MalformedFormat) + )); + assert!(matches!( + codec.open("rs1.onlypayload"), + Err(RequestStateError::MalformedFormat) + )); + assert!(matches!( + codec.open("rs1.a.b.c"), + Err(RequestStateError::MalformedFormat) + )); + } + + #[test] + fn non_base64_sections_are_invalid_encoding() { + let codec = RequestStateCodec::new(b"key".to_vec()); + // '.' and '+'/'/' are not valid in URL-safe base64; use an invalid char. + assert!(matches!( + codec.open("rs1.!!!!.!!!!"), + Err(RequestStateError::InvalidEncoding) + )); + } + + #[test] + fn debug_does_not_leak_key() { + let codec = RequestStateCodec::new(b"super-secret-key".to_vec()); + let rendered = format!("{codec:?}"); + assert!(!rendered.contains("super-secret-key")); + assert!(rendered.contains("redacted")); + } +} diff --git a/crates/rmcp/src/service.rs b/crates/rmcp/src/service.rs index 29d822a58..991df3b3b 100644 --- a/crates/rmcp/src/service.rs +++ b/crates/rmcp/src/service.rs @@ -86,6 +86,9 @@ pub enum ServiceError { Cancelled { reason: Option }, #[error("request timeout after {}", chrono::Duration::from_std(*timeout).unwrap_or_default())] Timeout { timeout: Duration }, + /// The peer kept returning `input_required` beyond the configured round cap. + #[error("input_required did not complete within {max_rounds} MRTR rounds")] + InputRequiredRoundsExceeded { max_rounds: usize }, } trait TransferObject: diff --git a/crates/rmcp/src/service/client.rs b/crates/rmcp/src/service/client.rs index 05c2749fe..929512615 100644 --- a/crates/rmcp/src/service/client.rs +++ b/crates/rmcp/src/service/client.rs @@ -1,24 +1,26 @@ // Sampling/Roots/Logging are SEP-2577-deprecated; internal references are expected. #![expect(deprecated)] -use std::borrow::Cow; +use std::{borrow::Cow, sync::Arc, time::Duration}; use thiserror::Error; use super::*; use crate::{ model::{ - ArgumentInfo, CallToolRequest, CallToolRequestParams, CallToolResult, + ArgumentInfo, CallToolRequest, CallToolRequestParams, CallToolResponse, CallToolResult, CancelledNotification, CancelledNotificationParam, ClientInfo, ClientJsonRpcMessage, ClientNotification, ClientRequest, ClientResult, CompleteRequest, CompleteRequestParams, - CompleteResult, CompletionContext, CompletionInfo, ErrorData, GetPromptRequest, - GetPromptRequestParams, GetPromptResult, InitializeRequest, InitializedNotification, - JsonRpcResponse, ListPromptsRequest, ListPromptsResult, ListResourceTemplatesRequest, - ListResourceTemplatesResult, ListResourcesRequest, ListResourcesResult, ListToolsRequest, - ListToolsResult, PaginatedRequestParams, ProgressNotification, ProgressNotificationParam, - ReadResourceRequest, ReadResourceRequestParams, ReadResourceResult, Reference, RequestId, - RootsListChangedNotification, ServerInfo, ServerJsonRpcMessage, ServerNotification, - ServerRequest, ServerResult, SetLevelRequest, SetLevelRequestParams, SubscribeRequest, - SubscribeRequestParams, UnsubscribeRequest, UnsubscribeRequestParams, + CompleteResult, CompletionContext, CompletionInfo, DEFAULT_MRTR_MAX_ROUNDS, ErrorData, + GetExtensions, GetMeta, GetPromptRequest, GetPromptRequestParams, GetPromptResponse, + GetPromptResult, InitializeRequest, InitializedNotification, InputRequest, + InputRequiredResult, InputResponses, JsonRpcResponse, ListPromptsRequest, + ListPromptsResult, ListResourceTemplatesRequest, ListResourceTemplatesResult, + ListResourcesRequest, ListResourcesResult, ListToolsRequest, ListToolsResult, + NumberOrString, PaginatedRequestParams, ProgressNotification, ProgressNotificationParam, + ReadResourceRequest, ReadResourceRequestParams, ReadResourceResponse, ReadResourceResult, + Reference, RequestId, RootsListChangedNotification, ServerInfo, ServerJsonRpcMessage, + ServerNotification, ServerRequest, ServerResult, SetLevelRequest, SetLevelRequestParams, + SubscribeRequest, SubscribeRequestParams, UnsubscribeRequest, UnsubscribeRequestParams, }, transport::DynamicTransportError, }; @@ -361,6 +363,72 @@ macro_rules! method { } impl Peer { + /// Send one `tools/call` request and return either a final result or an MRTR + /// `InputRequiredResult` without driving any follow-up rounds. + pub async fn call_tool_once( + &self, + params: CallToolRequestParams, + ) -> Result { + let result = self + .send_request(ClientRequest::CallToolRequest(CallToolRequest { + method: Default::default(), + params, + extensions: Default::default(), + })) + .await?; + match result { + ServerResult::CallToolResult(result) => Ok(CallToolResponse::Complete(result)), + ServerResult::InputRequiredResult(result) => { + Ok(CallToolResponse::InputRequired(result)) + } + _ => Err(ServiceError::UnexpectedResponse), + } + } + + /// Send one `prompts/get` request and return either a final result or an MRTR + /// `InputRequiredResult` without driving any follow-up rounds. + pub async fn get_prompt_once( + &self, + params: GetPromptRequestParams, + ) -> Result { + let result = self + .send_request(ClientRequest::GetPromptRequest(GetPromptRequest { + method: Default::default(), + params, + extensions: Default::default(), + })) + .await?; + match result { + ServerResult::GetPromptResult(result) => Ok(GetPromptResponse::Complete(result)), + ServerResult::InputRequiredResult(result) => { + Ok(GetPromptResponse::InputRequired(result)) + } + _ => Err(ServiceError::UnexpectedResponse), + } + } + + /// Send one `resources/read` request and return either a final result or an + /// MRTR `InputRequiredResult` without driving any follow-up rounds. + pub async fn read_resource_once( + &self, + params: ReadResourceRequestParams, + ) -> Result { + let result = self + .send_request(ClientRequest::ReadResourceRequest(ReadResourceRequest { + method: Default::default(), + params, + extensions: Default::default(), + })) + .await?; + match result { + ServerResult::ReadResourceResult(result) => Ok(ReadResourceResponse::Complete(result)), + ServerResult::InputRequiredResult(result) => { + Ok(ReadResourceResponse::InputRequired(result)) + } + _ => Err(ServiceError::UnexpectedResponse), + } + } + method!(peer_req complete CompleteRequest(CompleteRequestParams) => CompleteResult); method!( #[deprecated( @@ -558,3 +626,294 @@ impl Peer { Ok(completion.values) } } + +impl RunningService +where + S: Service, +{ + /// Send one `tools/call` request without driving MRTR follow-up rounds. + pub async fn call_tool_once( + &self, + params: CallToolRequestParams, + ) -> Result { + self.peer.call_tool_once(params).await + } + + /// Send one `prompts/get` request without driving MRTR follow-up rounds. + pub async fn get_prompt_once( + &self, + params: GetPromptRequestParams, + ) -> Result { + self.peer.get_prompt_once(params).await + } + + /// Send one `resources/read` request without driving MRTR follow-up rounds. + pub async fn read_resource_once( + &self, + params: ReadResourceRequestParams, + ) -> Result { + self.peer.read_resource_once(params).await + } + + /// High-level `tools/call` helper that automatically fulfils SEP-2322 + /// `input_required` rounds through the local [`ClientHandler`](crate::ClientHandler) service. + /// + /// # Errors + /// + /// Returns [`ServiceError::InputRequiredRoundsExceeded`] if the peer does + /// not produce a final [`CallToolResult`] within the default MRTR round cap. + /// Other transport, protocol, and local input-handler errors are propagated. + pub async fn call_tool( + &self, + params: CallToolRequestParams, + ) -> Result { + self.call_tool_with_mrtr_max_rounds(params, DEFAULT_MRTR_MAX_ROUNDS) + .await + } + + /// Same as [`Self::call_tool`], with an explicit MRTR round cap. + /// + /// # Errors + /// + /// Returns [`ServiceError::InputRequiredRoundsExceeded`] once `max_rounds` + /// `input_required` responses have been driven without receiving a final + /// [`CallToolResult`]. Other transport, protocol, and local input-handler + /// errors are propagated. + pub async fn call_tool_with_mrtr_max_rounds( + &self, + mut params: CallToolRequestParams, + max_rounds: usize, + ) -> Result { + let mut state_only_rounds = 0usize; + for _round in 0..max_rounds { + match self.peer.call_tool_once(params.clone()).await? { + CallToolResponse::Complete(result) => return Ok(result), + CallToolResponse::InputRequired(result) => { + let (input_responses, request_state) = self + .prepare_input_required_retry(result, &mut state_only_rounds) + .await?; + params.input_responses = input_responses; + params.request_state = request_state; + } + } + } + Err(ServiceError::InputRequiredRoundsExceeded { max_rounds }) + } + + /// High-level `prompts/get` helper that automatically fulfils SEP-2322 + /// `input_required` rounds through the local [`ClientHandler`](crate::ClientHandler) service. + /// + /// # Errors + /// + /// Returns [`ServiceError::InputRequiredRoundsExceeded`] if the peer does + /// not produce a final [`GetPromptResult`] within the default MRTR round cap. + /// Other transport, protocol, and local input-handler errors are propagated. + pub async fn get_prompt( + &self, + params: GetPromptRequestParams, + ) -> Result { + self.get_prompt_with_mrtr_max_rounds(params, DEFAULT_MRTR_MAX_ROUNDS) + .await + } + + /// Same as [`Self::get_prompt`], with an explicit MRTR round cap. + /// + /// # Errors + /// + /// Returns [`ServiceError::InputRequiredRoundsExceeded`] once `max_rounds` + /// `input_required` responses have been driven without receiving a final + /// [`GetPromptResult`]. Other transport, protocol, and local input-handler + /// errors are propagated. + pub async fn get_prompt_with_mrtr_max_rounds( + &self, + mut params: GetPromptRequestParams, + max_rounds: usize, + ) -> Result { + let mut state_only_rounds = 0usize; + for _round in 0..max_rounds { + match self.peer.get_prompt_once(params.clone()).await? { + GetPromptResponse::Complete(result) => return Ok(result), + GetPromptResponse::InputRequired(result) => { + let (input_responses, request_state) = self + .prepare_input_required_retry(result, &mut state_only_rounds) + .await?; + params.input_responses = input_responses; + params.request_state = request_state; + } + } + } + Err(ServiceError::InputRequiredRoundsExceeded { max_rounds }) + } + + /// High-level `resources/read` helper that automatically fulfils SEP-2322 + /// `input_required` rounds through the local [`ClientHandler`](crate::ClientHandler) service. + /// + /// # Errors + /// + /// Returns [`ServiceError::InputRequiredRoundsExceeded`] if the peer does + /// not produce a final [`ReadResourceResult`] within the default MRTR round + /// cap. Other transport, protocol, and local input-handler errors are + /// propagated. + pub async fn read_resource( + &self, + params: ReadResourceRequestParams, + ) -> Result { + self.read_resource_with_mrtr_max_rounds(params, DEFAULT_MRTR_MAX_ROUNDS) + .await + } + + /// Same as [`Self::read_resource`], with an explicit MRTR round cap. + /// + /// # Errors + /// + /// Returns [`ServiceError::InputRequiredRoundsExceeded`] once `max_rounds` + /// `input_required` responses have been driven without receiving a final + /// [`ReadResourceResult`]. Other transport, protocol, and local input-handler + /// errors are propagated. + pub async fn read_resource_with_mrtr_max_rounds( + &self, + mut params: ReadResourceRequestParams, + max_rounds: usize, + ) -> Result { + let mut state_only_rounds = 0usize; + for _round in 0..max_rounds { + match self.peer.read_resource_once(params.clone()).await? { + ReadResourceResponse::Complete(result) => return Ok(result), + ReadResourceResponse::InputRequired(result) => { + let (input_responses, request_state) = self + .prepare_input_required_retry(result, &mut state_only_rounds) + .await?; + params.input_responses = input_responses; + params.request_state = request_state; + } + } + } + Err(ServiceError::InputRequiredRoundsExceeded { max_rounds }) + } + + async fn prepare_input_required_retry( + &self, + result: InputRequiredResult, + state_only_rounds: &mut usize, + ) -> Result<(Option, Option), ServiceError> { + let had_input_requests = result + .input_requests + .as_ref() + .is_some_and(|requests| !requests.is_empty()); + if !had_input_requests && result.request_state.is_none() { + return Err(ServiceError::UnexpectedResponse); + } + + let responses = self + .fulfill_input_requests(result.input_requests.unwrap_or_default()) + .await?; + if had_input_requests { + *state_only_rounds = 0; + } else { + Self::sleep_state_only_round(*state_only_rounds).await; + *state_only_rounds += 1; + } + + Ok(( + (!responses.is_empty()).then_some(responses), + result.request_state, + )) + } + + async fn fulfill_input_requests( + &self, + requests: crate::model::InputRequests, + ) -> Result { + let responses = futures::future::try_join_all( + requests + .into_iter() + .map(|(key, request)| self.fulfill_input_request(key, request)), + ) + .await?; + Ok(responses.into_iter().collect()) + } + + async fn fulfill_input_request( + &self, + key: String, + request: InputRequest, + ) -> Result<(String, serde_json::Value), ServiceError> { + let response = match request { + InputRequest::CreateMessage(request) => { + let mut request = ServerRequest::CreateMessageRequest(request); + let context = self.input_request_context(&key, &mut request); + match self + .service + .handle_request(request, context) + .await + .map_err(ServiceError::McpError)? + { + ClientResult::CreateMessageResult(result) => { + serde_json::to_value(result).map_err(Self::serde_to_service_error)? + } + _ => return Err(ServiceError::UnexpectedResponse), + } + } + InputRequest::Elicitation(request) => { + let mut request = ServerRequest::ElicitRequest(request); + let context = self.input_request_context(&key, &mut request); + match self + .service + .handle_request(request, context) + .await + .map_err(ServiceError::McpError)? + { + ClientResult::ElicitResult(result) => { + serde_json::to_value(result).map_err(Self::serde_to_service_error)? + } + _ => return Err(ServiceError::UnexpectedResponse), + } + } + InputRequest::ListRoots(request) => { + let mut request = ServerRequest::ListRootsRequest(request); + let context = self.input_request_context(&key, &mut request); + match self + .service + .handle_request(request, context) + .await + .map_err(ServiceError::McpError)? + { + ClientResult::ListRootsResult(result) => { + serde_json::to_value(result).map_err(Self::serde_to_service_error)? + } + _ => return Err(ServiceError::UnexpectedResponse), + } + } + }; + Ok((key, response)) + } + + fn input_request_context(&self, key: &str, request: &mut T) -> RequestContext + where + T: GetMeta + GetExtensions, + { + let mut meta = Default::default(); + let mut extensions = Default::default(); + std::mem::swap(&mut meta, request.get_meta_mut()); + std::mem::swap(&mut extensions, request.extensions_mut()); + RequestContext { + ct: tokio_util::sync::CancellationToken::new(), + id: NumberOrString::String(Arc::from(key)), + peer: self.peer.clone(), + meta, + extensions, + } + } + + async fn sleep_state_only_round(state_only_rounds: usize) { + let millis = (50u64.saturating_mul(1_u64 << state_only_rounds.min(3))).min(250); + tokio::time::sleep(Duration::from_millis(millis)).await; + } + + fn serde_to_service_error(error: serde_json::Error) -> ServiceError { + ServiceError::McpError(ErrorData::internal_error( + format!("failed to serialize MRTR input response: {error}"), + None, + )) + } +} diff --git a/crates/rmcp/tests/test_mrtr_behavior.rs b/crates/rmcp/tests/test_mrtr_behavior.rs new file mode 100644 index 000000000..b9087f706 --- /dev/null +++ b/crates/rmcp/tests/test_mrtr_behavior.rs @@ -0,0 +1,594 @@ +//! Behavior and edge-case coverage for SEP-2322 multi round-trip requests (MRTR). +//! +//! These tests drive a real client/server pair over an in-memory duplex stream +//! and exercise the auto fulfill/retry loop, the manual `*_once` escape hatch, +//! and the server-side version gating. + +// Sampling/Roots are SEP-2577-deprecated but still used to model MRTR input requests. +#![allow(deprecated)] +use std::sync::{ + Arc, + atomic::{AtomicUsize, Ordering}, +}; + +use rmcp::{ + ClientHandler, ServerHandler, + model::*, + service::{RequestContext, RoleClient, RoleServer, ServiceError, serve_directly}, +}; +use serde_json::json; + +/// A `requestState` value with characters that must survive a byte-exact echo: +/// dots (the codec delimiter), base64 punctuation, whitespace, and quotes. +const TRICKY_STATE: &str = "st.ate/with+special=chars and spaces \"quotes\"\n\ttab"; + +// ============================================================================= +// Test handlers +// ============================================================================= + +/// A stateless MRTR server whose behavior is selected by the tool/prompt/resource +/// name. Round progression is derived entirely from `request_state` and +/// `input_responses`, as required by the stateless MRTR pattern. +#[derive(Clone, Default)] +struct MrtrServer { + calls: Arc, +} + +fn elicitation_request(message: &str) -> InputRequest { + InputRequest::Elicitation(ElicitRequest::new( + ElicitRequestParams::FormElicitationParams { + meta: None, + message: message.into(), + requested_schema: serde_json::from_value(json!({ + "type": "object", + "properties": { "name": { "type": "string" } }, + "required": ["name"] + })) + .unwrap(), + }, + )) +} + +fn sampling_request() -> InputRequest { + InputRequest::CreateMessage(CreateMessageRequest::new(CreateMessageRequestParams::new( + vec![SamplingMessage::user_text("What is the capital of France?")], + 100, + ))) +} + +fn roots_request() -> InputRequest { + InputRequest::ListRoots(ListRootsRequest::default()) +} + +fn single_elicitation(state: &str) -> InputRequiredResult { + let mut requests = InputRequests::new(); + requests.insert("answer".to_string(), elicitation_request("Name?")); + InputRequiredResult::new(Some(requests), Some(state.into())) +} + +impl MrtrServer { + fn call_tool_impl( + &self, + request: CallToolRequestParams, + ) -> Result { + let responses = request.input_responses.as_ref(); + let state = request.request_state.as_deref(); + match request.name.as_ref() { + // Single round: one elicitation, then complete. + "single" => match responses { + None => Ok(single_elicitation("state-single").into()), + Some(map) => { + if state != Some("state-single") { + return Err(ErrorData::internal_error("request_state not echoed", None)); + } + let answer = map + .get("answer") + .ok_or_else(|| ErrorData::internal_error("missing answer", None))?; + if answer["action"] != "accept" || answer["content"]["name"] != "Ferris" { + return Err(ErrorData::internal_error("unexpected elicit result", None)); + } + Ok(CallToolResult::success(vec![ContentBlock::text("done")]).into()) + } + }, + // Two elicitation rounds before completing. + "multi_round" => match state { + None => Ok(single_elicitation("round-1").into()), + Some("round-1") => Ok(single_elicitation("round-2").into()), + Some("round-2") => { + Ok(CallToolResult::success(vec![ContentBlock::text("multi-done")]).into()) + } + Some(other) => Err(ErrorData::internal_error( + format!("unexpected round state {other:?}"), + None, + )), + }, + // Several input requests fulfilled concurrently in a single round. + "multi_request" => match responses { + None => { + let mut requests = InputRequests::new(); + requests.insert("form".to_string(), elicitation_request("Name?")); + requests.insert("sample".to_string(), sampling_request()); + requests.insert("roots".to_string(), roots_request()); + Ok(InputRequiredResult::new(Some(requests), Some("multi-req".into())).into()) + } + Some(map) => { + for key in ["form", "sample", "roots"] { + if !map.contains_key(key) { + return Err(ErrorData::internal_error( + format!("missing response for {key}"), + None, + )); + } + } + Ok(CallToolResult::success(vec![ContentBlock::text("multi-req-done")]).into()) + } + }, + // State-only load shedding: two state-only rounds, then complete. + "state_only" => match state { + None => Ok(InputRequiredResult::from_request_state("so-1").into()), + Some("so-1") => Ok(InputRequiredResult::from_request_state("so-2").into()), + Some("so-2") => { + Ok(CallToolResult::success(vec![ContentBlock::text("state-done")]).into()) + } + Some(other) => Err(ErrorData::internal_error( + format!("unexpected state {other:?}"), + None, + )), + }, + // Never completes: used to exercise the max-rounds cap. + "loops" => Ok(single_elicitation("loop").into()), + // Triggers a failure inside the client's elicitation handler. + "handler_error" => { + let mut requests = InputRequests::new(); + requests.insert("answer".to_string(), elicitation_request("FAIL")); + Ok(InputRequiredResult::new(Some(requests), Some("state".into())).into()) + } + // Verifies the client echoes `request_state` byte-for-byte. + "echo_state" => match responses { + None => { + let mut requests = InputRequests::new(); + requests.insert("answer".to_string(), elicitation_request("Name?")); + Ok(InputRequiredResult::new(Some(requests), Some(TRICKY_STATE.into())).into()) + } + Some(_) => { + if state != Some(TRICKY_STATE) { + return Err(ErrorData::internal_error( + "request_state was not echoed byte-exact", + None, + )); + } + Ok(CallToolResult::success(vec![ContentBlock::text("echo-ok")]).into()) + } + }, + _ => Ok(CallToolResult::success(vec![ContentBlock::text("noop")]).into()), + } + } +} + +impl ServerHandler for MrtrServer { + fn get_info(&self) -> ServerInfo { + let mut info = ServerInfo::new( + ServerCapabilities::builder() + .enable_tools() + .enable_prompts() + .build(), + ); + info.protocol_version = ProtocolVersion::V_2026_07_28; + info + } + + async fn call_tool( + &self, + request: CallToolRequestParams, + _context: RequestContext, + ) -> Result { + self.calls.fetch_add(1, Ordering::SeqCst); + self.call_tool_impl(request) + } + + async fn get_prompt( + &self, + request: GetPromptRequestParams, + _context: RequestContext, + ) -> Result { + self.calls.fetch_add(1, Ordering::SeqCst); + match request.request_state.as_deref() { + None => Ok(single_elicitation("prompt-1").into()), + Some("prompt-1") => Ok(GetPromptResult::new(vec![PromptMessage::new_text( + Role::Assistant, + "prompt-done", + )]) + .into()), + Some(other) => Err(ErrorData::internal_error( + format!("unexpected prompt state {other:?}"), + None, + )), + } + } + + async fn read_resource( + &self, + request: ReadResourceRequestParams, + _context: RequestContext, + ) -> Result { + self.calls.fetch_add(1, Ordering::SeqCst); + match request.request_state.as_deref() { + None => Ok(single_elicitation("res-1").into()), + Some("res-1") => Ok(ReadResourceResult::new(vec![ResourceContents::text( + "resource-done", + request.uri, + )]) + .into()), + Some(other) => Err(ErrorData::internal_error( + format!("unexpected resource state {other:?}"), + None, + )), + } + } +} + +/// A client that fulfills every kind of MRTR input request. Elicitation fails +/// deliberately when the prompt message is `"FAIL"`. +#[derive(Clone, Default)] +struct MrtrClient; + +impl ClientHandler for MrtrClient { + async fn create_elicitation( + &self, + request: ElicitRequestParams, + _context: RequestContext, + ) -> Result { + if let ElicitRequestParams::FormElicitationParams { message, .. } = &request { + if message == "FAIL" { + return Err(ErrorData::internal_error( + "elicitation handler failed", + None, + )); + } + } + Ok(ElicitResult::new(ElicitationAction::Accept).with_content(json!({ "name": "Ferris" }))) + } + + async fn create_message( + &self, + _request: CreateMessageRequestParams, + _context: RequestContext, + ) -> Result { + Ok(CreateMessageResult::new( + SamplingMessage::assistant_text("Paris."), + "test-model".into(), + ) + .with_stop_reason(CreateMessageResult::STOP_REASON_END_TURN)) + } + + async fn list_roots( + &self, + _context: RequestContext, + ) -> Result { + Ok(ListRootsResult::new(vec![Root::new("file:///workspace")])) + } +} + +// ============================================================================= +// Harness +// ============================================================================= + +fn client_info(protocol_version: ProtocolVersion) -> ClientInfo { + ClientInfo::new( + ClientCapabilities::builder().enable_elicitation().build(), + Implementation::new("mrtr-test-client", "0.0.0"), + ) + .with_protocol_version(protocol_version) +} + +fn server_info_2026() -> ServerInfo { + let mut info = ServerInfo::new(ServerCapabilities::builder().enable_tools().build()); + info.protocol_version = ProtocolVersion::V_2026_07_28; + info +} + +/// Runs `body` inside a `LocalSet` so `spawn_local` (used when the `local` +/// feature is active) is available, wiring up a connected client/server pair. +async fn with_pair( + server: MrtrServer, + client_protocol: ProtocolVersion, + body: F, +) -> anyhow::Result<()> +where + F: FnOnce(rmcp::service::RunningService) -> Fut, + Fut: std::future::Future>, +{ + tokio::task::LocalSet::new() + .run_until(async move { + let (server_transport, client_transport) = tokio::io::duplex(8192); + let server_peer_info = client_info(client_protocol); + let server_task = tokio::task::spawn_local(async move { + let running = serve_directly::( + server, + server_transport, + Some(server_peer_info), + ); + running.waiting().await?; + anyhow::Ok(()) + }); + + let client = serve_directly::( + MrtrClient, + client_transport, + Some(server_info_2026()), + ); + + let result = body(client).await; + + server_task.abort(); + result + }) + .await +} + +// ============================================================================= +// Tests +// ============================================================================= + +#[tokio::test(flavor = "current_thread")] +async fn client_auto_fulfills_input_required_tool_call() -> anyhow::Result<()> { + let server = MrtrServer::default(); + let calls = server.calls.clone(); + with_pair(server, ProtocolVersion::V_2026_07_28, |client| async move { + let result = client + .call_tool(CallToolRequestParams::new("single")) + .await?; + assert_eq!(result.content[0].as_text().unwrap().text, "done"); + assert_eq!(calls.load(Ordering::SeqCst), 2); + Ok(()) + }) + .await +} + +#[tokio::test(flavor = "current_thread")] +async fn manual_once_returns_input_required_without_retry() -> anyhow::Result<()> { + let server = MrtrServer::default(); + let calls = server.calls.clone(); + with_pair(server, ProtocolVersion::V_2026_07_28, |client| async move { + let result = client + .call_tool_once(CallToolRequestParams::new("single")) + .await?; + assert!(matches!(result, CallToolResponse::InputRequired(_))); + // A manual round makes exactly one server call and never retries. + assert_eq!(calls.load(Ordering::SeqCst), 1); + Ok(()) + }) + .await +} + +#[tokio::test(flavor = "current_thread")] +async fn multi_round_input_required_completes() -> anyhow::Result<()> { + let server = MrtrServer::default(); + let calls = server.calls.clone(); + with_pair(server, ProtocolVersion::V_2026_07_28, |client| async move { + let result = client + .call_tool(CallToolRequestParams::new("multi_round")) + .await?; + assert_eq!(result.content[0].as_text().unwrap().text, "multi-done"); + // round 0 + two retries = 3 server calls. + assert_eq!(calls.load(Ordering::SeqCst), 3); + Ok(()) + }) + .await +} + +#[tokio::test(flavor = "current_thread")] +async fn multiple_input_requests_fulfilled_in_one_round() -> anyhow::Result<()> { + let server = MrtrServer::default(); + let calls = server.calls.clone(); + with_pair(server, ProtocolVersion::V_2026_07_28, |client| async move { + let result = client + .call_tool(CallToolRequestParams::new("multi_request")) + .await?; + assert_eq!(result.content[0].as_text().unwrap().text, "multi-req-done"); + assert_eq!(calls.load(Ordering::SeqCst), 2); + Ok(()) + }) + .await +} + +#[tokio::test(flavor = "current_thread")] +async fn state_only_input_required_completes() -> anyhow::Result<()> { + let server = MrtrServer::default(); + let calls = server.calls.clone(); + with_pair(server, ProtocolVersion::V_2026_07_28, |client| async move { + let result = client + .call_tool(CallToolRequestParams::new("state_only")) + .await?; + assert_eq!(result.content[0].as_text().unwrap().text, "state-done"); + // round 0 + two state-only retries = 3 server calls. + assert_eq!(calls.load(Ordering::SeqCst), 3); + Ok(()) + }) + .await +} + +#[tokio::test(flavor = "current_thread")] +async fn max_rounds_exceeded_returns_error() -> anyhow::Result<()> { + let server = MrtrServer::default(); + let calls = server.calls.clone(); + with_pair(server, ProtocolVersion::V_2026_07_28, |client| async move { + let err = client + .call_tool_with_mrtr_max_rounds(CallToolRequestParams::new("loops"), 3) + .await + .expect_err("a tool that never completes must exhaust the round cap"); + assert!(matches!( + err, + ServiceError::InputRequiredRoundsExceeded { max_rounds: 3 } + )); + assert_eq!(calls.load(Ordering::SeqCst), 3); + Ok(()) + }) + .await +} + +#[tokio::test(flavor = "current_thread")] +async fn client_handler_error_propagates() -> anyhow::Result<()> { + let server = MrtrServer::default(); + with_pair(server, ProtocolVersion::V_2026_07_28, |client| async move { + let err = client + .call_tool(CallToolRequestParams::new("handler_error")) + .await + .expect_err("a failing input handler must fail the whole call"); + assert!(matches!(err, ServiceError::McpError(_))); + Ok(()) + }) + .await +} + +#[tokio::test(flavor = "current_thread")] +async fn request_state_is_echoed_byte_exact() -> anyhow::Result<()> { + let server = MrtrServer::default(); + with_pair(server, ProtocolVersion::V_2026_07_28, |client| async move { + // The server returns an error result unless it sees TRICKY_STATE echoed + // back unchanged, so a successful completion proves the byte-exact echo. + let result = client + .call_tool(CallToolRequestParams::new("echo_state")) + .await?; + assert_eq!(result.content[0].as_text().unwrap().text, "echo-ok"); + Ok(()) + }) + .await +} + +#[tokio::test(flavor = "current_thread")] +async fn get_prompt_auto_fulfills_input_required() -> anyhow::Result<()> { + let server = MrtrServer::default(); + with_pair(server, ProtocolVersion::V_2026_07_28, |client| async move { + let result = client.get_prompt(GetPromptRequestParams::new("p")).await?; + assert_eq!( + result.messages[0].content.as_text().unwrap().text, + "prompt-done" + ); + Ok(()) + }) + .await +} + +#[tokio::test(flavor = "current_thread")] +async fn read_resource_auto_fulfills_input_required() -> anyhow::Result<()> { + let server = MrtrServer::default(); + with_pair(server, ProtocolVersion::V_2026_07_28, |client| async move { + let result = client + .read_resource(ReadResourceRequestParams::new("res://x")) + .await?; + let text = match &result.contents[0] { + ResourceContents::TextResourceContents { text, .. } => text.clone(), + _ => panic!("expected text resource"), + }; + assert_eq!(text, "resource-done"); + Ok(()) + }) + .await +} + +#[tokio::test(flavor = "current_thread")] +async fn old_protocol_rejects_input_required() -> anyhow::Result<()> { + let server = MrtrServer::default(); + // The client negotiated 2025-11-25, so the server must refuse to emit an + // InputRequiredResult and return a protocol error instead. + with_pair(server, ProtocolVersion::V_2025_11_25, |client| async move { + let err = client + .call_tool_once(CallToolRequestParams::new("single")) + .await + .expect_err("MRTR must be rejected for pre-2026 peers"); + match err { + ServiceError::McpError(error) => { + assert!( + error.message.contains("2026-07-28"), + "unexpected error message: {}", + error.message + ); + } + other => panic!("expected an McpError, got {other:?}"), + } + Ok(()) + }) + .await +} + +#[cfg(feature = "request-state")] +#[tokio::test(flavor = "current_thread")] +async fn request_state_codec_seals_and_verifies_through_the_loop() -> anyhow::Result<()> { + use std::sync::OnceLock; + + use rmcp::model::RequestStateCodec; + + // A shared per-process signing key, mirroring how a real server would derive one. + static KEY: &[u8] = b"integration-signing-key-32-bytes!"; + + fn codec() -> &'static RequestStateCodec { + static CODEC: OnceLock = OnceLock::new(); + CODEC.get_or_init(|| RequestStateCodec::new(KEY)) + } + + #[derive(Clone, Default)] + struct SealingServer; + + impl ServerHandler for SealingServer { + fn get_info(&self) -> ServerInfo { + let mut info = ServerInfo::new(ServerCapabilities::builder().enable_tools().build()); + info.protocol_version = ProtocolVersion::V_2026_07_28; + info + } + + async fn call_tool( + &self, + request: CallToolRequestParams, + _context: RequestContext, + ) -> Result { + match request.request_state { + None => { + let sealed = codec() + .seal_json(&json!({ "step": 1, "tool": request.name })) + .map_err(|e| ErrorData::internal_error(e.to_string(), None))?; + let mut requests = InputRequests::new(); + requests.insert("answer".to_string(), elicitation_request("Name?")); + Ok(InputRequiredResult::new(Some(requests), Some(sealed)).into()) + } + Some(sealed) => { + // The echoed state is untrusted; verify it before use. + let state: serde_json::Value = codec() + .open_json(&sealed) + .map_err(|e| ErrorData::internal_error(e.to_string(), None))?; + assert_eq!(state["step"], 1); + Ok(CallToolResult::success(vec![ContentBlock::text("sealed-done")]).into()) + } + } + } + } + + tokio::task::LocalSet::new() + .run_until(async move { + let (server_transport, client_transport) = tokio::io::duplex(8192); + let server_task = tokio::task::spawn_local(async move { + let running = serve_directly::( + SealingServer, + server_transport, + Some(client_info(ProtocolVersion::V_2026_07_28)), + ); + running.waiting().await?; + anyhow::Ok(()) + }); + + let client = serve_directly::( + MrtrClient, + client_transport, + Some(server_info_2026()), + ); + + let result = client + .call_tool(CallToolRequestParams::new("sealed")) + .await?; + assert_eq!(result.content[0].as_text().unwrap().text, "sealed-done"); + + server_task.abort(); + anyhow::Ok(()) + }) + .await +} diff --git a/crates/rmcp/tests/test_resource_not_found_version.rs b/crates/rmcp/tests/test_resource_not_found_version.rs index 255accb8c..44eb3631f 100644 --- a/crates/rmcp/tests/test_resource_not_found_version.rs +++ b/crates/rmcp/tests/test_resource_not_found_version.rs @@ -9,7 +9,7 @@ use rmcp::{ ClientHandler, RoleServer, ServerHandler, ServiceError, ServiceExt, model::{ ClientInfo, ErrorCode, ErrorData, ProtocolVersion, ReadResourceRequestParams, - ReadResourceResult, + ReadResourceResponse, }, service::RequestContext, }; @@ -22,7 +22,7 @@ impl ServerHandler for ResourceServer { &self, _request: ReadResourceRequestParams, _context: RequestContext, - ) -> Result { + ) -> Result { Err(ErrorData::resource_not_found("resource not found", None)) } } diff --git a/crates/rmcp/tests/test_result_type_wire.rs b/crates/rmcp/tests/test_result_type_wire.rs new file mode 100644 index 000000000..9c340ff88 --- /dev/null +++ b/crates/rmcp/tests/test_result_type_wire.rs @@ -0,0 +1,29 @@ +//! Wire-shape regression guards for the SEP-2322 `resultType` discriminator. +//! +//! These pin the behavior that keeps older/strict peers working: +//! - `EmptyResult` stays a bare `{}` (some peers strict-validate empty results +//! and reject extra keys), and +//! - ordinary results carry `resultType: "complete"`. + +use rmcp::model::{CallToolResult, ContentBlock, EmptyResult, ListToolsResult}; +use serde_json::json; + +#[test] +fn empty_result_serializes_without_result_type() { + let value = serde_json::to_value(EmptyResult {}).expect("serialize EmptyResult"); + assert_eq!(value, json!({})); +} + +#[test] +fn call_tool_result_serializes_complete_result_type() { + let value = serde_json::to_value(CallToolResult::success(vec![ContentBlock::text("ok")])) + .expect("serialize CallToolResult"); + assert_eq!(value["resultType"], "complete"); +} + +#[test] +fn paginated_result_serializes_complete_result_type() { + let value = + serde_json::to_value(ListToolsResult::default()).expect("serialize ListToolsResult"); + assert_eq!(value["resultType"], "complete"); +} diff --git a/crates/rmcp/tests/test_structured_output.rs b/crates/rmcp/tests/test_structured_output.rs index bb0d5e029..513c5ee5e 100644 --- a/crates/rmcp/tests/test_structured_output.rs +++ b/crates/rmcp/tests/test_structured_output.rs @@ -3,7 +3,7 @@ use rmcp::{ Json, ServerHandler, handler::server::{router::tool::ToolRouter, tool::IntoCallToolResult, wrapper::Parameters}, - model::{CallToolResult, ContentBlock, ServerResult, Tool}, + model::{CallToolResponse, CallToolResult, ContentBlock, ServerResult, Tool}, tool, tool_handler, tool_router, }; use schemars::JsonSchema; @@ -224,11 +224,13 @@ async fn test_structured_return_conversion() { }; let structured = Json(calc_result); - let result: Result = + let result: Result = rmcp::handler::server::tool::IntoCallToolResult::into_call_tool_result(structured); assert!(result.is_ok()); - let call_result = result.unwrap(); + let CallToolResponse::Complete(call_result) = result.unwrap() else { + panic!("expected complete CallToolResult"); + }; // Tools which return structured content should also return a serialized version as // Content::text for backwards compatibility. @@ -285,11 +287,13 @@ async fn test_output_schema_requires_structured_content() { let result = server.calculate(params).await.unwrap(); // Convert the Json to CallToolResult - let call_result: Result = + let call_result: Result = IntoCallToolResult::into_call_tool_result(result); assert!(call_result.is_ok()); - let call_result = call_result.unwrap(); + let CallToolResponse::Complete(call_result) = call_result.unwrap() else { + panic!("expected complete CallToolResult"); + }; // Verify it has structured_content and content assert!(call_result.structured_content.is_some()); diff --git a/crates/rmcp/tests/test_tool_disable_notification.rs b/crates/rmcp/tests/test_tool_disable_notification.rs index b30a58e18..cd8780591 100644 --- a/crates/rmcp/tests/test_tool_disable_notification.rs +++ b/crates/rmcp/tests/test_tool_disable_notification.rs @@ -9,7 +9,7 @@ use std::sync::{ use rmcp::{ ClientHandler, RoleClient, RoleServer, ServerHandler, ServiceExt, handler::server::{router::tool::ToolRoute, tool::ToolCallContext}, - model::{CallToolResult, ServerCapabilities, ServerInfo, Tool}, + model::{CallToolResponse, CallToolResult, ServerCapabilities, ServerInfo, Tool}, service::{MaybeSendFuture, NotificationContext}, }; use tokio::sync::{Notify, RwLock}; @@ -26,11 +26,11 @@ impl TestToolServer { let mut tool_router = rmcp::handler::server::router::tool::ToolRouter::::new(); tool_router.add_route(ToolRoute::new_dyn( Tool::new("tool_a", "Tool A", Arc::new(Default::default())), - |_ctx| Box::pin(async { Ok(CallToolResult::default()) }), + |_ctx| Box::pin(async { Ok(CallToolResult::default().into()) }), )); tool_router.add_route(ToolRoute::new_dyn( Tool::new("tool_b", "Tool B", Arc::new(Default::default())), - |_ctx| Box::pin(async { Ok(CallToolResult::default()) }), + |_ctx| Box::pin(async { Ok(CallToolResult::default().into()) }), )); Self { router: Arc::new(RwLock::new(tool_router)), @@ -49,7 +49,7 @@ impl ServerHandler for TestToolServer { &self, request: rmcp::model::CallToolRequestParams, context: rmcp::service::RequestContext, - ) -> Result { + ) -> Result { let router = self.router.read().await; let tcc = ToolCallContext::new(self, request, context); router.call(tcc).await diff --git a/examples/servers/Cargo.toml b/examples/servers/Cargo.toml index a544f3ea1..f189c9e7f 100644 --- a/examples/servers/Cargo.toml +++ b/examples/servers/Cargo.toml @@ -13,6 +13,7 @@ rmcp = { workspace = true, features = [ "transport-streamable-http-server", "auth", "elicitation", + "request-state", "schemars", ] } tokio = { version = "1", features = [ @@ -113,3 +114,7 @@ path = "src/elicitation_enum_inference.rs" [[example]] name = "servers_task_stdio" path = "src/task_stdio.rs" + +[[example]] +name = "servers_mrtr" +path = "src/mrtr.rs" diff --git a/examples/servers/README.md b/examples/servers/README.md index 69126519e..a6f0dcf76 100644 --- a/examples/servers/README.md +++ b/examples/servers/README.md @@ -72,6 +72,17 @@ A minimal stdio server demonstrating task-based tool invocation per - Wires up `enqueue_task` / `tasks/get` / `tasks/result` / `tasks/cancel` via `#[task_handler]` - Pair with `examples/clients/src/task_stdio.rs` to see the full lifecycle (create → poll → fetch result) +### MRTR Demo (`mrtr.rs`) + +An end-to-end walkthrough of SEP-2322 Multi Round-Trip Requests, running a +server and client in one process over an in-memory stream. + +- Server answers `tools/call` with an `InputRequiredResult` asking the client to elicit a value +- Client uses `call_tool` to auto-fulfil the elicitation and retry, then `call_tool_once` for manual control +- Seals/opens the untrusted `requestState` with `RequestStateCodec` (HMAC integrity) +- Both sides negotiate `2026-07-28`, the minimum version for MRTR +- Run with `cargo run -p mcp-server-examples --example servers_mrtr` + ### Progress Demo Server (`progress_demo.rs`) A server that demonstrates progress notifications during long-running operations. diff --git a/examples/servers/src/mrtr.rs b/examples/servers/src/mrtr.rs new file mode 100644 index 000000000..40b53c3e3 --- /dev/null +++ b/examples/servers/src/mrtr.rs @@ -0,0 +1,199 @@ +//! SEP-2322 Multi Round-Trip Request (MRTR) end-to-end example. +//! +//! This runs a server and a client in the same process, connected over an +//! in-memory duplex stream, to show the full MRTR flow: +//! +//! * The **server** answers `tools/call` with an [`InputRequiredResult`] instead +//! of a final result, asking the client to elicit a value first. It stores its +//! progress in an opaque, integrity-protected `requestState` produced by a +//! [`RequestStateCodec`], and verifies that state when the client retries. +//! * The **client** uses the high-level [`RunningService::call_tool`] helper, +//! which automatically fulfils the elicitation through the local +//! [`ClientHandler`] and retries the original request. The example then repeats +//! the call with [`RunningService::call_tool_once`] to show the manual escape +//! hatch that returns the intermediate result without retrying. +//! +//! ## Version gating +//! +//! `InputRequiredResult` is only valid once the peers have negotiated protocol +//! version `2026-07-28` or newer. Both sides advertise that version below. If a +//! server emits an `InputRequiredResult` to an older client, the SDK turns it +//! into a protocol error instead of sending it on the wire. +//! +//! ## `requestState` is untrusted input +//! +//! The client echoes `requestState` back verbatim, so from the server's point of +//! view it is attacker-controlled. A stateless server that puts meaningful data +//! in `requestState` MUST verify it. This example uses [`RequestStateCodec`] to +//! seal and open it with an HMAC tag; tampered values are rejected. +//! +//! Run with: +//! +//! ```sh +//! cargo run -p mcp-server-examples --example servers_mrtr +//! ``` + +use rmcp::{ + ClientHandler, ServerHandler, ServiceExt, + model::*, + service::{RequestContext, RoleClient, RoleServer}, +}; +use serde_json::json; + +/// A stable, high-entropy secret. In a real deployment, load this from your +/// secret manager and keep it out of clients' reach. It must stay constant for +/// the lifetime of any in-flight MRTR exchange. +const REQUEST_STATE_KEY: &[u8] = b"example-request-state-signing-key-32b!"; + +/// A server that needs a city name before it can answer a weather query. +#[derive(Clone)] +struct WeatherServer { + codec: RequestStateCodec, +} + +impl Default for WeatherServer { + fn default() -> Self { + Self { + codec: RequestStateCodec::new(REQUEST_STATE_KEY), + } + } +} + +impl ServerHandler for WeatherServer { + fn get_info(&self) -> ServerInfo { + let mut info = ServerInfo::new(ServerCapabilities::builder().enable_tools().build()); + // MRTR requires 2026-07-28 or newer. + info.protocol_version = ProtocolVersion::V_2026_07_28; + info + } + + async fn call_tool( + &self, + request: CallToolRequestParams, + _context: RequestContext, + ) -> Result { + match request.request_state { + // First round: ask the client to provide a city, and remember where + // we are by sealing our progress into `requestState`. + None => { + let sealed = self + .codec + .seal_json(&json!({ "awaiting": "city" })) + .map_err(|e| ErrorData::internal_error(e.to_string(), None))?; + + let mut input_requests = InputRequests::new(); + input_requests.insert( + "city".to_string(), + InputRequest::Elicitation(ElicitRequest::new( + ElicitRequestParams::FormElicitationParams { + meta: None, + message: "Which city do you want the weather for?".into(), + requested_schema: serde_json::from_value(json!({ + "type": "object", + "properties": { "city": { "type": "string" } }, + "required": ["city"] + })) + .expect("valid schema"), + }, + )), + ); + + Ok(InputRequiredResult::new(Some(input_requests), Some(sealed)).into()) + } + // Retry round: verify the echoed state before trusting it, read the + // elicited city, and return the final result. + Some(sealed) => { + let _state: serde_json::Value = self.codec.open_json(&sealed).map_err(|_| { + ErrorData::invalid_params("tampered or unknown request state", None) + })?; + + let city = request + .input_responses + .as_ref() + .and_then(|r| r.get("city")) + .and_then(|v| v["content"]["city"].as_str()) + .unwrap_or("your area"); + + Ok(CallToolResult::success(vec![ContentBlock::text(format!( + "It is sunny in {city}." + ))]) + .into()) + } + } + } +} + +/// A client that fulfils elicitation requests. A real client would prompt a user. +#[derive(Clone, Default)] +struct InteractiveClient; + +impl ClientHandler for InteractiveClient { + fn get_info(&self) -> ClientInfo { + ClientInfo::new( + ClientCapabilities::builder().enable_elicitation().build(), + Implementation::new("mrtr-example-client", env!("CARGO_PKG_VERSION")), + ) + .with_protocol_version(ProtocolVersion::V_2026_07_28) + } + + async fn create_elicitation( + &self, + request: ElicitRequestParams, + _context: RequestContext, + ) -> Result { + if let ElicitRequestParams::FormElicitationParams { message, .. } = &request { + println!(" [client] server asked: {message}"); + } + // Pretend the user typed "Paris". + Ok(ElicitResult::new(ElicitationAction::Accept).with_content(json!({ "city": "Paris" }))) + } +} + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let (server_transport, client_transport) = tokio::io::duplex(8192); + + // Spin up the server side. + tokio::spawn(async move { + let server = WeatherServer::default() + .serve(server_transport) + .await + .expect("server should start"); + let _ = server.waiting().await; + }); + + // Connect the client (this performs the initialize handshake). + let client = InteractiveClient::default().serve(client_transport).await?; + + // 1. High-level auto mode: the SDK fulfils the elicitation and retries for us. + println!("== auto mode (call_tool) =="); + let result = client + .call_tool(CallToolRequestParams::new("weather")) + .await?; + println!( + " [client] final result: {}\n", + result.content[0].as_text().unwrap().text + ); + + // 2. Manual mode: get the intermediate InputRequiredResult without retrying. + println!("== manual mode (call_tool_once) =="); + match client + .call_tool_once(CallToolRequestParams::new("weather")) + .await? + { + CallToolResponse::InputRequired(input_required) => { + let requests = input_required.input_requests.unwrap_or_default(); + println!( + " [client] server requested {} input(s); handling them yourself is up to you.", + requests.len() + ); + } + CallToolResponse::Complete(result) => { + println!(" [client] completed immediately: {result:?}"); + } + _ => println!(" [client] unhandled response variant"), + } + + client.cancel().await?; + Ok(()) +}