Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions crates/rmcp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ features = [
"auth",
"auth-client-credentials-jwt",
"base64",
"discovery",
"discovery-jws",
"client",
"client-side-sse",
"elicitation",
Expand Down Expand Up @@ -58,6 +60,9 @@ oauth2 = { version = "5.0", optional = true, default-features = false }
# JWT signing for client credentials (private_key_jwt)
jsonwebtoken = { version = "10", optional = true }

# DNS resolver for mcp:// discovery URI resolution
hickory-resolver = { version = "0.26", optional = true, default-features = false, features = ["tokio", "system-config"] }

# for auto generate schema
schemars = { version = "1.0", optional = true, features = ["chrono04"] }

Expand Down Expand Up @@ -185,6 +190,14 @@ auth = ["dep:oauth2", "__reqwest", "dep:url"]
auth-client-credentials-jwt = ["auth", "dep:jsonwebtoken", "uuid"]
schemars = ["dep:schemars"]

# mcp:// discovery URI resolution (draft-serra-mcp-discovery-uri)
discovery = [
"dep:hickory-resolver",
"dep:url",
"__reqwest",
]
discovery-jws = ["discovery", "dep:jsonwebtoken"]

[dev-dependencies]
tokio = { version = "1", features = ["full"] }
schemars = { version = "1.1.0", features = ["chrono04"] }
Expand Down
216 changes: 216 additions & 0 deletions crates/rmcp/src/discovery/dns.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
use async_trait::async_trait;
use jsonwebtoken::jwk::Jwk;

/// Hint extracted from the `_mcp.{host}` TXT record (the DNS "fast mode").
///
/// Per the draft this is advisory only: a `.well-known` manifest, when present,
/// always takes precedence over these values.
#[derive(Debug, Clone, Default, PartialEq, Eq)]
#[non_exhaustive]
pub struct McpDnsHint {
pub src: Option<String>,
pub registry: Option<String>,
pub auth: Option<String>,
}

/// A public key published via the `_mcp-key.{host}` TXT record, used to verify
/// a manifest's detached JWS signature.
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct DnsJwk {
pub kid: String,
pub jwk: Jwk,
}

/// Abstraction over DNS TXT lookups so the resolver can be unit tested without
/// a network. Returns one `String` per TXT record (its data chunks joined).
#[async_trait]
pub trait DnsResolver: Send + Sync {
async fn txt_lookup(&self, name: &str) -> Result<Vec<String>, DnsLookupError>;
}

/// Error returned by [`DnsResolver::txt_lookup`].
#[derive(Debug, thiserror::Error)]
#[error("DNS TXT lookup failed for {name}: {source}")]
#[non_exhaustive]
pub struct DnsLookupError {
pub name: String,
#[source]
pub source: Box<dyn std::error::Error + Send + Sync>,
}

/// Real resolver backed by hickory using the host's `/etc/resolv.conf`.
pub struct HickoryDnsResolver {
inner: hickory_resolver::TokioResolver,
}

impl HickoryDnsResolver {
pub fn from_system() -> Result<Self, DnsLookupError> {
let inner = hickory_resolver::TokioResolver::builder_tokio()
.map_err(|e| DnsLookupError {
name: "<system>".to_string(),
source: e.into(),
})?
.build()
.map_err(|e| DnsLookupError {
name: "<system>".to_string(),
source: e.into(),
})?;
Ok(Self { inner })
}
}

#[async_trait]
impl DnsResolver for HickoryDnsResolver {
async fn txt_lookup(&self, name: &str) -> Result<Vec<String>, DnsLookupError> {
use hickory_resolver::proto::rr::RData;

let lookup = match self.inner.txt_lookup(name).await {
Ok(lookup) => lookup,
// A missing record is a normal, non-fatal outcome for an optional step.
Err(e) if e.is_no_records_found() => return Ok(Vec::new()),
Err(e) => {
return Err(DnsLookupError {
name: name.to_string(),
source: e.into(),
});
}
};

let records = lookup
.answers()
.iter()
.filter_map(|record| match &record.data {
RData::TXT(txt) => {
// DNS TXT records split data into ≤255-byte character-strings.
// They must be concatenated without separators — using
// `TXT::to_string()` joins chunks with spaces, corrupting
// binary-ish data like JWK JSON.
let raw: Vec<u8> = txt.txt_data.iter().flatten().copied().collect();
Some(String::from_utf8_lossy(&raw).into_owned())
}
_ => None,
})
.collect();
Ok(records)
}
}

fn parse_pairs(record: &str) -> Vec<(String, String)> {
record
.split(';')
.filter_map(|segment| {
let segment = segment.trim();
if segment.is_empty() {
return None;
}
let (key, value) = segment.split_once('=')?;
Some((key.trim().to_ascii_lowercase(), value.trim().to_string()))
})
.collect()
}

/// Parse the first valid `v=mcp1` discovery record. `endpoint=` is accepted as a
/// deprecated alias for `src=`.
pub fn parse_mcp_hint(records: &[String]) -> Option<McpDnsHint> {
for record in records {
let pairs = parse_pairs(record);
let version = pairs.iter().find(|(k, _)| k == "v").map(|(_, v)| v);
if version.map(String::as_str) != Some("mcp1") {
continue;
}
let get = |key: &str| pairs.iter().find(|(k, _)| k == key).map(|(_, v)| v.clone());
return Some(McpDnsHint {
src: get("src").or_else(|| get("endpoint")),
registry: get("registry"),
auth: get("auth"),
});
}
None
}

/// Parse every valid `v=mcp1jwk` key record. Malformed JWKs are skipped.
pub fn parse_jwks(records: &[String]) -> Vec<DnsJwk> {
let mut keys = Vec::new();
for record in records {
let pairs = parse_pairs(record);
if pairs
.iter()
.find(|(k, _)| k == "v")
.map(|(_, v)| v.as_str())
!= Some("mcp1jwk")
{
continue;
}
let kid = pairs
.iter()
.find(|(k, _)| k == "kid")
.map(|(_, v)| v.clone());
let jwk_raw = pairs
.iter()
.find(|(k, _)| k == "jwk")
.map(|(_, v)| v.clone());
if let (Some(kid), Some(jwk_raw)) = (kid, jwk_raw) {
if let Ok(jwk) = serde_json::from_str::<Jwk>(&jwk_raw) {
keys.push(DnsJwk { kid, jwk });
}
}
}
keys
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn parses_src_hint() {
let hint =
parse_mcp_hint(&["v=mcp1; src=https://api.example.com/mcp".to_string()]).unwrap();
assert_eq!(hint.src.as_deref(), Some("https://api.example.com/mcp"));
assert_eq!(hint.registry, None);
}

#[test]
fn endpoint_is_alias_for_src() {
let hint =
parse_mcp_hint(&["v=mcp1; endpoint=https://example.com/mcp; auth=oauth2".to_string()])
.unwrap();
assert_eq!(hint.src.as_deref(), Some("https://example.com/mcp"));
assert_eq!(hint.auth.as_deref(), Some("oauth2"));
}

#[test]
fn ignores_non_mcp_txt_records() {
let records = vec![
"v=spf1 include:_spf.example.com ~all".to_string(),
"v=mcp1; registry=https://registry.example.com".to_string(),
];
let hint = parse_mcp_hint(&records).unwrap();
assert_eq!(
hint.registry.as_deref(),
Some("https://registry.example.com")
);
assert_eq!(hint.src, None);
}

#[test]
fn jwk_record_without_jwk_is_skipped() {
assert!(parse_jwks(&["v=mcp1jwk; kid=mcp-key-1".to_string()]).is_empty());
}

#[test]
fn parses_multi_chunk_jwk() {
// Simulate a JWK that was split across two DNS TXT character-strings
// (≤255 bytes each). This verifies that chunk concatenation works
// correctly and does not insert spurious separators.
let part1 = r#"v=mcp1jwk;kid=k1;jwk={"kty":"EC","crv":"P-256","x":"AAAAAAAAAAAAAAAA"#;
let part2 =
r#"AAAAAAAAAAAAAAAAAAAAAAAAAA","y":"BBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBBB"}"#;
let combined = format!("{part1}{part2}");
let keys = parse_jwks(&[combined]);
assert_eq!(keys.len(), 1);
assert_eq!(keys[0].kid, "k1");
assert_eq!(keys[0].jwk.common.key_algorithm, None);
}
}
41 changes: 41 additions & 0 deletions crates/rmcp/src/discovery/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
use thiserror::Error;

/// Errors that can occur while resolving an `mcp://` discovery URI.
#[derive(Debug, Error)]
#[non_exhaustive]
pub enum DiscoveryError {
#[error("invalid mcp discovery URI: {0}")]
InvalidUri(String),

#[error("no MCP server could be discovered for host {0}")]
NotFound(String),

#[error("manifest at {url} is malformed: {reason}")]
MalformedManifest { url: String, reason: String },

#[error(
"endpoint host {endpoint_host} is not the discovery host {discovery_host} or a subdomain of it"
)]
EndpointHostMismatch {
endpoint_host: String,
discovery_host: String,
},

#[error("insecure endpoint {0}: discovery requires https")]
InsecureEndpoint(String),

#[error("unsupported transport {0:?}: only \"http\" is supported")]
UnsupportedTransport(String),

#[error("manifest signature verification failed: {0}")]
SignatureVerification(String),

#[error("network error talking to {url}: {source}")]
Network {
url: String,
#[source]
source: Box<dyn std::error::Error + Send + Sync>,
},
}

pub type Result<T> = std::result::Result<T, DiscoveryError>;
Loading