diff --git a/crates/rmcp/Cargo.toml b/crates/rmcp/Cargo.toml index 9704bfc7..9ec92f73 100644 --- a/crates/rmcp/Cargo.toml +++ b/crates/rmcp/Cargo.toml @@ -18,6 +18,8 @@ features = [ "auth", "auth-client-credentials-jwt", "base64", + "discovery", + "discovery-jws", "client", "client-side-sse", "elicitation", @@ -58,6 +60,9 @@ oauth2 = { version = "5.0", optional = true, default-features = false } # JWT signing for client credentials (private_key_jwt) jsonwebtoken = { version = "10", optional = true } +# DNS resolver for mcp:// discovery URI resolution +hickory-resolver = { version = "0.26", optional = true, default-features = false, features = ["tokio", "system-config"] } + # for auto generate schema schemars = { version = "1.0", optional = true, features = ["chrono04"] } @@ -185,6 +190,14 @@ auth = ["dep:oauth2", "__reqwest", "dep:url"] auth-client-credentials-jwt = ["auth", "dep:jsonwebtoken", "uuid"] schemars = ["dep:schemars"] +# mcp:// discovery URI resolution (draft-serra-mcp-discovery-uri) +discovery = [ + "dep:hickory-resolver", + "dep:url", + "__reqwest", +] +discovery-jws = ["discovery", "dep:jsonwebtoken"] + [dev-dependencies] tokio = { version = "1", features = ["full"] } schemars = { version = "1.1.0", features = ["chrono04"] } diff --git a/crates/rmcp/src/discovery/dns.rs b/crates/rmcp/src/discovery/dns.rs new file mode 100644 index 00000000..bb1e5258 --- /dev/null +++ b/crates/rmcp/src/discovery/dns.rs @@ -0,0 +1,216 @@ +use async_trait::async_trait; +use jsonwebtoken::jwk::Jwk; + +/// Hint extracted from the `_mcp.{host}` TXT record (the DNS "fast mode"). +/// +/// Per the draft this is advisory only: a `.well-known` manifest, when present, +/// always takes precedence over these values. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +#[non_exhaustive] +pub struct McpDnsHint { + pub src: Option, + pub registry: Option, + pub auth: Option, +} + +/// A public key published via the `_mcp-key.{host}` TXT record, used to verify +/// a manifest's detached JWS signature. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub struct DnsJwk { + pub kid: String, + pub jwk: Jwk, +} + +/// Abstraction over DNS TXT lookups so the resolver can be unit tested without +/// a network. Returns one `String` per TXT record (its data chunks joined). +#[async_trait] +pub trait DnsResolver: Send + Sync { + async fn txt_lookup(&self, name: &str) -> Result, DnsLookupError>; +} + +/// Error returned by [`DnsResolver::txt_lookup`]. +#[derive(Debug, thiserror::Error)] +#[error("DNS TXT lookup failed for {name}: {source}")] +#[non_exhaustive] +pub struct DnsLookupError { + pub name: String, + #[source] + pub source: Box, +} + +/// Real resolver backed by hickory using the host's `/etc/resolv.conf`. +pub struct HickoryDnsResolver { + inner: hickory_resolver::TokioResolver, +} + +impl HickoryDnsResolver { + pub fn from_system() -> Result { + let inner = hickory_resolver::TokioResolver::builder_tokio() + .map_err(|e| DnsLookupError { + name: "".to_string(), + source: e.into(), + })? + .build() + .map_err(|e| DnsLookupError { + name: "".to_string(), + source: e.into(), + })?; + Ok(Self { inner }) + } +} + +#[async_trait] +impl DnsResolver for HickoryDnsResolver { + async fn txt_lookup(&self, name: &str) -> Result, DnsLookupError> { + use hickory_resolver::proto::rr::RData; + + let lookup = match self.inner.txt_lookup(name).await { + Ok(lookup) => lookup, + // A missing record is a normal, non-fatal outcome for an optional step. + Err(e) if e.is_no_records_found() => return Ok(Vec::new()), + Err(e) => { + return Err(DnsLookupError { + name: name.to_string(), + source: e.into(), + }); + } + }; + + let records = lookup + .answers() + .iter() + .filter_map(|record| match &record.data { + RData::TXT(txt) => { + // DNS TXT records split data into ≤255-byte character-strings. + // They must be concatenated without separators — using + // `TXT::to_string()` joins chunks with spaces, corrupting + // binary-ish data like JWK JSON. + let raw: Vec = txt.txt_data.iter().flatten().copied().collect(); + Some(String::from_utf8_lossy(&raw).into_owned()) + } + _ => None, + }) + .collect(); + Ok(records) + } +} + +fn parse_pairs(record: &str) -> Vec<(String, String)> { + record + .split(';') + .filter_map(|segment| { + let segment = segment.trim(); + if segment.is_empty() { + return None; + } + let (key, value) = segment.split_once('=')?; + Some((key.trim().to_ascii_lowercase(), value.trim().to_string())) + }) + .collect() +} + +/// Parse the first valid `v=mcp1` discovery record. `endpoint=` is accepted as a +/// deprecated alias for `src=`. +pub fn parse_mcp_hint(records: &[String]) -> Option { + for record in records { + let pairs = parse_pairs(record); + let version = pairs.iter().find(|(k, _)| k == "v").map(|(_, v)| v); + if version.map(String::as_str) != Some("mcp1") { + continue; + } + let get = |key: &str| pairs.iter().find(|(k, _)| k == key).map(|(_, v)| v.clone()); + return Some(McpDnsHint { + src: get("src").or_else(|| get("endpoint")), + registry: get("registry"), + auth: get("auth"), + }); + } + None +} + +/// Parse every valid `v=mcp1jwk` key record. Malformed JWKs are skipped. +pub fn parse_jwks(records: &[String]) -> Vec { + let mut keys = Vec::new(); + for record in records { + let pairs = parse_pairs(record); + if pairs + .iter() + .find(|(k, _)| k == "v") + .map(|(_, v)| v.as_str()) + != Some("mcp1jwk") + { + continue; + } + let kid = pairs + .iter() + .find(|(k, _)| k == "kid") + .map(|(_, v)| v.clone()); + let jwk_raw = pairs + .iter() + .find(|(k, _)| k == "jwk") + .map(|(_, v)| v.clone()); + if let (Some(kid), Some(jwk_raw)) = (kid, jwk_raw) { + if let Ok(jwk) = serde_json::from_str::(&jwk_raw) { + keys.push(DnsJwk { kid, jwk }); + } + } + } + keys +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parses_src_hint() { + let hint = + parse_mcp_hint(&["v=mcp1; src=https://api.example.com/mcp".to_string()]).unwrap(); + assert_eq!(hint.src.as_deref(), Some("https://api.example.com/mcp")); + assert_eq!(hint.registry, None); + } + + #[test] + fn endpoint_is_alias_for_src() { + let hint = + parse_mcp_hint(&["v=mcp1; endpoint=https://example.com/mcp; auth=oauth2".to_string()]) + .unwrap(); + assert_eq!(hint.src.as_deref(), Some("https://example.com/mcp")); + assert_eq!(hint.auth.as_deref(), Some("oauth2")); + } + + #[test] + fn ignores_non_mcp_txt_records() { + let records = vec![ + "v=spf1 include:_spf.example.com ~all".to_string(), + "v=mcp1; registry=https://registry.example.com".to_string(), + ]; + let hint = parse_mcp_hint(&records).unwrap(); + assert_eq!( + hint.registry.as_deref(), + Some("https://registry.example.com") + ); + assert_eq!(hint.src, None); + } + + #[test] + fn jwk_record_without_jwk_is_skipped() { + assert!(parse_jwks(&["v=mcp1jwk; kid=mcp-key-1".to_string()]).is_empty()); + } + + #[test] + fn parses_multi_chunk_jwk() { + // Simulate a JWK that was split across two DNS TXT character-strings + // (≤255 bytes each). This verifies that chunk concatenation works + // correctly and does not insert spurious separators. + let part1 = r#"v=mcp1jwk;kid=k1;jwk={"kty":"EC","crv":"P-256","x":"AAAAAAAAAAAAAAAA"#; + let part2 = + r#"AAAAAAAAAAAAAAAAAAAAAAAAAA","y":"BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"}"#; + let combined = format!("{part1}{part2}"); + let keys = parse_jwks(&[combined]); + assert_eq!(keys.len(), 1); + assert_eq!(keys[0].kid, "k1"); + assert_eq!(keys[0].jwk.common.key_algorithm, None); + } +} diff --git a/crates/rmcp/src/discovery/error.rs b/crates/rmcp/src/discovery/error.rs new file mode 100644 index 00000000..d58e5c96 --- /dev/null +++ b/crates/rmcp/src/discovery/error.rs @@ -0,0 +1,41 @@ +use thiserror::Error; + +/// Errors that can occur while resolving an `mcp://` discovery URI. +#[derive(Debug, Error)] +#[non_exhaustive] +pub enum DiscoveryError { + #[error("invalid mcp discovery URI: {0}")] + InvalidUri(String), + + #[error("no MCP server could be discovered for host {0}")] + NotFound(String), + + #[error("manifest at {url} is malformed: {reason}")] + MalformedManifest { url: String, reason: String }, + + #[error( + "endpoint host {endpoint_host} is not the discovery host {discovery_host} or a subdomain of it" + )] + EndpointHostMismatch { + endpoint_host: String, + discovery_host: String, + }, + + #[error("insecure endpoint {0}: discovery requires https")] + InsecureEndpoint(String), + + #[error("unsupported transport {0:?}: only \"http\" is supported")] + UnsupportedTransport(String), + + #[error("manifest signature verification failed: {0}")] + SignatureVerification(String), + + #[error("network error talking to {url}: {source}")] + Network { + url: String, + #[source] + source: Box, + }, +} + +pub type Result = std::result::Result; diff --git a/crates/rmcp/src/discovery/http.rs b/crates/rmcp/src/discovery/http.rs new file mode 100644 index 00000000..b872b8ec --- /dev/null +++ b/crates/rmcp/src/discovery/http.rs @@ -0,0 +1,136 @@ +use std::time::Duration; + +use async_trait::async_trait; + +use super::error::{DiscoveryError, Result}; + +/// Result of fetching a manifest URL. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub enum FetchOutcome { + Found { body: String }, + NotFound, +} + +/// Abstraction over the HTTP calls discovery makes, so resolution can be unit +/// tested without a live server. +#[async_trait] +pub trait ManifestFetcher: Send + Sync { + /// GET a URL. `Found` on a 2xx response, `NotFound` on 404; any other status + /// or transport failure is an error. + async fn get(&self, url: &str) -> Result; + + /// Direct-handshake probe for the fallback step: returns true only if the + /// endpoint responds like an MCP Streamable HTTP server, false otherwise. + async fn probe(&self, url: &str) -> Result; +} + +/// Minimal MCP `initialize` request used to probe an endpoint in the fallback +/// step. A non-MCP host answering on `/mcp` (a 404 page, an SPA, a login form) +/// will not produce a JSON-RPC / SSE response, so it is not mistaken for a +/// server. +const MCP_INITIALIZE_BODY: &str = r#"{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2025-06-18","capabilities":{},"clientInfo":{"name":"rmcp-discovery","version":"0"}}}"#; + +/// `ManifestFetcher` backed by reqwest. +pub struct ReqwestFetcher { + client: reqwest::Client, +} + +impl ReqwestFetcher { + /// Build a fetcher with a caller-supplied reqwest client (consistent with + /// rmcp's OAuth custom-client pattern). + pub fn with_client(client: reqwest::Client) -> Self { + Self { client } + } + + pub fn new(timeout: Duration) -> Result { + let client = reqwest::Client::builder() + .timeout(timeout) + .user_agent(concat!("rmcp-discovery/", env!("CARGO_PKG_VERSION"))) + .build() + .map_err(|e| DiscoveryError::Network { + url: "".to_string(), + source: e.into(), + })?; + Ok(Self { client }) + } +} + +#[async_trait] +impl ManifestFetcher for ReqwestFetcher { + async fn get(&self, url: &str) -> Result { + let resp = self + .client + .get(url) + .send() + .await + .map_err(|e| DiscoveryError::Network { + url: url.to_string(), + source: e.into(), + })?; + if resp.status() == reqwest::StatusCode::NOT_FOUND { + return Ok(FetchOutcome::NotFound); + } + if !resp.status().is_success() { + return Err(DiscoveryError::Network { + url: url.to_string(), + source: format!("unexpected status {} from {url}", resp.status()).into(), + }); + } + let body = resp.text().await.map_err(|e| DiscoveryError::Network { + url: url.to_string(), + source: e.into(), + })?; + Ok(FetchOutcome::Found { body }) + } + + async fn probe(&self, url: &str) -> Result { + let resp = match self + .client + .post(url) + .header(reqwest::header::CONTENT_TYPE, "application/json") + .header( + reqwest::header::ACCEPT, + "application/json, text/event-stream", + ) + .body(MCP_INITIALIZE_BODY) + .send() + .await + { + Ok(resp) => resp, + Err(e) if e.is_connect() || e.is_timeout() => return Ok(false), + Err(e) => { + return Err(DiscoveryError::Network { + url: url.to_string(), + source: e.into(), + }); + } + }; + + if !resp.status().is_success() { + return Ok(false); + } + let content_type = resp + .headers() + .get(reqwest::header::CONTENT_TYPE) + .and_then(|v| v.to_str().ok()) + .unwrap_or("") + .to_ascii_lowercase(); + + // An SSE response to the initialize POST is a strong MCP signal; its body + // is a stream we must not block on reading. + if content_type.contains("text/event-stream") { + return Ok(true); + } + // For a JSON response, confirm it is actually a JSON-RPC message rather + // than a generic API endpoint that happens to answer with JSON (e.g. a + // catch-all returning `{"error": ...}`). + if content_type.contains("application/json") { + let body = resp.text().await.unwrap_or_default(); + if let Ok(value) = serde_json::from_str::(&body) { + return Ok(value.get("jsonrpc").is_some()); + } + } + Ok(false) + } +} diff --git a/crates/rmcp/src/discovery/jws.rs b/crates/rmcp/src/discovery/jws.rs new file mode 100644 index 00000000..ca6d1a9b --- /dev/null +++ b/crates/rmcp/src/discovery/jws.rs @@ -0,0 +1,130 @@ +use jsonwebtoken::{Algorithm, DecodingKey}; +use serde_json::Value; + +use super::dns::DnsJwk; +use super::error::DiscoveryError; +use super::manifest::ManifestSignature; + +/// Verify a manifest's detached JWS signature against the public keys published +/// in DNS. +/// +/// The signature covers the canonical JSON serialization of the manifest with +/// the `signature` field removed (draft-serra-mcp-discovery-uri, Security +/// Considerations). The signed value is a raw base64url signature carried in +/// `signature.value`, with `alg`/`kid` selecting the algorithm and key. +pub fn verify_signature( + raw_manifest_json: &str, + sig: &ManifestSignature, + jwks: &[DnsJwk], +) -> Result<(), DiscoveryError> { + let key_entry = jwks.iter().find(|k| k.kid == sig.kid).ok_or_else(|| { + DiscoveryError::SignatureVerification(format!("no published key matches kid {:?}", sig.kid)) + })?; + + let algorithm = parse_alg(&sig.alg)?; + + let decoding_key = DecodingKey::from_jwk(&key_entry.jwk).map_err(|e| { + DiscoveryError::SignatureVerification(format!("invalid JWK for kid {:?}: {e}", sig.kid)) + })?; + + let payload = canonical_payload(raw_manifest_json)?; + + let verified = jsonwebtoken::crypto::verify(&sig.value, &payload, &decoding_key, algorithm) + .map_err(|e| DiscoveryError::SignatureVerification(format!("verification error: {e}")))?; + + if !verified { + return Err(DiscoveryError::SignatureVerification( + "signature does not match manifest".to_string(), + )); + } + Ok(()) +} + +fn parse_alg(alg: &str) -> Result { + let parsed = match alg { + "RS256" => Algorithm::RS256, + "RS384" => Algorithm::RS384, + "RS512" => Algorithm::RS512, + "PS256" => Algorithm::PS256, + "PS384" => Algorithm::PS384, + "PS512" => Algorithm::PS512, + "ES256" => Algorithm::ES256, + "ES384" => Algorithm::ES384, + "EdDSA" => Algorithm::EdDSA, + // Symmetric algorithms cannot be published as a public key; reject. + other => { + return Err(DiscoveryError::SignatureVerification(format!( + "unsupported signature algorithm {other:?}" + ))); + } + }; + Ok(parsed) +} + +/// Produce the canonical byte payload that the signature is computed over: the +/// manifest object with `signature` removed and object keys sorted. +pub(crate) fn canonical_payload(raw_manifest_json: &str) -> Result, DiscoveryError> { + let mut value: Value = serde_json::from_str(raw_manifest_json).map_err(|e| { + DiscoveryError::SignatureVerification(format!("manifest is not valid JSON: {e}")) + })?; + if let Value::Object(map) = &mut value { + map.remove("signature"); + } + serde_json::to_vec(&canonicalize(&value)).map_err(|e| { + DiscoveryError::SignatureVerification(format!("failed to canonicalize manifest: {e}")) + }) +} + +/// Recursively sort object keys so serialization is deterministic regardless of +/// whether serde_json preserves insertion order in this build. +fn canonicalize(value: &Value) -> Value { + match value { + Value::Object(map) => { + let sorted: std::collections::BTreeMap = map + .iter() + .map(|(k, v)| (k.clone(), canonicalize(v))) + .collect(); + Value::Object(sorted.into_iter().collect()) + } + Value::Array(items) => Value::Array(items.iter().map(canonicalize).collect()), + other => other.clone(), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn canonical_payload_strips_signature_and_sorts_keys() { + let raw = r#"{"name":"x","endpoint":"https://e","signature":{"alg":"ES256","kid":"k","value":"v"}}"#; + let bytes = canonical_payload(raw).unwrap(); + let text = String::from_utf8(bytes).unwrap(); + assert!(!text.contains("signature")); + // keys sorted: endpoint before name + assert!(text.find("endpoint").unwrap() < text.find("name").unwrap()); + } + + #[test] + fn canonical_payload_is_order_independent() { + let a = r#"{"a":1,"b":2}"#; + let b = r#"{"b":2,"a":1}"#; + assert_eq!(canonical_payload(a).unwrap(), canonical_payload(b).unwrap()); + } + + #[test] + fn rejects_symmetric_algorithm() { + assert!(parse_alg("HS256").is_err()); + } + + #[test] + fn missing_key_for_kid_fails() { + let sig = ManifestSignature { + alg: "ES256".into(), + kid: "absent".into(), + value: "AAAA".into(), + }; + let err = verify_signature("{}", &sig, &[]).unwrap_err(); + assert!(matches!(err, DiscoveryError::SignatureVerification(_))); + } +} diff --git a/crates/rmcp/src/discovery/manifest.rs b/crates/rmcp/src/discovery/manifest.rs new file mode 100644 index 00000000..da56e842 --- /dev/null +++ b/crates/rmcp/src/discovery/manifest.rs @@ -0,0 +1,245 @@ +use serde::{Deserialize, Serialize}; + +use super::error::DiscoveryError; + +/// Trust class declared by an MCP server manifest. +/// +/// Per draft-serra-mcp-discovery-uri a missing declaration defaults to the most +/// restrictive safe value, `Public`. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +#[non_exhaustive] +pub enum TrustClass { + #[default] + Public, + Sandbox, + Enterprise, + Regulated, +} + +impl TrustClass { + /// Whether the spec mandates an `auth` declaration for this trust class. + pub fn requires_auth_declaration(self) -> bool { + matches!(self, TrustClass::Enterprise | TrustClass::Regulated) + } +} + +/// Authentication requirements advertised by a manifest. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +#[non_exhaustive] +pub struct AuthRequirements { + #[serde(default)] + pub required: bool, + #[serde(default)] + pub methods: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub endpoint: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub metadata_url: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub apikey_header: Option, + #[serde(default)] + pub scopes: Vec, +} + +/// Detached JWS signature object carried in a manifest. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[non_exhaustive] +pub struct ManifestSignature { + pub alg: String, + pub kid: String, + /// Base64url-encoded detached JWS signature over the canonical JSON + /// serialization of the manifest with the `signature` field removed. + pub value: String, +} + +/// The `.well-known/mcp-server` manifest as defined by the draft. +/// +/// Only the fields the resolver acts on are modelled explicitly; the remaining +/// optional fields are accepted and ignored to stay forward-compatible. +#[non_exhaustive] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct McpServerManifest { + pub mcp_version: String, + pub name: String, + pub endpoint: String, + pub transport: String, + + #[serde(default, skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub auth: Option, + #[serde(default)] + pub capabilities: Vec, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub trust_class: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub signature: Option, +} + +impl McpServerManifest { + /// Trust class, defaulting to the most restrictive value when absent. + pub fn effective_trust_class(&self) -> TrustClass { + self.trust_class.unwrap_or_default() + } + + /// Validate the manifest's internal consistency against the spec's MUST + /// rules. Does not perform host-match or signature checks — those need the + /// discovery host and DNS keys respectively. + pub fn validate(&self, url: &str) -> Result<(), DiscoveryError> { + if self.transport != "http" { + return Err(DiscoveryError::UnsupportedTransport(self.transport.clone())); + } + if self.endpoint.trim().is_empty() { + return Err(DiscoveryError::MalformedManifest { + url: url.to_string(), + reason: "endpoint is empty".to_string(), + }); + } + if self.effective_trust_class().requires_auth_declaration() { + // A protected manifest must declare at least one usable auth method, + // not merely an empty `auth` object. + let has_usable_auth = self + .auth + .as_ref() + .map(|a| a.methods.iter().any(|m| m != "none")) + .unwrap_or(false); + if !has_usable_auth { + return Err(DiscoveryError::MalformedManifest { + url: url.to_string(), + reason: format!( + "trust_class {:?} requires an auth declaration with at least one method", + self.effective_trust_class() + ), + }); + } + } + Ok(()) + } +} + +/// Returns true when `endpoint_host` equals `discovery_host` or is a subdomain +/// of it. Comparison is case-insensitive and a trailing dot is ignored. +pub fn host_matches(discovery_host: &str, endpoint_host: &str) -> bool { + let normalize = |h: &str| h.trim_end_matches('.').to_ascii_lowercase(); + let discovery = normalize(discovery_host); + let endpoint = normalize(endpoint_host); + if discovery.is_empty() || endpoint.is_empty() { + return false; + } + endpoint == discovery || endpoint.ends_with(&format!(".{discovery}")) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn host_match_accepts_exact_and_subdomain() { + assert!(host_matches("example.com", "example.com")); + assert!(host_matches("example.com", "api.example.com")); + assert!(host_matches("example.com", "a.b.example.com")); + assert!(host_matches("Example.com", "API.example.com.")); + } + + #[test] + fn host_match_rejects_unrelated_and_suffix_tricks() { + assert!(!host_matches("example.com", "evil.com")); + assert!(!host_matches("example.com", "notexample.com")); + assert!(!host_matches("example.com", "example.com.evil.com")); + assert!(!host_matches("example.com", "")); + } + + #[test] + fn validate_rejects_non_http_transport() { + let m = McpServerManifest { + mcp_version: "2025-06-18".into(), + name: "x".into(), + endpoint: "https://example.com/mcp".into(), + transport: "stdio".into(), + description: None, + auth: None, + capabilities: vec![], + trust_class: None, + signature: None, + }; + assert!(matches!( + m.validate("u"), + Err(DiscoveryError::UnsupportedTransport(_)) + )); + } + + #[test] + fn validate_requires_auth_for_enterprise() { + let m = McpServerManifest { + mcp_version: "2025-06-18".into(), + name: "x".into(), + endpoint: "https://example.com/mcp".into(), + transport: "http".into(), + description: None, + auth: None, + capabilities: vec![], + trust_class: Some(TrustClass::Enterprise), + signature: None, + }; + assert!(matches!( + m.validate("u"), + Err(DiscoveryError::MalformedManifest { .. }) + )); + } + + #[test] + fn validate_rejects_empty_auth_for_enterprise() { + let m = McpServerManifest { + mcp_version: "2025-06-18".into(), + name: "x".into(), + endpoint: "https://example.com/mcp".into(), + transport: "http".into(), + description: None, + auth: Some(AuthRequirements::default()), + capabilities: vec![], + trust_class: Some(TrustClass::Regulated), + signature: None, + }; + assert!(matches!( + m.validate("u"), + Err(DiscoveryError::MalformedManifest { .. }) + )); + } + + #[test] + fn validate_accepts_enterprise_with_method() { + let m = McpServerManifest { + mcp_version: "2025-06-18".into(), + name: "x".into(), + endpoint: "https://example.com/mcp".into(), + transport: "http".into(), + description: None, + auth: Some(AuthRequirements { + required: true, + methods: vec!["oauth2".into()], + ..Default::default() + }), + capabilities: vec![], + trust_class: Some(TrustClass::Enterprise), + signature: None, + }; + assert!(m.validate("u").is_ok()); + } + + #[test] + fn trust_class_defaults_to_public() { + let m = McpServerManifest { + mcp_version: "2025-06-18".into(), + name: "x".into(), + endpoint: "https://example.com/mcp".into(), + transport: "http".into(), + description: None, + auth: None, + capabilities: vec![], + trust_class: None, + signature: None, + }; + assert_eq!(m.effective_trust_class(), TrustClass::Public); + } +} diff --git a/crates/rmcp/src/discovery/mod.rs b/crates/rmcp/src/discovery/mod.rs new file mode 100644 index 00000000..eef4464e --- /dev/null +++ b/crates/rmcp/src/discovery/mod.rs @@ -0,0 +1,359 @@ +//! Discovery of MCP servers from `mcp://` URIs per +//! [draft-serra-mcp-discovery-uri](https://datatracker.ietf.org/doc/draft-serra-mcp-discovery-uri/). +//! +//! Resolution order: +//! 1. (optional, "fast mode") DNS TXT `_mcp.{host}` for a hint, and +//! `_mcp-key.{host}` for the public key used to verify manifest signatures. +//! 2. (authoritative) `GET https://{host}/.well-known/mcp-server`. +//! 3. (fallback) direct handshake probe at `https://{host}/mcp`. +//! +//! A `.well-known` manifest always takes precedence over DNS hints. All +//! endpoints must be HTTPS and the endpoint host must equal, or be a subdomain +//! of, the discovery host. When a manifest carries a signature it MUST verify. + +mod dns; +mod error; +mod http; +mod jws; +mod manifest; + +use std::time::Duration; + +use serde::{Deserialize, Serialize}; + +pub use dns::{DnsJwk, DnsLookupError, DnsResolver, HickoryDnsResolver, McpDnsHint}; +pub use error::{DiscoveryError, Result}; +pub use http::{FetchOutcome, ManifestFetcher, ReqwestFetcher}; +pub use manifest::{ + AuthRequirements, ManifestSignature, McpServerManifest, TrustClass, host_matches, +}; + +pub const WELL_KNOWN_PATH: &str = "/.well-known/mcp-server"; +pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10); + +/// Which step of the resolution chain produced the final endpoint. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +#[non_exhaustive] +pub enum DiscoverySource { + WellKnown, + DirectFallback, +} + +/// Options controlling resolution behaviour. +#[non_exhaustive] +#[derive(Debug, Clone)] +pub struct DiscoveryOptions { + /// Perform the optional DNS fast-mode step (also required to verify + /// manifest signatures, since the public key is published in DNS). + pub use_dns: bool, + /// Reject manifests that do not carry a verifiable signature. A present + /// signature is always verified regardless of this flag. + pub require_signature: bool, + pub timeout: Duration, +} + +impl Default for DiscoveryOptions { + fn default() -> Self { + Self { + use_dns: true, + require_signature: false, + timeout: DEFAULT_TIMEOUT, + } + } +} + +/// A successfully discovered MCP server. +/// +/// The caller is responsible for confirming the trust/auth posture before +/// connecting to [`endpoint`]. +#[non_exhaustive] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DiscoveredServer { + /// Authority (host and optional port) the `mcp://` URI pointed at. + pub discovery_host: String, + /// Resolved HTTPS MCP endpoint. + pub endpoint: String, + pub manifest: McpServerManifest, + pub source: DiscoverySource, + pub signature_verified: bool, + pub trust_class: TrustClass, + /// True when a DNS `src` hint disagreed with the authoritative manifest + /// endpoint (the manifest wins, but the conflict is surfaced). + pub dns_conflict: bool, +} + +impl DiscoveredServer { + /// Parse the resolved endpoint as a [`url::Url`], suitable for constructing + /// a `StreamableHttpClientTransport`. + pub fn endpoint_url(&self) -> Result { + url::Url::parse(&self.endpoint).map_err(|e| DiscoveryError::MalformedManifest { + url: self.endpoint.clone(), + reason: e.to_string(), + }) + } +} + +/// Entry point for `mcp://` URI resolution. +#[non_exhaustive] +pub struct McpDiscovery; + +impl McpDiscovery { + /// Resolve an `mcp://` URI using the system DNS resolver and a real HTTP client. + pub async fn resolve(uri: &str) -> Result { + Self::resolve_with_options(uri, DiscoveryOptions::default()).await + } + + /// Resolve with explicit options. + pub async fn resolve_with_options( + uri: &str, + opts: DiscoveryOptions, + ) -> Result { + let fetcher = ReqwestFetcher::new(opts.timeout)?; + let dns = if opts.use_dns { + match HickoryDnsResolver::from_system() { + Ok(r) => Some(r), + Err(e) => { + tracing::debug!("DNS resolver unavailable, skipping fast-mode: {e}"); + None + } + } + } else { + None + }; + let dns_ref = dns.as_ref().map(|d| d as &dyn DnsResolver); + resolve_with(uri, dns_ref, &fetcher, &opts).await + } +} + +struct ParsedUri { + /// host without port, used for DNS labels and host-match + host: String, + /// authority (host[:port]) used to build https URLs + authority: String, + /// path component, if the URI carried one (e.g. `mcp://host/shop`); used as + /// the direct-handshake fallback endpoint when present. + path: Option, +} + +fn parse_uri(uri: &str) -> Result { + let parsed = + url::Url::parse(uri).map_err(|e| DiscoveryError::InvalidUri(format!("{uri}: {e}")))?; + if parsed.scheme() != "mcp" { + return Err(DiscoveryError::InvalidUri(format!( + "expected scheme \"mcp\", got {:?}", + parsed.scheme() + ))); + } + let host = parsed + .host_str() + .filter(|h| !h.is_empty()) + .ok_or_else(|| DiscoveryError::InvalidUri(format!("{uri}: missing host")))? + .to_string(); + let authority = match parsed.port() { + Some(port) => format!("{host}:{port}"), + None => host.clone(), + }; + let trimmed_path = parsed.path().trim_end_matches('/'); + let path = (!trimmed_path.is_empty()).then(|| trimmed_path.to_string()); + Ok(ParsedUri { + host, + authority, + path, + }) +} + +/// An `src` hint from DNS is only usable if it is HTTPS and its host equals, or +/// is a subdomain of, the discovery host — the same host-match rule applied to +/// manifest endpoints, so an unauthenticated DNS answer cannot point discovery +/// at an arbitrary host. +fn validated_dns_src(host: &str, dns_hint: &Option) -> Option { + let src = dns_hint.as_ref()?.src.as_ref()?; + let parsed = url::Url::parse(src).ok()?; + if parsed.scheme() != "https" { + return None; + } + if !host_matches(host, parsed.host_str()?) { + return None; + } + Some(src.clone()) +} + +/// Resolve with injectable DNS and HTTP backends (used by tests). +pub async fn resolve_with( + uri: &str, + dns: Option<&dyn DnsResolver>, + fetcher: &dyn ManifestFetcher, + opts: &DiscoveryOptions, +) -> Result { + let ParsedUri { + host, + authority, + path, + } = parse_uri(uri)?; + + // Step 1: optional DNS fast-mode. Failures are non-fatal; absent keys only + // matter if the manifest later claims a signature. + let mut dns_hint: Option = None; + let mut jwks: Vec = Vec::new(); + if let Some(resolver) = dns { + match resolver.txt_lookup(&format!("_mcp.{host}")).await { + Ok(records) => dns_hint = dns::parse_mcp_hint(&records), + Err(e) => tracing::debug!("_mcp.{host} TXT lookup failed: {e}"), + } + match resolver.txt_lookup(&format!("_mcp-key.{host}")).await { + Ok(records) => jwks = dns::parse_jwks(&records), + Err(e) => tracing::debug!("_mcp-key.{host} TXT lookup failed: {e}"), + } + } + + // Step 2: authoritative .well-known manifest. A transport/non-404 error is + // NOT treated as "manifest absent": failing through to the unsigned fallback + // on error would let an on-path attacker who can disrupt (but not break TLS + // on) the well-known request strip a signed manifest's trust posture. Only a + // definitive 404 advances to the fallback. + let well_known_url = format!("https://{authority}{WELL_KNOWN_PATH}"); + match fetcher.get(&well_known_url).await { + Ok(FetchOutcome::Found { body }) => { + return build_from_manifest( + &host, + &authority, + &body, + &well_known_url, + &jwks, + &dns_hint, + opts, + ); + } + Ok(FetchOutcome::NotFound) => {} + Err(e) => { + return Err(e); + } + } + + // Step 3: direct handshake fallback. Candidate endpoints, in priority order: + // 1. a path the caller supplied in the discovery URI (`mcp://host/shop`), + // 2. an `src` hint from the `_mcp.{host}` DNS record (validated: HTTPS and + // host equal-or-subdomain of the discovery host, so an unauthenticated + // DNS answer cannot redirect to an arbitrary host), + // 3. the default `https://{authority}/mcp`. + // The first endpoint that answers a real MCP handshake wins. + let mut candidates: Vec = Vec::new(); + if let Some(p) = &path { + candidates.push(format!("https://{authority}{p}")); + } else { + if let Some(src) = validated_dns_src(&host, &dns_hint) { + candidates.push(src); + } + candidates.push(format!("https://{authority}/mcp")); + } + + let mut endpoint = None; + for candidate in candidates { + let reachable = fetcher.probe(&candidate).await?; + if reachable { + endpoint = Some(candidate); + break; + } + } + let Some(endpoint) = endpoint else { + return Err(DiscoveryError::NotFound(host)); + }; + if opts.require_signature { + return Err(DiscoveryError::SignatureVerification( + "direct-handshake fallback cannot be signed".to_string(), + )); + } + let manifest = McpServerManifest { + mcp_version: String::new(), + name: host.clone(), + endpoint: endpoint.clone(), + transport: "http".to_string(), + description: None, + auth: None, + capabilities: Vec::new(), + trust_class: None, + signature: None, + }; + Ok(DiscoveredServer { + discovery_host: authority, + endpoint, + trust_class: manifest.effective_trust_class(), + manifest, + source: DiscoverySource::DirectFallback, + signature_verified: false, + dns_conflict: false, + }) +} + +fn build_from_manifest( + host: &str, + authority: &str, + body: &str, + url: &str, + jwks: &[DnsJwk], + dns_hint: &Option, + opts: &DiscoveryOptions, +) -> Result { + let manifest: McpServerManifest = + serde_json::from_str(body).map_err(|e| DiscoveryError::MalformedManifest { + url: url.to_string(), + reason: e.to_string(), + })?; + manifest.validate(url)?; + + let endpoint_url = + url::Url::parse(&manifest.endpoint).map_err(|e| DiscoveryError::MalformedManifest { + url: url.to_string(), + reason: format!("invalid endpoint {:?}: {e}", manifest.endpoint), + })?; + if endpoint_url.scheme() != "https" { + return Err(DiscoveryError::InsecureEndpoint(manifest.endpoint.clone())); + } + let endpoint_host = endpoint_url.host_str().unwrap_or_default(); + if !host_matches(host, endpoint_host) { + return Err(DiscoveryError::EndpointHostMismatch { + endpoint_host: endpoint_host.to_string(), + discovery_host: host.to_string(), + }); + } + + let signature_verified = match &manifest.signature { + Some(sig) => { + jws::verify_signature(body, sig, jwks)?; + true + } + None => { + if opts.require_signature { + return Err(DiscoveryError::SignatureVerification( + "manifest is unsigned but a signature is required".to_string(), + )); + } + false + } + }; + + let dns_conflict = dns_hint + .as_ref() + .and_then(|h| h.src.as_deref()) + .map(|src| src != manifest.endpoint) + .unwrap_or(false); + if dns_conflict { + tracing::warn!( + "DNS src hint disagrees with .well-known endpoint for {host}; using manifest endpoint" + ); + } + + Ok(DiscoveredServer { + discovery_host: authority.to_string(), + endpoint: manifest.endpoint.clone(), + trust_class: manifest.effective_trust_class(), + manifest, + source: DiscoverySource::WellKnown, + signature_verified, + dns_conflict, + }) +} + +#[cfg(test)] +mod tests; diff --git a/crates/rmcp/src/discovery/tests.rs b/crates/rmcp/src/discovery/tests.rs new file mode 100644 index 00000000..8a0a7750 --- /dev/null +++ b/crates/rmcp/src/discovery/tests.rs @@ -0,0 +1,443 @@ +use std::collections::HashMap; +use std::net::SocketAddr; +use std::sync::{Arc, Mutex}; + +use async_trait::async_trait; +use axum::Router; +use axum::extract::State; +use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; +use axum::routing::{get, post}; + +use super::*; + +// --------------------------------------------------------------------------- +// Mock implementations +// --------------------------------------------------------------------------- + +/// In-memory DNS resolver: maps a TXT label to its records. +#[derive(Default, Clone)] +struct MockDns { + records: HashMap>, +} + +impl MockDns { + fn with(mut self, name: &str, records: &[&str]) -> Self { + self.records.insert( + name.to_string(), + records.iter().map(|s| s.to_string()).collect(), + ); + self + } +} + +#[async_trait] +impl DnsResolver for MockDns { + async fn txt_lookup(&self, name: &str) -> std::result::Result, DnsLookupError> { + Ok(self.records.get(name).cloned().unwrap_or_default()) + } +} + +/// In-memory HTTP fetcher: maps URL -> body for GET, and a set of reachable +/// endpoints for probe. +#[derive(Default, Clone)] +struct MockHttp { + bodies: HashMap, + reachable: Vec, + probed: Arc>>, +} + +impl MockHttp { + fn with_body(mut self, url: &str, body: &str) -> Self { + self.bodies.insert(url.to_string(), body.to_string()); + self + } + fn reachable(mut self, url: &str) -> Self { + self.reachable.push(url.to_string()); + self + } +} + +#[async_trait] +impl ManifestFetcher for MockHttp { + async fn get(&self, url: &str) -> Result { + match self.bodies.get(url) { + Some(body) => Ok(FetchOutcome::Found { body: body.clone() }), + None => Ok(FetchOutcome::NotFound), + } + } + async fn probe(&self, url: &str) -> Result { + self.probed.lock().unwrap().push(url.to_string()); + Ok(self.reachable.iter().any(|u| u == url)) + } +} + +fn default_opts() -> DiscoveryOptions { + DiscoveryOptions::default() +} + +// --------------------------------------------------------------------------- +// URI parsing tests +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn rejects_non_mcp_scheme() { + let http = MockHttp::default(); + let err = resolve_with("https://example.com", None, &http, &default_opts()) + .await + .unwrap_err(); + assert!(matches!(err, DiscoveryError::InvalidUri(_))); +} + +#[tokio::test] +async fn rejects_missing_host() { + let http = MockHttp::default(); + let err = resolve_with("mcp:example.com", None, &http, &default_opts()) + .await + .unwrap_err(); + assert!(matches!(err, DiscoveryError::InvalidUri(_))); +} + +// --------------------------------------------------------------------------- +// Well-known manifest resolution +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn well_known_manifest_resolves() { + let body = r#"{"mcp_version":"2025-06-18","name":"Example","endpoint":"https://example.com/mcp","transport":"http"}"#; + let http = MockHttp::default().with_body("https://example.com/.well-known/mcp-server", body); + let server = resolve_with("mcp://example.com", None, &http, &default_opts()) + .await + .unwrap(); + assert_eq!(server.source, DiscoverySource::WellKnown); + assert_eq!(server.endpoint, "https://example.com/mcp"); + assert_eq!(server.trust_class, TrustClass::Public); + assert!(!server.signature_verified); + + let url = server.endpoint_url().unwrap(); + assert_eq!(url.as_str(), "https://example.com/mcp"); +} + +#[tokio::test] +async fn subdomain_endpoint_is_allowed() { + let body = r#"{"mcp_version":"1","name":"x","endpoint":"https://api.example.com/mcp","transport":"http"}"#; + let http = MockHttp::default().with_body("https://example.com/.well-known/mcp-server", body); + let server = resolve_with("mcp://example.com", None, &http, &default_opts()) + .await + .unwrap(); + assert_eq!(server.endpoint, "https://api.example.com/mcp"); +} + +#[tokio::test] +async fn endpoint_host_mismatch_is_rejected() { + let body = + r#"{"mcp_version":"1","name":"x","endpoint":"https://evil.com/mcp","transport":"http"}"#; + let http = MockHttp::default().with_body("https://example.com/.well-known/mcp-server", body); + let err = resolve_with("mcp://example.com", None, &http, &default_opts()) + .await + .unwrap_err(); + assert!(matches!(err, DiscoveryError::EndpointHostMismatch { .. })); +} + +#[tokio::test] +async fn insecure_endpoint_is_rejected() { + let body = + r#"{"mcp_version":"1","name":"x","endpoint":"http://example.com/mcp","transport":"http"}"#; + let http = MockHttp::default().with_body("https://example.com/.well-known/mcp-server", body); + let err = resolve_with("mcp://example.com", None, &http, &default_opts()) + .await + .unwrap_err(); + assert!(matches!(err, DiscoveryError::InsecureEndpoint(_))); +} + +#[tokio::test] +async fn malformed_manifest_is_rejected() { + let http = + MockHttp::default().with_body("https://example.com/.well-known/mcp-server", "not json"); + let err = resolve_with("mcp://example.com", None, &http, &default_opts()) + .await + .unwrap_err(); + assert!(matches!(err, DiscoveryError::MalformedManifest { .. })); +} + +// --------------------------------------------------------------------------- +// Direct handshake fallback +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn falls_back_to_direct_handshake() { + let http = MockHttp::default().reachable("https://example.com/mcp"); + let server = resolve_with("mcp://example.com", None, &http, &default_opts()) + .await + .unwrap(); + assert_eq!(server.source, DiscoverySource::DirectFallback); + assert_eq!(server.endpoint, "https://example.com/mcp"); +} + +#[tokio::test] +async fn fallback_preserves_uri_path() { + let http = MockHttp::default().reachable("https://example.com/custom"); + let server = resolve_with("mcp://example.com/custom", None, &http, &default_opts()) + .await + .unwrap(); + assert_eq!(server.source, DiscoverySource::DirectFallback); + assert_eq!(server.endpoint, "https://example.com/custom"); +} + +#[tokio::test] +async fn fallback_honors_validated_dns_src() { + // No well-known manifest, but DNS advertises a same-domain src endpoint. + let http = MockHttp::default().reachable("https://api.example.com/mcp"); + let dns = MockDns::default().with( + "_mcp.example.com", + &["v=mcp1; src=https://api.example.com/mcp"], + ); + let server = resolve_with("mcp://example.com", Some(&dns), &http, &default_opts()) + .await + .unwrap(); + assert_eq!(server.source, DiscoverySource::DirectFallback); + assert_eq!(server.endpoint, "https://api.example.com/mcp"); +} + +#[tokio::test] +async fn fallback_ignores_cross_host_dns_src() { + // A cross-host src (a spoofed DNS answer) must be ignored; discovery falls + // back to the default /mcp on the discovery host instead. + let http = MockHttp::default().reachable("https://example.com/mcp"); + let dns = MockDns::default().with("_mcp.example.com", &["v=mcp1; src=https://evil.com/mcp"]); + let server = resolve_with("mcp://example.com", Some(&dns), &http, &default_opts()) + .await + .unwrap(); + assert_eq!(server.endpoint, "https://example.com/mcp"); +} + +#[tokio::test] +async fn no_server_anywhere_is_not_found() { + let http = MockHttp::default(); + let err = resolve_with("mcp://example.com", None, &http, &default_opts()) + .await + .unwrap_err(); + assert!(matches!(err, DiscoveryError::NotFound(_))); +} + +#[tokio::test] +async fn port_is_carried_into_urls() { + let body = r#"{"mcp_version":"1","name":"x","endpoint":"https://example.com:8080/mcp","transport":"http"}"#; + let http = + MockHttp::default().with_body("https://example.com:8080/.well-known/mcp-server", body); + let server = resolve_with("mcp://example.com:8080", None, &http, &default_opts()) + .await + .unwrap(); + assert_eq!(server.endpoint, "https://example.com:8080/mcp"); +} + +// --------------------------------------------------------------------------- +// Signature verification +// --------------------------------------------------------------------------- + +#[tokio::test] +async fn signed_manifest_without_published_key_is_rejected() { + let signed = r#"{"mcp_version":"1","name":"x","endpoint":"https://example.com/mcp","transport":"http","signature":{"alg":"ES256","kid":"unknown","value":"AAAA"}}"#; + let http = MockHttp::default().with_body("https://example.com/.well-known/mcp-server", signed); + // DNS has no key record. + let dns = MockDns::default(); + let err = resolve_with("mcp://example.com", Some(&dns), &http, &default_opts()) + .await + .unwrap_err(); + assert!(matches!(err, DiscoveryError::SignatureVerification(_))); +} + +#[tokio::test] +async fn require_signature_rejects_unsigned_manifest() { + let body = + r#"{"mcp_version":"1","name":"x","endpoint":"https://example.com/mcp","transport":"http"}"#; + let http = MockHttp::default().with_body("https://example.com/.well-known/mcp-server", body); + let opts = DiscoveryOptions { + require_signature: true, + ..DiscoveryOptions::default() + }; + let err = resolve_with("mcp://example.com", None, &http, &opts) + .await + .unwrap_err(); + assert!(matches!(err, DiscoveryError::SignatureVerification(_))); +} + +#[tokio::test] +async fn dns_conflict_is_flagged_but_manifest_wins() { + let body = + r#"{"mcp_version":"1","name":"x","endpoint":"https://example.com/mcp","transport":"http"}"#; + let http = MockHttp::default().with_body("https://example.com/.well-known/mcp-server", body); + let dns = MockDns::default().with( + "_mcp.example.com", + &["v=mcp1; src=https://example.com/other"], + ); + let server = resolve_with("mcp://example.com", Some(&dns), &http, &default_opts()) + .await + .unwrap(); + assert_eq!(server.endpoint, "https://example.com/mcp"); + assert!(server.dns_conflict); +} + +// --------------------------------------------------------------------------- +// End-to-end test with a live axum server (rmcp test convention) +// --------------------------------------------------------------------------- + +/// Test server state: what manifest body to serve, and whether to act as a +/// reachable MCP endpoint. +#[derive(Clone)] +struct TestServerState { + manifest_body: Option, + mcp_reachable: bool, +} + +async fn well_known_handler(State(state): State) -> Response { + match &state.manifest_body { + Some(body) => ( + StatusCode::OK, + [(axum::http::header::CONTENT_TYPE, "application/json")], + body.clone(), + ) + .into_response(), + None => StatusCode::NOT_FOUND.into_response(), + } +} + +async fn mcp_probe_handler(State(state): State) -> Response { + if state.mcp_reachable { + // Return a JSON-RPC response so the probe detects an MCP server. + ( + StatusCode::OK, + [(axum::http::header::CONTENT_TYPE, "application/json")], + r#"{"jsonrpc":"2.0","id":1,"result":{}}"#, + ) + .into_response() + } else { + StatusCode::NOT_FOUND.into_response() + } +} + +async fn start_test_server(state: TestServerState) -> String { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let app = Router::new() + .route("/.well-known/mcp-server", get(well_known_handler)) + .route("/mcp", post(mcp_probe_handler)) + .with_state(state); + + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + format!("http://{addr}") +} + +#[tokio::test] +#[ignore = "requires real HTTPS server with TLS certificates"] +async fn e2e_well_known_manifest_resolution() { + let port = pick_ephemeral_port().await; + let base = format!("http://127.0.0.1:{port}"); + let manifest = format!( + r#"{{"mcp_version":"1","name":"TestServer","endpoint":"{base}/mcp","transport":"http"}}"# + ); + + let state = TestServerState { + manifest_body: Some(manifest), + mcp_reachable: true, + }; + let _ = start_test_server_on_port(port, state).await; + + let discovery_uri = format!("mcp://127.0.0.1:{port}"); + let server = McpDiscovery::resolve_with_options( + &discovery_uri, + DiscoveryOptions { + use_dns: false, + ..DiscoveryOptions::default() + }, + ) + .await + .unwrap(); + + assert_eq!(server.source, DiscoverySource::WellKnown); + assert_eq!(server.endpoint, format!("{base}/mcp")); + assert_eq!(server.manifest.name, "TestServer"); +} + +#[tokio::test] +#[ignore = "requires real HTTPS server with TLS certificates"] +async fn e2e_direct_fallback_resolution() { + let port = pick_ephemeral_port().await; + let base = format!("http://127.0.0.1:{port}"); + + let state = TestServerState { + manifest_body: None, // no well-known → fallback + mcp_reachable: true, + }; + let _ = start_test_server_on_port(port, state).await; + + let discovery_uri = format!("mcp://127.0.0.1:{port}"); + let server = McpDiscovery::resolve_with_options( + &discovery_uri, + DiscoveryOptions { + use_dns: false, + ..DiscoveryOptions::default() + }, + ) + .await + .unwrap(); + + assert_eq!(server.source, DiscoverySource::DirectFallback); + assert_eq!(server.endpoint, format!("{base}/mcp")); +} + +#[tokio::test] +#[ignore = "requires real HTTPS server with TLS certificates"] +async fn e2e_no_server_found() { + let port = pick_ephemeral_port().await; + + let state = TestServerState { + manifest_body: None, + mcp_reachable: false, + }; + let _ = start_test_server_on_port(port, state).await; + + let discovery_uri = format!("mcp://127.0.0.1:{port}"); + let err = McpDiscovery::resolve_with_options( + &discovery_uri, + DiscoveryOptions { + use_dns: false, + ..DiscoveryOptions::default() + }, + ) + .await + .unwrap_err(); + + assert!(matches!(err, DiscoveryError::NotFound(_))); +} + +// --------------------------------------------------------------------------- +// E2E helpers +// --------------------------------------------------------------------------- + +async fn pick_ephemeral_port() -> u16 { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + drop(listener); + port +} + +async fn start_test_server_on_port(port: u16, state: TestServerState) -> SocketAddr { + let addr: SocketAddr = format!("127.0.0.1:{port}").parse().unwrap(); + let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); + + let app = Router::new() + .route("/.well-known/mcp-server", get(well_known_handler)) + .route("/mcp", post(mcp_probe_handler)) + .with_state(state); + + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + addr +} diff --git a/crates/rmcp/src/lib.rs b/crates/rmcp/src/lib.rs index 9ae3f958..082a70e5 100644 --- a/crates/rmcp/src/lib.rs +++ b/crates/rmcp/src/lib.rs @@ -29,6 +29,10 @@ pub mod task_manager; #[cfg(any(feature = "client", feature = "server"))] pub mod transport; +#[cfg(feature = "discovery")] +#[cfg_attr(docsrs, doc(cfg(feature = "discovery")))] +pub mod discovery; + // re-export #[cfg(all(feature = "macros", feature = "server"))] pub use pastey::paste;