From cc8f5a9235ef3a6dc55cd1f85ec4f2ff2c3fa333 Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Thu, 25 Jun 2026 17:05:52 -0700 Subject: [PATCH 01/16] feat(middleware): add in-process supervisor middleware Signed-off-by: Piotr Mlocek --- Cargo.lock | 15 + Cargo.toml | 1 + crates/openshell-core/src/proto/mod.rs | 14 + crates/openshell-policy/Cargo.toml | 1 + crates/openshell-policy/src/compose.rs | 1 + crates/openshell-policy/src/lib.rs | 220 +++++- crates/openshell-policy/src/merge.rs | 24 + crates/openshell-providers/src/profiles.rs | 2 + .../Cargo.toml | 25 + .../src/builtins/mod.rs | 4 + .../src/builtins/secrets.rs | 83 +++ .../src/lib.rs | 437 ++++++++++++ .../src/service.rs | 75 ++ .../openshell-supervisor-network/Cargo.toml | 2 + .../data/sandbox-policy.rego | 22 +- .../src/l7/relay.rs | 478 ++++++++++++- .../src/l7/rest.rs | 122 ++++ .../openshell-supervisor-network/src/opa.rs | 638 +++++++++++++++++- .../src/policy_local.rs | 5 + .../openshell-supervisor-network/src/proxy.rs | 2 + proto/middleware.proto | 95 +++ proto/sandbox.proto | 29 +- 22 files changed, 2278 insertions(+), 17 deletions(-) create mode 100644 crates/openshell-supervisor-middleware/Cargo.toml create mode 100644 crates/openshell-supervisor-middleware/src/builtins/mod.rs create mode 100644 crates/openshell-supervisor-middleware/src/builtins/secrets.rs create mode 100644 crates/openshell-supervisor-middleware/src/lib.rs create mode 100644 crates/openshell-supervisor-middleware/src/service.rs create mode 100644 proto/middleware.proto diff --git a/Cargo.lock b/Cargo.lock index c86773bb7..4b43f48c5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3769,6 +3769,7 @@ version = "0.0.0" dependencies = [ "miette", "openshell-core", + "prost-types", "serde", "serde_json", "serde_yml", @@ -3924,6 +3925,18 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "openshell-supervisor-middleware" +version = "0.0.0" +dependencies = [ + "miette", + "openshell-core", + "prost-types", + "regex", + "tokio", + "tonic", +] + [[package]] name = "openshell-supervisor-network" version = "0.0.0" @@ -3946,6 +3959,8 @@ dependencies = [ "openshell-ocsf", "openshell-policy", "openshell-router", + "openshell-supervisor-middleware", + "prost-types", "rcgen", "regorus", "reqwest 0.12.28", diff --git a/Cargo.toml b/Cargo.toml index f450cd5c8..fd3641d68 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -74,6 +74,7 @@ serde_yml = "0.0.12" toml = "0.8" apollo-parser = "0.8.5" tower-mcp-types = "0.12.0" +regex = "1" # HTTP client reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls-native-roots"] } diff --git a/crates/openshell-core/src/proto/mod.rs b/crates/openshell-core/src/proto/mod.rs index 08b062d2e..1ac6fc94c 100644 --- a/crates/openshell-core/src/proto/mod.rs +++ b/crates/openshell-core/src/proto/mod.rs @@ -79,8 +79,22 @@ pub mod inference { } } +#[allow( + clippy::all, + clippy::pedantic, + clippy::nursery, + unused_qualifications, + rust_2018_idioms +)] +pub mod middleware { + pub mod v1 { + include!(concat!(env!("OUT_DIR"), "/openshell.middleware.v1.rs")); + } +} + pub use datamodel::v1::*; pub use inference::v1::*; +pub use middleware::v1::*; pub use openshell::*; pub use sandbox::v1::*; pub use test::ObjectForTest; diff --git a/crates/openshell-policy/Cargo.toml b/crates/openshell-policy/Cargo.toml index 16719de13..50bea5b32 100644 --- a/crates/openshell-policy/Cargo.toml +++ b/crates/openshell-policy/Cargo.toml @@ -12,6 +12,7 @@ repository.workspace = true [dependencies] openshell-core = { path = "../openshell-core", default-features = false } +prost-types = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } serde_yml = { workspace = true } diff --git a/crates/openshell-policy/src/compose.rs b/crates/openshell-policy/src/compose.rs index 7ca8584d9..1ad0d4617 100644 --- a/crates/openshell-policy/src/compose.rs +++ b/crates/openshell-policy/src/compose.rs @@ -115,6 +115,7 @@ mod tests { ..Default::default() }], binaries: Vec::new(), + middleware: Vec::new(), } } diff --git a/crates/openshell-policy/src/lib.rs b/crates/openshell-policy/src/lib.rs index 9d5dc5b25..6ccfd1158 100644 --- a/crates/openshell-policy/src/lib.rs +++ b/crates/openshell-policy/src/lib.rs @@ -19,8 +19,8 @@ use std::path::Path; use miette::{IntoDiagnostic, Result, WrapErr}; use openshell_core::proto::{ FilesystemPolicy, GraphqlOperation, L7Allow, L7DenyRule, L7QueryMatcher, L7Rule, - LandlockPolicy, McpOptions, NetworkBinary, NetworkEndpoint, NetworkPolicyRule, ProcessPolicy, - SandboxPolicy, + LandlockPolicy, MiddlewareEndpointSelector, NetworkBinary, NetworkEndpoint, + NetworkMiddlewareConfig, NetworkPolicyRule, ProcessPolicy, SandboxPolicy, McpOptions, }; use serde::{Deserialize, Serialize}; @@ -49,6 +49,8 @@ struct PolicyFile { process: Option, #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] network_policies: BTreeMap, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + network_middlewares: Vec, } #[derive(Debug, Serialize, Deserialize)] @@ -87,6 +89,30 @@ struct NetworkPolicyRuleDef { endpoints: Vec, #[serde(default, skip_serializing_if = "Vec::is_empty")] binaries: Vec, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + middleware: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +struct NetworkMiddlewareConfigDef { + name: String, + middleware: String, + #[serde(default, skip_serializing_if = "BTreeMap::is_empty")] + config: BTreeMap, + #[serde(default, skip_serializing_if = "String::is_empty")] + on_error: String, + #[serde(default, skip_serializing_if = "Option::is_none")] + endpoints: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +struct MiddlewareEndpointSelectorDef { + #[serde(default, skip_serializing_if = "Vec::is_empty")] + include: Vec, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + exclude: Vec, } #[derive(Debug, Serialize, Deserialize)] @@ -148,6 +174,8 @@ struct NetworkEndpointDef { json_rpc: Option, #[serde(default, skip_serializing_if = "Option::is_none")] mcp: Option, + #[serde(default, skip_serializing_if = "Vec::is_empty")] + middleware: Vec, } // Signature dictated by serde's `skip_serializing_if`, which requires `&T`. @@ -672,6 +700,21 @@ fn yaml_mcp_method( } fn to_proto(raw: PolicyFile) -> SandboxPolicy { + let network_middlewares = raw + .network_middlewares + .into_iter() + .map(|mw| NetworkMiddlewareConfig { + name: mw.name, + middleware: mw.middleware, + config: Some(json_map_to_struct(mw.config)), + on_error: mw.on_error, + endpoints: mw.endpoints.map(|selector| MiddlewareEndpointSelector { + include: selector.include, + exclude: selector.exclude, + }), + }) + .collect(); + let network_policies = raw .network_policies .into_iter() @@ -745,6 +788,7 @@ fn to_proto(raw: PolicyFile) -> SandboxPolicy { signing_region: e.signing_region, json_rpc_max_body_bytes: json_rpc_max_body_bytes(&e.json_rpc, &e.mcp), mcp: mcp_options(&e.mcp), + middleware: e.middleware, } }) .collect(), @@ -756,6 +800,7 @@ fn to_proto(raw: PolicyFile) -> SandboxPolicy { ..Default::default() }) .collect(), + middleware: rule.middleware, }; (key, proto_rule) }) @@ -776,6 +821,7 @@ fn to_proto(raw: PolicyFile) -> SandboxPolicy { run_as_group: p.run_as_group, }), network_policies, + network_middlewares, } } @@ -892,6 +938,7 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { signing_region: e.signing_region.clone(), json_rpc, mcp, + middleware: e.middleware.clone(), } }) .collect(), @@ -903,17 +950,103 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { harness: false, }) .collect(), + middleware: rule.middleware.clone(), }; (key.clone(), yaml_rule) }) .collect(); + let network_middlewares = policy + .network_middlewares + .iter() + .map(|mw| NetworkMiddlewareConfigDef { + name: mw.name.clone(), + middleware: mw.middleware.clone(), + config: mw + .config + .as_ref() + .map(struct_to_json_map) + .unwrap_or_default(), + on_error: mw.on_error.clone(), + endpoints: mw + .endpoints + .as_ref() + .map(|selector| MiddlewareEndpointSelectorDef { + include: selector.include.clone(), + exclude: selector.exclude.clone(), + }), + }) + .collect(); + PolicyFile { version: policy.version, filesystem_policy, landlock, process, network_policies, + network_middlewares, + } +} + +fn json_map_to_struct(map: BTreeMap) -> prost_types::Struct { + prost_types::Struct { + fields: map + .into_iter() + .map(|(key, value)| (key, json_to_protobuf_value(value))) + .collect(), + } +} + +fn json_to_protobuf_value(value: serde_json::Value) -> prost_types::Value { + use prost_types::{ListValue, Struct, Value, value::Kind}; + Value { + kind: Some(match value { + serde_json::Value::Null => Kind::NullValue(0), + serde_json::Value::Bool(value) => Kind::BoolValue(value), + serde_json::Value::Number(value) => { + Kind::NumberValue(value.as_f64().unwrap_or_default()) + } + serde_json::Value::String(value) => Kind::StringValue(value), + serde_json::Value::Array(values) => Kind::ListValue(ListValue { + values: values.into_iter().map(json_to_protobuf_value).collect(), + }), + serde_json::Value::Object(values) => Kind::StructValue(Struct { + fields: values + .into_iter() + .map(|(key, value)| (key, json_to_protobuf_value(value))) + .collect(), + }), + }), + } +} + +fn struct_to_json_map(config: &prost_types::Struct) -> BTreeMap { + config + .fields + .iter() + .map(|(key, value)| (key.clone(), protobuf_value_to_json(value))) + .collect() +} + +fn protobuf_value_to_json(value: &prost_types::Value) -> serde_json::Value { + match value.kind.as_ref() { + Some(prost_types::value::Kind::NullValue(_)) | None => serde_json::Value::Null, + Some(prost_types::value::Kind::BoolValue(value)) => serde_json::Value::Bool(*value), + Some(prost_types::value::Kind::NumberValue(value)) => serde_json::Number::from_f64(*value) + .map_or(serde_json::Value::Null, serde_json::Value::Number), + Some(prost_types::value::Kind::StringValue(value)) => { + serde_json::Value::String(value.clone()) + } + Some(prost_types::value::Kind::ListValue(value)) => { + serde_json::Value::Array(value.values.iter().map(protobuf_value_to_json).collect()) + } + Some(prost_types::value::Kind::StructValue(value)) => serde_json::Value::Object( + value + .fields + .iter() + .map(|(key, value)| (key.clone(), protobuf_value_to_json(value))) + .collect(), + ), } } @@ -1029,6 +1162,7 @@ pub fn restrictive_default_policy() -> SandboxPolicy { run_as_group: "sandbox".into(), }), network_policies: HashMap::new(), + network_middlewares: vec![], } } @@ -1399,6 +1533,87 @@ network_policies: assert_eq!(proto2.network_policies["my_api"].name, "my-custom-api-name"); } + #[test] + fn round_trip_preserves_network_middlewares() { + let yaml = r#" +version: 1 +network_middlewares: + - name: global-redactor + middleware: openshell/secrets + on_error: fail_open + endpoints: + include: ["api.example.com", "*.service.test"] + exclude: ["internal.example.com"] + config: + secrets: ["api_key", "authorization"] + service: + mode: redact + max_matches: 2 + - name: endpoint-redactor + middleware: openshell/secrets +network_policies: + api: + name: api + middleware: ["global-redactor"] + endpoints: + - host: api.example.com + port: 443 + protocol: rest + middleware: ["endpoint-redactor"] + binaries: + - path: /usr/bin/curl +"#; + let proto = parse_sandbox_policy(yaml).expect("parse failed"); + assert_eq!(proto.network_middlewares.len(), 2); + assert_eq!(proto.network_middlewares[0].name, "global-redactor"); + assert_eq!(proto.network_middlewares[0].middleware, "openshell/secrets"); + assert_eq!(proto.network_middlewares[0].on_error, "fail_open"); + assert_eq!( + proto.network_middlewares[0] + .endpoints + .as_ref() + .expect("selector") + .include, + vec!["api.example.com", "*.service.test"] + ); + assert_eq!( + proto.network_middlewares[0] + .endpoints + .as_ref() + .expect("selector") + .exclude, + vec!["internal.example.com"] + ); + assert!( + proto.network_middlewares[0] + .config + .as_ref() + .expect("config") + .fields + .contains_key("service") + ); + assert_eq!( + proto.network_policies["api"].middleware, + vec!["global-redactor"] + ); + assert_eq!( + proto.network_policies["api"].endpoints[0].middleware, + vec!["endpoint-redactor"] + ); + + let yaml_out = serialize_sandbox_policy(&proto).expect("serialize failed"); + let reparsed = parse_sandbox_policy(&yaml_out).expect("re-parse failed"); + assert_eq!(reparsed.network_middlewares, proto.network_middlewares); + assert_eq!( + reparsed.network_policies["api"].middleware, + vec!["global-redactor"] + ); + assert_eq!( + reparsed.network_policies["api"].endpoints[0].middleware, + vec!["endpoint-redactor"] + ); + } + #[test] fn restrictive_default_has_no_network_policies() { let policy = restrictive_default_policy(); @@ -1714,6 +1929,7 @@ network_policies: filesystem: None, landlock: None, network_policies: HashMap::new(), + network_middlewares: Vec::new(), }; assert!(validate_sandbox_policy(&policy).is_ok()); } diff --git a/crates/openshell-policy/src/merge.rs b/crates/openshell-policy/src/merge.rs index 04f390198..1c63e6ebc 100644 --- a/crates/openshell-policy/src/merge.rs +++ b/crates/openshell-policy/src/merge.rs @@ -989,6 +989,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }, ); @@ -1007,6 +1008,7 @@ mod tests { path: "/usr/bin/gh".to_string(), ..Default::default() }], + ..Default::default() }; let result = merge_policy( @@ -1035,6 +1037,7 @@ mod tests { name: "existing".to_string(), endpoints: vec![endpoint("api.github.com", 443)], binaries: vec![advisor_binary("/usr/bin/curl")], + ..Default::default() }, ); @@ -1045,6 +1048,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let result = merge_policy( @@ -1076,6 +1080,7 @@ mod tests { ..Default::default() }, ], + ..Default::default() }; let result = merge_policy( @@ -1107,6 +1112,7 @@ mod tests { path: "/usr/bin/python".to_string(), ..Default::default() }], + ..Default::default() }, ); @@ -1120,6 +1126,7 @@ mod tests { ..Default::default() }], binaries: vec![advisor_binary("/usr/bin/python")], + ..Default::default() }; let result = merge_policy( @@ -1447,6 +1454,7 @@ mod tests { path: "/usr/bin/gh".to_string(), ..Default::default() }], + ..Default::default() }, ); @@ -1471,6 +1479,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let merged = merge_policy( @@ -1494,6 +1503,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; // Merge an *unrelated* rule for a different host. The proposed rule @@ -1524,6 +1534,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let mut policy = restrictive_default_policy(); @@ -1536,6 +1547,7 @@ mod tests { path: "/usr/bin/git".to_string(), ..Default::default() }], + ..Default::default() }, ); @@ -1567,6 +1579,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; // Endpoint exists in the policy but with a *different* binary. The @@ -1582,6 +1595,7 @@ mod tests { path: "/usr/bin/git".to_string(), ..Default::default() }], + ..Default::default() }, ); @@ -1618,6 +1632,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let mut policy = restrictive_default_policy(); @@ -1637,6 +1652,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }, ); @@ -1664,6 +1680,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let mut policy = restrictive_default_policy(); @@ -1686,6 +1703,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }, ); @@ -1709,6 +1727,7 @@ mod tests { path: "/usr/bin/git".to_string(), ..Default::default() }], + ..Default::default() }; let merged = merge_policy( @@ -1733,6 +1752,7 @@ mod tests { name: "any_binary_rule".to_string(), endpoints: vec![endpoint("api.github.com", 443)], binaries: vec![], + ..Default::default() }; let mut policy = restrictive_default_policy(); @@ -1745,6 +1765,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }, ); @@ -1802,6 +1823,7 @@ mod tests { path: "/usr/bin/gh".to_string(), ..Default::default() }], + ..Default::default() }; let composed = compose_effective_policy( &SandboxPolicy::default(), @@ -1833,6 +1855,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let result = merge_policy( composed, @@ -1901,6 +1924,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let result = merge_policy( policy, diff --git a/crates/openshell-providers/src/profiles.rs b/crates/openshell-providers/src/profiles.rs index ddfbcaf7d..1eb1b54d2 100644 --- a/crates/openshell-providers/src/profiles.rs +++ b/crates/openshell-providers/src/profiles.rs @@ -450,6 +450,7 @@ impl ProviderTypeProfile { NetworkPolicyRule { name: rule_name.to_string(), endpoints: self.endpoints.iter().map(endpoint_to_proto).collect(), + middleware: Vec::new(), binaries: self.binaries.iter().map(binary_to_proto).collect(), } } @@ -787,6 +788,7 @@ fn endpoint_to_proto(endpoint: &EndpointProfile) -> NetworkEndpoint { request_body_credential_rewrite: endpoint.request_body_credential_rewrite, advisor_proposed: false, persisted_queries: endpoint.persisted_queries.clone(), + middleware: Vec::new(), graphql_persisted_queries: endpoint .graphql_persisted_queries .iter() diff --git a/crates/openshell-supervisor-middleware/Cargo.toml b/crates/openshell-supervisor-middleware/Cargo.toml new file mode 100644 index 000000000..fdaeb2e82 --- /dev/null +++ b/crates/openshell-supervisor-middleware/Cargo.toml @@ -0,0 +1,25 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +[package] +name = "openshell-supervisor-middleware" +description = "In-process supervisor middleware contract and built-ins for OpenShell" +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +rust-version.workspace = true + +[dependencies] +openshell-core = { path = "../openshell-core" } + +miette = { workspace = true } +prost-types = { workspace = true } +regex = { workspace = true } +tonic = { workspace = true } + +[dev-dependencies] +tokio = { workspace = true } + +[lints] +workspace = true diff --git a/crates/openshell-supervisor-middleware/src/builtins/mod.rs b/crates/openshell-supervisor-middleware/src/builtins/mod.rs new file mode 100644 index 000000000..60572d3e8 --- /dev/null +++ b/crates/openshell-supervisor-middleware/src/builtins/mod.rs @@ -0,0 +1,4 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +pub(crate) mod secrets; diff --git a/crates/openshell-supervisor-middleware/src/builtins/secrets.rs b/crates/openshell-supervisor-middleware/src/builtins/secrets.rs new file mode 100644 index 000000000..6c94eb439 --- /dev/null +++ b/crates/openshell-supervisor-middleware/src/builtins/secrets.rs @@ -0,0 +1,83 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::collections::HashMap; + +use miette::{Result, miette}; +use openshell_core::proto::{Decision, Finding, HttpRequestEvaluation, HttpRequestResult}; +use regex::Regex; + +use crate::BUILTIN_SECRETS; + +pub(crate) fn validate_config(config: &prost_types::Struct) -> Result<()> { + let mode = config + .fields + .get("secrets") + .and_then(|value| match value.kind.as_ref() { + Some(prost_types::value::Kind::StringValue(value)) => Some(value.as_str()), + _ => None, + }) + .unwrap_or("redact"); + if mode != "redact" { + return Err(miette!( + "{} only supports config.secrets: redact in phase 1", + BUILTIN_SECRETS + )); + } + Ok(()) +} + +pub(crate) fn evaluate_http_request( + evaluation: &HttpRequestEvaluation, +) -> Result { + let default_config = prost_types::Struct::default(); + validate_config(evaluation.config.as_ref().unwrap_or(&default_config))?; + let text = String::from_utf8(evaluation.body.clone()) + .map_err(|_| miette!("{} requires UTF-8 request bodies", BUILTIN_SECRETS))?; + let (body, count) = redact_common_secrets(&text)?; + let mut result = HttpRequestResult { + decision: Decision::Allow as i32, + reason: String::new(), + body: body.into_bytes(), + has_body: count > 0, + add_headers: HashMap::new(), + findings: Vec::new(), + metadata: HashMap::new(), + }; + if count > 0 { + result.findings.push(Finding { + r#type: "secret.common".into(), + label: "common secret pattern".into(), + count, + confidence: "medium".into(), + severity: "medium".into(), + }); + result + .metadata + .insert("secrets_redacted".into(), count.to_string()); + } + Ok(result) +} + +fn redact_common_secrets(input: &str) -> Result<(String, u32)> { + let patterns = [ + r#"(?i)(api[_-]?key|access[_-]?token|secret|password)(["']?\s*[:=]\s*["'])[^"',\s}]+(["']?)"#, + r#"(sk-[A-Za-z0-9_-]{16,})"#, + ]; + let mut output = input.to_string(); + let mut count = 0u32; + for pattern in patterns { + let regex = Regex::new(pattern).map_err(|e| miette!("{e}"))?; + count = count.saturating_add(regex.find_iter(&output).count() as u32); + output = regex + .replace_all(&output, |captures: ®ex::Captures<'_>| { + if captures.len() >= 4 { + format!("{}{}[REDACTED]{}", &captures[1], &captures[2], &captures[3]) + } else { + "[REDACTED]".to_string() + } + }) + .into_owned(); + } + Ok((output, count)) +} diff --git a/crates/openshell-supervisor-middleware/src/lib.rs b/crates/openshell-supervisor-middleware/src/lib.rs new file mode 100644 index 000000000..7d9161fcf --- /dev/null +++ b/crates/openshell-supervisor-middleware/src/lib.rs @@ -0,0 +1,437 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +//! In-process supervisor middleware chain execution. + +mod builtins; +mod service; + +use std::collections::{BTreeMap, HashMap, HashSet}; +use std::sync::Arc; + +use miette::{Result, miette}; +pub use service::InProcessMiddlewareService; + +use openshell_core::proto::middleware::v1::supervisor_middleware_server::SupervisorMiddleware; +use openshell_core::proto::{ + Decision, Finding, HttpRequestEvaluation, HttpRequestTarget, NetworkMiddlewareConfig, Process, + RequestContext, +}; +use tonic::Request; + +pub const API_VERSION: &str = "openshell.middleware.v1"; +pub const HTTP_REQUEST_OPERATION: &str = "HttpRequest"; +pub const PRE_CREDENTIALS_PHASE: &str = "pre_credentials"; +pub const BUILTIN_SECRETS: &str = "openshell/secrets"; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum OnError { + FailClosed, + FailOpen, +} + +impl OnError { + pub fn parse(value: &str) -> Result { + match value { + "" | "fail_closed" => Ok(Self::FailClosed), + "fail_open" => Ok(Self::FailOpen), + other => Err(miette!( + "invalid middleware on_error '{other}', expected fail_closed or fail_open" + )), + } + } +} + +#[derive(Debug, Clone)] +pub struct ChainEntry { + pub name: String, + pub implementation: String, + pub config: prost_types::Struct, + pub on_error: OnError, +} + +impl TryFrom<&NetworkMiddlewareConfig> for ChainEntry { + type Error = miette::Report; + + fn try_from(value: &NetworkMiddlewareConfig) -> Result { + if value.name.is_empty() { + return Err(miette!("middleware config name cannot be empty")); + } + if value.middleware.is_empty() { + return Err(miette!( + "middleware config '{}' must name an implementation", + value.name + )); + } + Ok(Self { + name: value.name.clone(), + implementation: value.middleware.clone(), + config: value.config.clone().unwrap_or_default(), + on_error: OnError::parse(&value.on_error)?, + }) + } +} + +#[derive(Debug, Clone)] +pub struct HttpRequestInput { + pub request_id: String, + pub sandbox_id: String, + pub binary: String, + pub pid: u32, + pub ancestors: Vec, + pub scheme: String, + pub host: String, + pub port: u16, + pub method: String, + pub path: String, + pub query: String, + pub headers: BTreeMap, + pub body: Vec, +} + +#[derive(Debug, Clone)] +pub struct ChainOutcome { + pub allowed: bool, + pub reason: String, + pub body: Vec, + pub added_headers: BTreeMap, + pub findings: Vec, + pub metadata: BTreeMap>, + pub applied: Vec, +} + +#[derive(Debug, Clone, PartialEq)] +pub struct NamespacedFinding { + pub middleware: String, + pub finding: Finding, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct MiddlewareInvocation { + pub name: String, + pub implementation: String, + pub decision: Decision, + pub transformed: bool, +} + +#[derive(Clone)] +pub struct ChainRunner { + service: Arc, +} + +impl Default for ChainRunner { + fn default() -> Self { + Self::new(Arc::new(InProcessMiddlewareService)) + } +} + +impl ChainRunner { + pub fn new(service: Arc) -> Self { + Self { service } + } + + pub async fn evaluate( + &self, + entries: &[ChainEntry], + input: HttpRequestInput, + ) -> Result { + let mut headers = input.headers.clone(); + let mut body = input.body.clone(); + let mut added_headers = BTreeMap::new(); + let mut findings = Vec::new(); + let mut metadata = BTreeMap::new(); + let mut applied = Vec::new(); + + for entry in entries { + let evaluation = build_evaluation(entry, &input, &headers, &body); + let result = match self + .service + .evaluate_http_request(Request::new(evaluation)) + .await + { + Ok(result) => result.into_inner(), + Err(err) => match entry.on_error { + OnError::FailOpen => { + applied.push(MiddlewareInvocation { + name: entry.name.clone(), + implementation: entry.implementation.clone(), + decision: Decision::Allow, + transformed: false, + }); + continue; + } + OnError::FailClosed => { + return Ok(ChainOutcome { + allowed: false, + reason: format!("middleware_failed: {}", safe_reason(&err.to_string())), + body, + added_headers, + findings, + metadata, + applied, + }); + } + }, + }; + + validate_header_mutations(&headers, &result.add_headers)?; + for (name, value) in &result.add_headers { + headers.insert(name.to_ascii_lowercase(), value.clone()); + added_headers.insert(name.to_ascii_lowercase(), value.clone()); + } + let transformed = result.has_body; + if result.has_body { + body = result.body.clone(); + } + for finding in result.findings { + findings.push(NamespacedFinding { + middleware: entry.name.clone(), + finding, + }); + } + if !result.metadata.is_empty() { + metadata.insert( + entry.name.clone(), + result.metadata.clone().into_iter().collect(), + ); + } + applied.push(MiddlewareInvocation { + name: entry.name.clone(), + implementation: entry.implementation.clone(), + decision: Decision::try_from(result.decision).unwrap_or(Decision::Unspecified), + transformed, + }); + if result.decision == Decision::Deny as i32 { + return Ok(ChainOutcome { + allowed: false, + reason: safe_reason(&result.reason), + body, + added_headers, + findings, + metadata, + applied, + }); + } + } + + Ok(ChainOutcome { + allowed: true, + reason: String::new(), + body, + added_headers, + findings, + metadata, + applied, + }) + } +} + +fn build_evaluation( + entry: &ChainEntry, + input: &HttpRequestInput, + headers: &BTreeMap, + body: &[u8], +) -> HttpRequestEvaluation { + HttpRequestEvaluation { + api_version: API_VERSION.into(), + binding_id: entry.implementation.clone(), + phase: PRE_CREDENTIALS_PHASE.into(), + context: Some(RequestContext { + request_id: input.request_id.clone(), + sandbox_id: input.sandbox_id.clone(), + originating_process: Some(Process { + binary: input.binary.clone(), + pid: input.pid, + ancestors: input.ancestors.clone(), + }), + }), + config: Some(entry.config.clone()), + target: Some(HttpRequestTarget { + scheme: input.scheme.clone(), + host: input.host.clone(), + port: u32::from(input.port), + method: input.method.clone(), + path: input.path.clone(), + query: input.query.clone(), + }), + headers: headers.clone().into_iter().collect(), + body: body.to_vec(), + } +} + +fn validate_header_mutations( + existing_headers: &BTreeMap, + mutations: &HashMap, +) -> Result<()> { + let mut seen = HashSet::new(); + for name in mutations.keys() { + let lower = name.to_ascii_lowercase(); + if !seen.insert(lower.clone()) || existing_headers.contains_key(&lower) { + return Err(miette!( + "middleware cannot rewrite existing header '{name}'" + )); + } + if !is_safe_append_header(&lower) { + return Err(miette!("middleware cannot append unsafe header '{name}'")); + } + } + Ok(()) +} + +fn is_safe_append_header(name: &str) -> bool { + if name.is_empty() + || name.contains(':') + || name.bytes().any(|b| b <= 0x20 || b >= 0x7f) + || matches!( + name, + "authorization" | "cookie" | "host" | "content-length" | "transfer-encoding" + ) + || name.starts_with("x-amz-") + || name.starts_with("x-openshell-credential") + { + return false; + } + name.starts_with("x-openshell-middleware-") +} + +pub(crate) fn safe_reason(reason: &str) -> String { + reason + .chars() + .filter(|ch| ch.is_ascii_alphanumeric() || matches!(ch, '_' | '-' | ':' | ' ')) + .take(160) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + use openshell_core::proto::middleware::v1::supervisor_middleware_server::SupervisorMiddleware; + + fn entry(name: &str, on_error: OnError) -> ChainEntry { + ChainEntry { + name: name.into(), + implementation: BUILTIN_SECRETS.into(), + config: prost_types::Struct { + fields: [( + "secrets".into(), + prost_types::Value { + kind: Some(prost_types::value::Kind::StringValue("redact".into())), + }, + )] + .into_iter() + .collect(), + }, + on_error, + } + } + + fn input(body: &str) -> HttpRequestInput { + HttpRequestInput { + request_id: "req".into(), + sandbox_id: "sbx".into(), + binary: "/usr/bin/curl".into(), + pid: 42, + ancestors: vec![], + scheme: "https".into(), + host: "api.example.com".into(), + port: 443, + method: "POST".into(), + path: "/v1".into(), + query: String::new(), + headers: BTreeMap::new(), + body: body.as_bytes().to_vec(), + } + } + + #[tokio::test] + async fn redacts_common_secret_patterns() { + let outcome = ChainRunner::default() + .evaluate( + &[entry("redact", OnError::FailClosed)], + input(r#"{"api_key":"sk-1234567890abcdef"}"#), + ) + .await + .expect("evaluate"); + assert!(outcome.allowed); + assert_eq!( + String::from_utf8(outcome.body).expect("utf8"), + r#"{"api_key":"[REDACTED]"}"# + ); + assert_eq!(outcome.findings[0].finding.count, 1); + } + + #[tokio::test] + async fn transformed_body_feeds_next_stage() { + let entries = [ + entry("first", OnError::FailClosed), + entry("second", OnError::FailClosed), + ]; + let outcome = ChainRunner::default() + .evaluate(&entries, input(r#"password="top-secret""#)) + .await + .expect("evaluate"); + assert!(outcome.allowed); + assert_eq!( + String::from_utf8(outcome.body).expect("utf8"), + r#"password="[REDACTED]""# + ); + assert_eq!(outcome.applied.len(), 2); + } + + #[tokio::test] + async fn fail_open_allows_unavailable_middleware() { + let unavailable = ChainEntry { + name: "missing".into(), + implementation: "third-party/missing".into(), + config: prost_types::Struct::default(), + on_error: OnError::FailOpen, + }; + let outcome = ChainRunner::default() + .evaluate(&[unavailable], input("hello")) + .await + .expect("evaluate"); + assert!(outcome.allowed); + assert_eq!(outcome.body, b"hello"); + } + + #[tokio::test] + async fn fail_closed_denies_unavailable_middleware() { + let unavailable = ChainEntry { + name: "missing".into(), + implementation: "third-party/missing".into(), + config: prost_types::Struct::default(), + on_error: OnError::FailClosed, + }; + let outcome = ChainRunner::default() + .evaluate(&[unavailable], input("hello")) + .await + .expect("evaluate"); + assert!(!outcome.allowed); + assert!(outcome.reason.starts_with("middleware_failed:")); + } + + #[tokio::test] + async fn in_process_service_describes_builtin_binding() { + let manifest = InProcessMiddlewareService + .describe(Request::new(())) + .await + .expect("describe") + .into_inner(); + assert_eq!(manifest.api_version, API_VERSION); + assert_eq!(manifest.bindings[0].id, BUILTIN_SECRETS); + assert_eq!(manifest.bindings[0].operation, HTTP_REQUEST_OPERATION); + assert_eq!(manifest.bindings[0].phase, PRE_CREDENTIALS_PHASE); + } + + #[test] + fn unsafe_header_mutation_is_rejected() { + let err = validate_header_mutations( + &BTreeMap::new(), + &[("Authorization".into(), "Bearer nope".into())] + .into_iter() + .collect(), + ) + .expect_err("unsafe header"); + assert!(err.to_string().contains("unsafe header")); + } +} diff --git a/crates/openshell-supervisor-middleware/src/service.rs b/crates/openshell-supervisor-middleware/src/service.rs new file mode 100644 index 000000000..31cca5694 --- /dev/null +++ b/crates/openshell-supervisor-middleware/src/service.rs @@ -0,0 +1,75 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use openshell_core::proto::middleware::v1::supervisor_middleware_server::SupervisorMiddleware; +use openshell_core::proto::{ + HttpRequestEvaluation, HttpRequestResult, MiddlewareBinding, MiddlewareManifest, + ValidateConfigRequest, ValidateConfigResponse, +}; +use tonic::{Request, Response, Status}; + +use crate::{ + API_VERSION, BUILTIN_SECRETS, HTTP_REQUEST_OPERATION, PRE_CREDENTIALS_PHASE, builtins, + safe_reason, +}; + +#[derive(Debug, Default)] +pub struct InProcessMiddlewareService; + +#[tonic::async_trait] +impl SupervisorMiddleware for InProcessMiddlewareService { + async fn describe( + &self, + _request: Request<()>, + ) -> Result, Status> { + Ok(Response::new(MiddlewareManifest { + api_version: API_VERSION.into(), + name: "openshell/in-process".into(), + service_version: env!("CARGO_PKG_VERSION").into(), + bindings: vec![MiddlewareBinding { + id: BUILTIN_SECRETS.into(), + operation: HTTP_REQUEST_OPERATION.into(), + phase: PRE_CREDENTIALS_PHASE.into(), + }], + })) + } + + async fn validate_config( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + let config = request.config.unwrap_or_default(); + let validation = match request.binding_id.as_str() { + BUILTIN_SECRETS => builtins::secrets::validate_config(&config), + other => Err(miette::miette!( + "middleware implementation '{other}' is not available in phase 1" + )), + }; + Ok(Response::new(match validation { + Ok(()) => ValidateConfigResponse { + valid: true, + reason: String::new(), + }, + Err(err) => ValidateConfigResponse { + valid: false, + reason: safe_reason(&err.to_string()), + }, + })) + } + + async fn evaluate_http_request( + &self, + request: Request, + ) -> Result, Status> { + let request = request.into_inner(); + let result = match request.binding_id.as_str() { + BUILTIN_SECRETS => builtins::secrets::evaluate_http_request(&request), + other => Err(miette::miette!( + "middleware implementation '{other}' is not available in phase 1" + )), + } + .map_err(|err| Status::invalid_argument(safe_reason(&err.to_string())))?; + Ok(Response::new(result)) + } +} diff --git a/crates/openshell-supervisor-network/Cargo.toml b/crates/openshell-supervisor-network/Cargo.toml index 7d0079f7b..fd8fad5f7 100644 --- a/crates/openshell-supervisor-network/Cargo.toml +++ b/crates/openshell-supervisor-network/Cargo.toml @@ -15,6 +15,7 @@ openshell-core = { path = "../openshell-core" } openshell-ocsf = { path = "../openshell-ocsf" } openshell-policy = { path = "../openshell-policy" } openshell-router = { path = "../openshell-router" } +openshell-supervisor-middleware = { path = "../openshell-supervisor-middleware" } apollo-parser = { workspace = true } aws-sigv4 = { version = "1", features = ["sign-http", "http1"] } @@ -28,6 +29,7 @@ glob = { workspace = true } hex = "0.4" ipnet = "2" miette = { workspace = true } +prost-types = { workspace = true } rcgen = { workspace = true } regorus = { version = "0.9", default-features = false, features = ["std", "arc", "glob"] } reqwest = { workspace = true } diff --git a/crates/openshell-supervisor-network/data/sandbox-policy.rego b/crates/openshell-supervisor-network/data/sandbox-policy.rego index efcdf0732..afa4f6947 100644 --- a/crates/openshell-supervisor-network/data/sandbox-policy.rego +++ b/crates/openshell-supervisor-network/data/sandbox-policy.rego @@ -856,6 +856,22 @@ matched_endpoint_config := _matching_endpoint_configs[0] if { count(_matching_endpoint_configs) > 0 } +network_middlewares := object.get(data, "network_middlewares", []) + +_matching_middleware_contexts := [ctx | + some pname + _matching_policy_names[pname] + policy := data.network_policies[pname] + some ep + ep := policy.endpoints[_] + endpoint_matches_request(ep, input.network) + ctx := { + "policy": pname, + "policy_middleware": object.get(policy, "middleware", []), + "endpoint": ep, + } +] + _policy_has_exact_declared_endpoint(policy) if { some ep ep := policy.endpoints[_] @@ -909,7 +925,7 @@ endpoint_path_matches_request(ep, request) if { } # An endpoint has extended config if it specifies L7 protocol, allowed_ips, -# or an explicit tls mode (e.g. tls: skip). +# middleware, or an explicit tls mode (e.g. tls: skip). endpoint_has_extended_config(ep) if { ep.protocol } @@ -918,6 +934,10 @@ endpoint_has_extended_config(ep) if { count(object.get(ep, "allowed_ips", [])) > 0 } +endpoint_has_extended_config(ep) if { + count(object.get(ep, "middleware", [])) > 0 +} + endpoint_has_extended_config(ep) if { ep.tls } diff --git a/crates/openshell-supervisor-network/src/l7/relay.rs b/crates/openshell-supervisor-network/src/l7/relay.rs index ed2bde113..4d501d0a3 100644 --- a/crates/openshell-supervisor-network/src/l7/relay.rs +++ b/crates/openshell-supervisor-network/src/l7/relay.rs @@ -15,9 +15,12 @@ use miette::{IntoDiagnostic, Result, miette}; use openshell_core::activity::{ActivitySender, try_record_activity}; use openshell_core::secrets::{self, SecretResolver}; use openshell_ocsf::{ - ActionId, ActivityId, DispositionId, Endpoint, HttpActivityBuilder, HttpRequest, - NetworkActivityBuilder, SeverityId, StatusId, Url as OcsfUrl, ocsf_emit, + ActionId, ActivityId, DetectionFindingBuilder, DispositionId, Endpoint, FindingInfo, + HttpActivityBuilder, HttpRequest, NetworkActivityBuilder, SeverityId, StatusId, Url as OcsfUrl, + ocsf_emit, }; +use std::collections::BTreeMap; +use std::path::PathBuf; use std::sync::Arc; use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; use tracing::{debug, warn}; @@ -450,6 +453,37 @@ where let _ = &eval_target; if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { + let chain = + engine.query_middleware_chain(&middleware_network_input(ctx), &req.target)?; + let req = + match apply_middleware_chain(req, client, ctx, chain, engine.generation_guard()) + .await? + { + MiddlewareApplyResult::Allowed(req) => req, + MiddlewareApplyResult::Denied(reason) => { + crate::l7::rest::RestProvider::default() + .deny_with_redacted_target( + &crate::l7::provider::L7Request { + action: request_info.action.clone(), + target: redacted_target.clone(), + query_params: request_info.query_params.clone(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }, + &ctx.policy_name, + &reason, + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } + }; let outcome = crate::l7::rest::relay_http_request_with_options_guarded( &req, client, @@ -734,6 +768,167 @@ fn jsonrpc_engine_type(protocol: L7Protocol) -> &'static str { } } +enum MiddlewareApplyResult { + Allowed(crate::l7::provider::L7Request), + Denied(String), +} + +async fn apply_middleware_chain( + req: crate::l7::provider::L7Request, + client: &mut C, + ctx: &L7EvalContext, + chain: Vec, + generation_guard: &PolicyGenerationGuard, +) -> Result { + if chain.is_empty() { + return Ok(MiddlewareApplyResult::Allowed(req)); + } + let buffered = + crate::l7::rest::buffer_request_body_for_middleware(&req, client, Some(generation_guard)) + .await?; + let headers = safe_middleware_headers(&buffered.headers)?; + let input = openshell_supervisor_middleware::HttpRequestInput { + request_id: uuid::Uuid::new_v4().to_string(), + sandbox_id: String::new(), + binary: ctx.binary_path.clone(), + pid: 0, + ancestors: ctx.ancestors.clone(), + scheme: "https".into(), + host: ctx.host.clone(), + port: ctx.port, + method: req.action.clone(), + path: req.target.clone(), + query: String::new(), + headers, + body: buffered.body, + }; + let outcome = openshell_supervisor_middleware::ChainRunner::default() + .evaluate(&chain, input) + .await?; + emit_middleware_events(ctx, &req, &outcome); + let rebuilt = crate::l7::rest::rebuild_request_with_buffered_body( + &req, + &buffered.headers, + &outcome.body, + &outcome.added_headers, + )?; + if outcome.allowed { + Ok(MiddlewareApplyResult::Allowed(rebuilt)) + } else { + Ok(MiddlewareApplyResult::Denied(outcome.reason)) + } +} + +fn safe_middleware_headers(headers: &[u8]) -> Result> { + let header_str = + std::str::from_utf8(headers).map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let mut out = BTreeMap::new(); + for line in header_str.lines().skip(1) { + let Some((name, value)) = line.split_once(':') else { + continue; + }; + let name = name.trim().to_ascii_lowercase(); + if name.is_empty() + || matches!( + name.as_str(), + "authorization" | "cookie" | "host" | "content-length" | "transfer-encoding" + ) + || name.starts_with("x-amz-") + || name.starts_with("x-openshell-credential") + { + continue; + } + out.insert(name, value.trim().to_string()); + } + Ok(out) +} + +fn middleware_network_input(ctx: &L7EvalContext) -> crate::opa::NetworkInput { + crate::opa::NetworkInput { + host: ctx.host.clone(), + port: ctx.port, + binary_path: PathBuf::from(&ctx.binary_path), + binary_sha256: String::new(), + ancestors: ctx.ancestors.iter().map(PathBuf::from).collect(), + cmdline_paths: ctx.cmdline_paths.iter().map(PathBuf::from).collect(), + } +} + +fn emit_middleware_events( + ctx: &L7EvalContext, + req: &crate::l7::provider::L7Request, + outcome: &openshell_supervisor_middleware::ChainOutcome, +) { + for invocation in &outcome.applied { + let allowed = invocation.decision == openshell_core::proto::Decision::Allow; + let event = HttpActivityBuilder::new(openshell_ocsf::ctx::ctx()) + .activity(ActivityId::Other) + .action(if allowed { + ActionId::Allowed + } else { + ActionId::Denied + }) + .disposition(if allowed { + DispositionId::Allowed + } else { + DispositionId::Blocked + }) + .severity(if allowed { + SeverityId::Informational + } else { + SeverityId::Medium + }) + .http_request(HttpRequest::new( + &req.action, + OcsfUrl::new("http", &ctx.host, &req.target, ctx.port), + )) + .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) + .firewall_rule(&ctx.policy_name, "middleware") + .message(format!( + "MIDDLEWARE {} {} decision={:?} transformed={}", + invocation.name, + invocation.implementation, + invocation.decision, + invocation.transformed + )) + .build(); + ocsf_emit!(event); + } + if !outcome.allowed && outcome.reason.starts_with("middleware_failed:") { + let event = DetectionFindingBuilder::new(openshell_ocsf::ctx::ctx()) + .severity(SeverityId::High) + .finding_info(FindingInfo::new( + "openshell.middleware.failure", + "Supervisor middleware failure", + )) + .message("Required supervisor middleware failed closed") + .build(); + ocsf_emit!(event); + } + for finding in &outcome.findings { + let event = DetectionFindingBuilder::new(openshell_ocsf::ctx::ctx()) + .severity(match finding.finding.severity.as_str() { + "high" => SeverityId::High, + "low" => SeverityId::Low, + _ => SeverityId::Medium, + }) + .finding_info(FindingInfo::new( + &finding.finding.r#type, + &finding.finding.label, + )) + .evidence_pairs(&[ + ("middleware", &finding.middleware), + ("count", &finding.finding.count.to_string()), + ]) + .message(format!( + "Middleware finding {} count={}", + finding.finding.r#type, finding.finding.count + )) + .build(); + ocsf_emit!(event); + } +} + /// REST relay loop: parse request -> evaluate -> allow/deny -> relay response -> repeat. async fn relay_rest( config: &L7EndpointConfig, @@ -903,6 +1098,37 @@ where let _ = &eval_target; if allowed || config.enforcement == EnforcementMode::Audit { + let chain = + engine.query_middleware_chain(&middleware_network_input(ctx), &req.target)?; + let req = + match apply_middleware_chain(req, client, ctx, chain, engine.generation_guard()) + .await? + { + MiddlewareApplyResult::Allowed(req) => req, + MiddlewareApplyResult::Denied(reason) => { + provider + .deny_with_redacted_target( + &crate::l7::provider::L7Request { + action: request_info.action.clone(), + target: redacted_target.clone(), + query_params: request_info.query_params.clone(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }, + &ctx.policy_name, + &reason, + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } + }; let req_with_auth = match crate::l7::token_grant_injection::inject_if_needed(req, ctx).await { Ok(req) => req, @@ -1336,6 +1562,37 @@ where let _ = &eval_target; if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { + let chain = + engine.query_middleware_chain(&middleware_network_input(ctx), &req.target)?; + let req = + match apply_middleware_chain(req, client, ctx, chain, engine.generation_guard()) + .await? + { + MiddlewareApplyResult::Allowed(req) => req, + MiddlewareApplyResult::Denied(reason) => { + crate::l7::rest::RestProvider::default() + .deny_with_redacted_target( + &crate::l7::provider::L7Request { + action: request_info.action.clone(), + target: redacted_target.clone(), + query_params: request_info.query_params.clone(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }, + &ctx.policy_name, + &reason, + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } + }; let outcome = crate::l7::rest::relay_http_request_with_resolver_guarded( &req, client, @@ -1674,6 +1931,7 @@ pub async fn relay_passthrough_with_credentials( upstream: &mut U, ctx: &L7EvalContext, generation_guard: &PolicyGenerationGuard, + middleware_engine: Option<&crate::opa::OpaEngine>, ) -> Result<()> where C: AsyncRead + AsyncWrite + Unpin + Send, @@ -1756,6 +2014,43 @@ where ocsf_emit!(event); } + let req = if let Some(engine) = middleware_engine { + let input = middleware_network_input(ctx); + let (chain, generation) = + engine.query_middleware_chain_with_generation(&input, &req.target)?; + if generation != generation_guard.captured_generation() { + return Ok(()); + } + match apply_middleware_chain(req, client, ctx, chain, generation_guard).await? { + MiddlewareApplyResult::Allowed(req) => req, + MiddlewareApplyResult::Denied(reason) => { + crate::l7::rest::RestProvider::default() + .deny_with_redacted_target( + &crate::l7::provider::L7Request { + action: "HTTP".into(), + target: redacted_target.clone(), + query_params: std::collections::HashMap::new(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }, + &ctx.policy_name, + &reason, + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } + } + } else { + req + }; + let req_with_auth = match crate::l7::token_grant_injection::inject_if_needed(req, ctx).await { Ok(req) => req, @@ -1901,6 +2196,63 @@ network_policies: (config, tunnel_engine, ctx, fixture) } + fn middleware_relay_context( + middleware_impl: &str, + on_error: &str, + ) -> (L7EndpointConfig, TunnelPolicyEngine, L7EvalContext) { + let data = format!( + r#" +network_middlewares: + - name: request-middleware + middleware: {middleware_impl} + on_error: {on_error} +network_policies: + rest_api: + name: rest_api + middleware: ["request-middleware"] + endpoints: + - host: api.example.test + port: 8080 + protocol: rest + enforcement: enforce + rules: + - allow: + method: POST + path: "/v1/**" + binaries: + - {{ path: /usr/bin/curl }} +"# + ); + let engine = OpaEngine::from_strings(TEST_POLICY, &data).unwrap(); + let input = NetworkInput { + host: "api.example.test".into(), + port: 8080, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let (endpoint_config, generation) = engine + .query_endpoint_config_with_generation(&input) + .unwrap(); + let config = crate::l7::parse_l7_config(&endpoint_config.unwrap()).unwrap(); + let tunnel_engine = engine.clone_engine_for_tunnel(generation).unwrap(); + let ctx = L7EvalContext { + host: "api.example.test".into(), + port: 8080, + policy_name: "rest_api".into(), + binary_path: "/usr/bin/curl".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + + (config, tunnel_engine, ctx) + } + fn passthrough_token_grant_relay_context( resolver_response: std::result::Result<&str, &str>, ) -> ( @@ -2112,7 +2464,10 @@ network_policies: .unwrap(); let upstream_request = String::from_utf8_lossy(&upstream_request[..n]); - assert!(upstream_request.starts_with("GET /v1/projects HTTP/1.1\r\n")); + assert!( + upstream_request.starts_with("GET /v1/projects HTTP/1.1\r\n"), + "unexpected upstream request: {upstream_request:?}" + ); assert!(upstream_request.contains("Authorization: Bearer grant-token\r\n")); assert!(!upstream_request.contains("stale-token")); assert_eq!(authorization_header_count(&upstream_request), 1); @@ -2194,6 +2549,115 @@ network_policies: fixture.assert_one_request("api.example.test\t8080\t/v1/**\tprovider:access_token"); } + #[tokio::test] + async fn l7_rest_middleware_redacts_body_before_upstream() { + let (config, tunnel_engine, ctx) = + middleware_relay_context("openshell/secrets", "fail_closed"); + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_inspection( + &config, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + let body = br#"{"api_key":"sk-1234567890abcdef"}"#; + let request = format!( + "POST /v1/messages HTTP/1.1\r\nHost: api.example.test\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + std::str::from_utf8(body).unwrap() + ); + app.write_all(request.as_bytes()).await.unwrap(); + + let mut upstream_request = [0u8; 1024]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut upstream_request), + ) + .await + .expect("request should reach upstream") + .unwrap(); + let upstream_request = String::from_utf8_lossy(&upstream_request[..n]); + assert!(upstream_request.contains(r#""api_key":"[REDACTED]""#)); + assert!(!upstream_request.contains("sk-1234567890abcdef")); + + upstream + .write_all(b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\nConnection: close\r\n\r\n") + .await + .unwrap(); + let mut client_response = [0u8; 512]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + app.read(&mut client_response), + ) + .await + .expect("response should reach client") + .unwrap(); + assert!(String::from_utf8_lossy(&client_response[..n]).contains("204 No Content")); + drop(app); + tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should finish") + .unwrap() + .unwrap(); + } + + #[tokio::test] + async fn l7_rest_middleware_fail_closed_does_not_reach_upstream() { + let (config, tunnel_engine, ctx) = + middleware_relay_context("example/unavailable", "fail_closed"); + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_inspection( + &config, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + app.write_all( + b"POST /v1/messages HTTP/1.1\r\nHost: api.example.test\r\nContent-Length: 2\r\nConnection: close\r\n\r\n{}", + ) + .await + .unwrap(); + + let mut response = [0u8; 512]; + let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) + .await + .expect("denial should reach client") + .unwrap(); + let response = String::from_utf8_lossy(&response[..n]); + assert!(response.contains("403 Forbidden")); + assert!(response.contains("middleware_failed")); + + let mut upstream_request = [0u8; 32]; + let result = tokio::time::timeout( + std::time::Duration::from_millis(100), + upstream.read(&mut upstream_request), + ) + .await; + assert!( + matches!(result, Err(_) | Ok(Ok(0))), + "upstream should not receive request bytes" + ); + + drop(app); + tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should finish") + .unwrap() + .unwrap(); + } + #[tokio::test] async fn passthrough_relay_injects_token_grant_authorization_header() { let (generation_guard, ctx, fixture) = @@ -2206,6 +2670,7 @@ network_policies: &mut relay_upstream, &ctx, &generation_guard, + None, ) .await }); @@ -2268,6 +2733,7 @@ network_policies: &mut relay_upstream, &ctx, &generation_guard, + None, ) .await }); @@ -3173,7 +3639,10 @@ network_policies: .expect("first request should reach upstream") .unwrap(); let first_upstream = String::from_utf8_lossy(&first_upstream[..n]); - assert!(first_upstream.starts_with("POST /write HTTP/1.1")); + assert!( + first_upstream.starts_with("POST /write HTTP/1.1"), + "unexpected upstream request: {first_upstream:?}" + ); upstream .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\nConnection: keep-alive\r\n\r\nOK") @@ -3243,6 +3712,7 @@ network_policies: &mut relay_upstream, &ctx, &generation_guard, + None, ) .await }); diff --git a/crates/openshell-supervisor-network/src/l7/rest.rs b/crates/openshell-supervisor-network/src/l7/rest.rs index 0558a67e5..1a4036abd 100644 --- a/crates/openshell-supervisor-network/src/l7/rest.rs +++ b/crates/openshell-supervisor-network/src/l7/rest.rs @@ -27,6 +27,7 @@ const MAX_REWRITE_BODY_BYTES: usize = 256 * 1024; /// Maximum body bytes for `SigV4` body-signing mode. Larger than the credential /// rewrite limit because Bedrock payloads can be several megabytes. const MAX_SIGV4_BODY_BYTES: usize = 10 * 1024 * 1024; +pub(crate) const MAX_MIDDLEWARE_BODY_BYTES: usize = MAX_REWRITE_BODY_BYTES; const RELAY_BUF_SIZE: usize = 8192; const HTTP_METHOD_PREFIXES: &[&[u8]] = &[ b"GET ", @@ -768,6 +769,83 @@ struct PreparedRequestBody { body: Vec, } +pub(crate) struct BufferedRequestBody { + pub(crate) headers: Vec, + pub(crate) body: Vec, +} + +pub(crate) async fn buffer_request_body_for_middleware( + req: &L7Request, + client: &mut C, + generation_guard: Option<&PolicyGenerationGuard>, +) -> Result { + let header_end = req + .raw_header + .windows(4) + .position(|w| w == b"\r\n\r\n") + .map_or(req.raw_header.len(), |p| p + 4); + let headers = req.raw_header[..header_end].to_vec(); + let already_read = &req.raw_header[header_end..]; + match req.body_length { + BodyLength::None => Ok(BufferedRequestBody { + headers, + body: already_read.to_vec(), + }), + BodyLength::ContentLength(len) => { + let len = usize::try_from(len) + .map_err(|_| miette!("request body is too large for middleware"))?; + if len > MAX_MIDDLEWARE_BODY_BYTES { + return Err(miette!( + "middleware buffers at most {MAX_MIDDLEWARE_BODY_BYTES} request body bytes" + )); + } + let initial_len = already_read.len().min(len); + let mut body = Vec::with_capacity(len); + body.extend_from_slice(&already_read[..initial_len]); + let mut remaining = len.saturating_sub(initial_len); + let mut buf = [0u8; RELAY_BUF_SIZE]; + while remaining > 0 { + let to_read = remaining.min(buf.len()); + let n = client.read(&mut buf[..to_read]).await.into_diagnostic()?; + if n == 0 { + return Err(miette!( + "Connection closed with {remaining} body bytes remaining" + )); + } + if let Some(guard) = generation_guard { + guard.ensure_current()?; + } + body.extend_from_slice(&buf[..n]); + remaining -= n; + } + Ok(BufferedRequestBody { headers, body }) + } + BodyLength::Chunked => { + let body = collect_chunked_body(client, already_read, generation_guard).await?; + Ok(BufferedRequestBody { headers, body }) + } + } +} + +pub(crate) fn rebuild_request_with_buffered_body( + req: &L7Request, + headers: &[u8], + body: &[u8], + add_headers: &std::collections::BTreeMap, +) -> Result { + let mut header_bytes = set_content_length(headers, body.len())?; + header_bytes = strip_header(&header_bytes, "transfer-encoding")?; + header_bytes = append_headers(&header_bytes, add_headers)?; + header_bytes.extend_from_slice(body); + Ok(L7Request { + action: req.action.clone(), + target: req.target.clone(), + query_params: req.query_params.clone(), + raw_header: header_bytes, + body_length: BodyLength::ContentLength(body.len() as u64), + }) +} + async fn collect_and_rewrite_request_body( req: &L7Request, client: &mut C, @@ -1160,6 +1238,50 @@ fn set_content_length(headers: &[u8], len: usize) -> Result> { Ok(out.into_bytes()) } +fn strip_header(headers: &[u8], strip_name: &str) -> Result> { + let header_str = + std::str::from_utf8(headers).map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let mut out = String::with_capacity(header_str.len()); + for line in header_str.split("\r\n") { + if line.is_empty() { + out.push_str("\r\n"); + break; + } + if line + .split_once(':') + .is_some_and(|(name, _)| name.trim().eq_ignore_ascii_case(strip_name)) + { + continue; + } + out.push_str(line); + out.push_str("\r\n"); + } + Ok(out.into_bytes()) +} + +fn append_headers( + headers: &[u8], + add_headers: &std::collections::BTreeMap, +) -> Result> { + if add_headers.is_empty() { + return Ok(headers.to_vec()); + } + let split = headers + .windows(4) + .position(|w| w == b"\r\n\r\n") + .map_or(headers.len(), |pos| pos); + let mut out = Vec::with_capacity(headers.len() + add_headers.len() * 32); + out.extend_from_slice(&headers[..split]); + for (name, value) in add_headers { + out.extend_from_slice(b"\r\n"); + out.extend_from_slice(name.as_bytes()); + out.extend_from_slice(b": "); + out.extend_from_slice(value.as_bytes()); + } + out.extend_from_slice(b"\r\n\r\n"); + Ok(out) +} + pub(crate) fn request_is_websocket_upgrade(raw_header: &[u8]) -> bool { let header_end = raw_header .windows(4) diff --git a/crates/openshell-supervisor-network/src/opa.rs b/crates/openshell-supervisor-network/src/opa.rs index fbab5fedd..451c57e59 100644 --- a/crates/openshell-supervisor-network/src/opa.rs +++ b/crates/openshell-supervisor-network/src/opa.rs @@ -13,6 +13,8 @@ use openshell_core::policy::{ }; use openshell_core::proto::SandboxPolicy as ProtoSandboxPolicy; use openshell_policy::L7ConfigStanza; +use openshell_supervisor_middleware::ChainEntry; +use std::collections::HashSet; use std::path::{Path, PathBuf}; use std::sync::{ Arc, Mutex, @@ -132,6 +134,19 @@ impl TunnelPolicyEngine { pub(crate) fn engine(&self) -> &Mutex { &self.engine } + + /// Query the ordered middleware chain for a request path within this tunnel. + pub fn query_middleware_chain( + &self, + input: &NetworkInput, + request_path: &str, + ) -> Result> { + let mut engine = self + .engine + .lock() + .map_err(|_| miette::miette!("OPA engine lock poisoned"))?; + query_middleware_chain_locked(&mut engine, input, request_path) + } } impl OpaEngine { @@ -200,6 +215,14 @@ impl OpaEngine { .map_err(|e| miette::miette!("internal: failed to parse proto JSON: {e}"))?; // Validate BEFORE expanding presets + let middleware_errors = validate_middleware_policies(&data); + if !middleware_errors.is_empty() { + return Err(miette::miette!( + "middleware policy validation failed:\n{}", + middleware_errors.join("\n") + )); + } + let (errors, warnings) = crate::l7::validate_l7_policies(&data); for w in &warnings { openshell_ocsf::ocsf_emit!( @@ -548,6 +571,21 @@ impl OpaEngine { } } + /// Query the ordered middleware chain for a parsed HTTP request path. + pub fn query_middleware_chain_with_generation( + &self, + input: &NetworkInput, + request_path: &str, + ) -> Result<(Vec, u64)> { + let mut engine = self + .engine + .lock() + .map_err(|_| miette::miette!("OPA engine lock poisoned"))?; + let generation = self.current_generation(); + let chain = query_middleware_chain_locked(&mut engine, input, request_path)?; + Ok((chain, generation)) + } + /// Query `allowed_ips` from the matched endpoint config for a given request. /// /// Returns the list of CIDR/IP strings from the endpoint's `allowed_ips` @@ -687,6 +725,243 @@ fn get_str_array(val: ®orus::Value, key: &str) -> Vec { } } +fn network_input_json(input: &NetworkInput) -> serde_json::Value { + let ancestor_strs: Vec = input + .ancestors + .iter() + .map(|p| p.to_string_lossy().into_owned()) + .collect(); + let cmdline_strs: Vec = input + .cmdline_paths + .iter() + .map(|p| p.to_string_lossy().into_owned()) + .collect(); + serde_json::json!({ + "exec": { + "path": input.binary_path.to_string_lossy(), + "ancestors": ancestor_strs, + "cmdline_paths": cmdline_strs, + }, + "network": { + "host": input.host, + "port": input.port, + } + }) +} + +#[derive(Debug, Clone)] +struct MiddlewareContext { + policy_middleware: Vec, + endpoint_middleware: Vec, + endpoint_path: String, +} + +fn query_middleware_chain_locked( + engine: &mut regorus::Engine, + input: &NetworkInput, + request_path: &str, +) -> Result> { + engine + .set_input_json(&network_input_json(input).to_string()) + .map_err(|e| miette::miette!("{e}"))?; + + let configs_val = engine + .eval_rule("data.openshell.sandbox.network_middlewares".into()) + .map_err(|e| miette::miette!("{e}"))?; + let configs = parse_middleware_configs(&configs_val)?; + if configs.is_empty() { + return Ok(Vec::new()); + } + let contexts_val = engine + .eval_rule("data.openshell.sandbox._matching_middleware_contexts".into()) + .map_err(|e| miette::miette!("{e}"))?; + let contexts = parse_middleware_contexts(&contexts_val); + let Some(context) = select_middleware_context(&contexts, request_path) else { + return Ok(global_middleware_entries( + &configs, + &input.host, + &HashSet::new(), + )?); + }; + + let mut explicit = Vec::new(); + for name in context + .policy_middleware + .iter() + .chain(context.endpoint_middleware.iter()) + { + if !explicit.contains(name) { + explicit.push(name.clone()); + } + } + let explicit_set: HashSet = explicit.iter().cloned().collect(); + let mut ordered = global_middleware_entries(&configs, &input.host, &explicit_set)?; + for name in explicit { + if !ordered.iter().any(|entry| entry.name == name) { + let config = configs + .iter() + .find(|config| get_str(config, "name").as_deref() == Some(name.as_str())) + .ok_or_else(|| miette::miette!("unknown middleware config '{name}'"))?; + ordered.push(chain_entry_from_value(config)?); + } + } + Ok(ordered) +} + +fn parse_middleware_configs(value: ®orus::Value) -> Result> { + match value { + regorus::Value::Undefined => Ok(Vec::new()), + regorus::Value::Array(values) => Ok(values.to_vec()), + other => Err(miette::miette!( + "network_middlewares must be an array, got {other:?}" + )), + } +} + +fn parse_middleware_contexts(value: ®orus::Value) -> Vec { + let regorus::Value::Array(values) = value else { + return Vec::new(); + }; + values + .iter() + .filter_map(|value| { + let regorus::Value::Object(_) = value else { + return None; + }; + let endpoint = get_field(value, "endpoint")?; + Some(MiddlewareContext { + policy_middleware: get_str_array(value, "policy_middleware"), + endpoint_middleware: get_str_array(endpoint, "middleware"), + endpoint_path: get_str(endpoint, "path").unwrap_or_default(), + }) + }) + .collect() +} + +fn select_middleware_context<'a>( + contexts: &'a [MiddlewareContext], + request_path: &str, +) -> Option<&'a MiddlewareContext> { + contexts + .iter() + .filter(|context| crate::l7::endpoint_path_matches(&context.endpoint_path, request_path)) + .max_by_key(|context| { + if context.endpoint_path.is_empty() { + 0 + } else { + context.endpoint_path.chars().filter(|c| *c != '*').count() + } + }) +} + +fn global_middleware_entries( + configs: &[regorus::Value], + host: &str, + explicit: &HashSet, +) -> Result> { + let mut entries = Vec::new(); + for config in configs { + let name = get_str(config, "name").unwrap_or_default(); + if explicit.contains(&name) { + continue; + } + if middleware_selector_matches(config, host) { + entries.push(chain_entry_from_value(config)?); + } + } + Ok(entries) +} + +fn middleware_selector_matches(config: ®orus::Value, host: &str) -> bool { + let Some(selector) = get_field(config, "endpoints") else { + return false; + }; + let includes = get_str_array(selector, "include"); + let excludes = get_str_array(selector, "exclude"); + let included = + !includes.is_empty() && includes.iter().any(|pattern| host_matches(pattern, host)); + let excluded = excludes.iter().any(|pattern| host_matches(pattern, host)); + included && !excluded +} + +fn host_matches(pattern: &str, host: &str) -> bool { + if pattern == "*" || pattern == "**" { + return true; + } + if !pattern.contains('*') { + return pattern.eq_ignore_ascii_case(host); + } + glob::Pattern::new(&pattern.to_ascii_lowercase()) + .is_ok_and(|pattern| pattern.matches(&host.to_ascii_lowercase())) +} + +fn chain_entry_from_value(value: ®orus::Value) -> Result { + let name = get_str(value, "name").unwrap_or_default(); + let implementation = get_str(value, "middleware").unwrap_or_default(); + Ok(ChainEntry { + name, + implementation, + config: get_field(value, "config") + .map(regorus_value_to_struct) + .unwrap_or_default(), + on_error: openshell_supervisor_middleware::OnError::parse( + get_str(value, "on_error").as_deref().unwrap_or_default(), + )?, + }) +} + +fn get_field<'a>(val: &'a regorus::Value, key: &str) -> Option<&'a regorus::Value> { + let key_val = regorus::Value::String(key.into()); + match val { + regorus::Value::Object(map) => map.get(&key_val), + _ => None, + } +} + +fn regorus_value_to_struct(value: ®orus::Value) -> prost_types::Struct { + let regorus::Value::Object(map) = value else { + return prost_types::Struct::default(); + }; + prost_types::Struct { + fields: map + .iter() + .filter_map(|(key, value)| match key { + regorus::Value::String(key) => { + Some((key.to_string(), regorus_value_to_prost(value))) + } + _ => None, + }) + .collect(), + } +} + +fn regorus_value_to_prost(value: ®orus::Value) -> prost_types::Value { + use prost_types::{ListValue, Struct, Value, value::Kind}; + Value { + kind: Some(match value { + regorus::Value::Bool(value) => Kind::BoolValue(*value), + regorus::Value::Number(value) => Kind::NumberValue(value.as_f64().unwrap_or_default()), + regorus::Value::String(value) => Kind::StringValue(value.to_string()), + regorus::Value::Array(values) => Kind::ListValue(ListValue { + values: values.iter().map(regorus_value_to_prost).collect(), + }), + regorus::Value::Object(values) => Kind::StructValue(Struct { + fields: values + .iter() + .filter_map(|(key, value)| match key { + regorus::Value::String(key) => { + Some((key.to_string(), regorus_value_to_prost(value))) + } + _ => None, + }) + .collect(), + }), + regorus::Value::Null | regorus::Value::Undefined => Kind::NullValue(0), + _ => Kind::NullValue(0), + }), + } +} + fn parse_filesystem_policy(val: ®orus::Value) -> FilesystemPolicy { FilesystemPolicy { read_only: get_str_array(val, "read_only") @@ -735,6 +1010,14 @@ fn preprocess_yaml_data(yaml_str: &str) -> Result { } // Validate BEFORE expanding presets (catches user errors like rules+access) + let middleware_errors = validate_middleware_policies(&data); + if !middleware_errors.is_empty() { + return Err(miette::miette!( + "middleware policy validation failed:\n{}", + middleware_errors.join("\n") + )); + } + let (errors, warnings) = crate::l7::validate_l7_policies(&data); for w in &warnings { openshell_ocsf::ocsf_emit!( @@ -955,6 +1238,131 @@ fn normalize_l7_rule_aliases( } } +fn validate_middleware_policies(data: &serde_json::Value) -> Vec { + let mut errors = Vec::new(); + let middlewares = data + .get("network_middlewares") + .and_then(serde_json::Value::as_array) + .map_or(&[][..], Vec::as_slice); + let mut names = HashSet::new(); + for mw in middlewares { + let name = mw + .get("name") + .and_then(serde_json::Value::as_str) + .unwrap_or_default(); + let implementation = mw + .get("middleware") + .and_then(serde_json::Value::as_str) + .unwrap_or_default(); + if name.is_empty() { + errors.push("network_middlewares entry has empty name".to_string()); + } else if !names.insert(name.to_string()) { + errors.push(format!("duplicate middleware config '{name}'")); + } + if implementation.is_empty() { + errors.push(format!( + "middleware config '{name}' has empty implementation" + )); + } + if implementation.starts_with("openshell/") + && implementation != openshell_supervisor_middleware::BUILTIN_SECRETS + { + errors.push(format!( + "middleware config '{name}' references unsupported built-in '{implementation}'" + )); + } + let on_error = mw + .get("on_error") + .and_then(serde_json::Value::as_str) + .unwrap_or_default(); + if !matches!(on_error, "" | "fail_closed" | "fail_open") { + errors.push(format!( + "middleware config '{name}' has invalid on_error '{on_error}'" + )); + } + } + + let Some(policies) = data + .get("network_policies") + .and_then(serde_json::Value::as_object) + else { + return errors; + }; + + for (policy_name, policy) in policies { + let policy_middleware = json_string_array(policy.get("middleware")); + for name in &policy_middleware { + if !names.contains(name) { + errors.push(format!( + "network policy '{policy_name}' references unknown middleware config '{name}'" + )); + } + } + for endpoint in policy + .get("endpoints") + .and_then(serde_json::Value::as_array) + .map_or(&[][..], Vec::as_slice) + { + let endpoint_middleware = json_string_array(endpoint.get("middleware")); + for name in &endpoint_middleware { + if !names.contains(name) { + errors.push(format!( + "network policy '{policy_name}' endpoint references unknown middleware config '{name}'" + )); + } + } + let tls_skip = endpoint + .get("tls") + .and_then(serde_json::Value::as_str) + .is_some_and(|tls| tls == "skip"); + if tls_skip && (!policy_middleware.is_empty() || !endpoint_middleware.is_empty()) { + errors.push(format!( + "network policy '{policy_name}' attaches middleware to a tls: skip endpoint" + )); + } + if tls_skip && global_selector_matches_any_middleware(middlewares, endpoint) { + errors.push(format!( + "network policy '{policy_name}' tls: skip endpoint matches a global middleware selector" + )); + } + } + } + errors +} + +fn json_string_array(value: Option<&serde_json::Value>) -> Vec { + value + .and_then(serde_json::Value::as_array) + .map(|values| { + values + .iter() + .filter_map(serde_json::Value::as_str) + .map(ToString::to_string) + .collect() + }) + .unwrap_or_default() +} + +fn global_selector_matches_any_middleware( + middlewares: &[serde_json::Value], + endpoint: &serde_json::Value, +) -> bool { + let host = endpoint + .get("host") + .and_then(serde_json::Value::as_str) + .unwrap_or_default(); + middlewares.iter().any(|mw| { + let Some(selector) = mw.get("endpoints") else { + return false; + }; + let includes = json_string_array(selector.get("include")); + let excludes = json_string_array(selector.get("exclude")); + !includes.is_empty() + && includes.iter().any(|pattern| host_matches(pattern, host)) + && !excludes.iter().any(|pattern| host_matches(pattern, host)) + }) +} + /// Resolve a policy binary path through the container's root filesystem. /// /// On Linux, `/proc//root/` provides access to the container's mount @@ -1316,6 +1724,9 @@ fn proto_to_opa_data_json(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> St allow_all_known_mcp_methods.into(); } } + if !e.middleware.is_empty() { + ep["middleware"] = e.middleware.clone().into(); + } ep }) .collect(); @@ -1341,14 +1752,43 @@ fn proto_to_opa_data_json(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> St entries }) .collect(); - ( - key.clone(), - serde_json::json!({ - "name": rule.name, - "endpoints": endpoints, - "binaries": binaries, - }), - ) + let mut policy = serde_json::json!({ + "name": rule.name, + "endpoints": endpoints, + "binaries": binaries, + }); + if !rule.middleware.is_empty() { + policy["middleware"] = rule.middleware.clone().into(); + } + (key.clone(), policy) + }) + .collect(); + + let network_middlewares: Vec = proto + .network_middlewares + .iter() + .map(|mw| { + let mut value = serde_json::json!({ + "name": mw.name, + "middleware": mw.middleware, + }); + if let Some(config) = &mw.config { + value["config"] = prost_struct_to_json(config); + } + if !mw.on_error.is_empty() { + value["on_error"] = mw.on_error.clone().into(); + } + if let Some(selector) = &mw.endpoints { + let mut endpoints = serde_json::json!({}); + if !selector.include.is_empty() { + endpoints["include"] = selector.include.clone().into(); + } + if !selector.exclude.is_empty() { + endpoints["exclude"] = selector.exclude.clone().into(); + } + value["endpoints"] = endpoints; + } + value }) .collect(); @@ -1357,10 +1797,37 @@ fn proto_to_opa_data_json(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> St "landlock": landlock, "process": process, "network_policies": network_policies, + "network_middlewares": network_middlewares, }) .to_string() } +fn prost_struct_to_json(config: &prost_types::Struct) -> serde_json::Value { + serde_json::Value::Object( + config + .fields + .iter() + .map(|(key, value)| (key.clone(), prost_value_to_json(value))) + .collect(), + ) +} + +fn prost_value_to_json(value: &prost_types::Value) -> serde_json::Value { + match value.kind.as_ref() { + Some(prost_types::value::Kind::NullValue(_)) | None => serde_json::Value::Null, + Some(prost_types::value::Kind::BoolValue(value)) => serde_json::Value::Bool(*value), + Some(prost_types::value::Kind::NumberValue(value)) => serde_json::Number::from_f64(*value) + .map_or(serde_json::Value::Null, serde_json::Value::Number), + Some(prost_types::value::Kind::StringValue(value)) => { + serde_json::Value::String(value.clone()) + } + Some(prost_types::value::Kind::ListValue(value)) => { + serde_json::Value::Array(value.values.iter().map(prost_value_to_json).collect()) + } + Some(prost_types::value::Kind::StructValue(value)) => prost_struct_to_json(value), + } +} + #[cfg(test)] #[allow( clippy::needless_raw_string_hashes, @@ -1407,6 +1874,7 @@ mod tests { path: "/usr/local/bin/claude".to_string(), ..Default::default() }], + ..Default::default() }, ); network_policies.insert( @@ -1422,6 +1890,7 @@ mod tests { path: "/usr/bin/glab".to_string(), ..Default::default() }], + ..Default::default() }, ); ProtoSandboxPolicy { @@ -1439,6 +1908,7 @@ mod tests { run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], } } @@ -2763,6 +3233,7 @@ network_policies: path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }, ); @@ -2781,6 +3252,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); @@ -3783,6 +4255,7 @@ network_policies: path: "/usr/bin/node".to_string(), ..Default::default() }], + ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -3800,6 +4273,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); @@ -3840,6 +4314,7 @@ network_policies: path: "/usr/bin/node".to_string(), ..Default::default() }], + ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -3857,6 +4332,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); @@ -3898,6 +4374,7 @@ network_policies: path: "/usr/local/bin/claude".to_string(), ..Default::default() }], + middleware: vec![], }, ); let proto = ProtoSandboxPolicy { @@ -3915,6 +4392,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); @@ -3958,6 +4436,7 @@ network_policies: path: "/usr/local/bin/aws".to_string(), ..Default::default() }], + middleware: vec![], }, ); let proto = ProtoSandboxPolicy { @@ -3975,6 +4454,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); @@ -4017,6 +4497,7 @@ network_policies: path: "/usr/bin/node".to_string(), ..Default::default() }], + ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -4034,6 +4515,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); @@ -4966,6 +5448,7 @@ process: ..Default::default() }], binaries: vec![proposal_binary], + ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -4983,6 +5466,7 @@ process: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); let input = NetworkInput { @@ -5020,6 +5504,7 @@ process: path: "/usr/bin/python".to_string(), ..Default::default() }], + ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -5037,6 +5522,7 @@ process: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); let input = NetworkInput { @@ -5090,6 +5576,7 @@ process: path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -5107,6 +5594,7 @@ process: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("Failed to create engine from proto"); @@ -5320,6 +5808,7 @@ network_policies: path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -5337,6 +5826,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).unwrap(); // Port 443 @@ -6023,6 +6513,7 @@ network_policies: path: "/usr/bin/python3".to_string(), ..Default::default() }], + ..Default::default() }, ); @@ -6279,6 +6770,7 @@ network_policies: path: link_path, ..Default::default() }], + ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -6296,6 +6788,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; // Build engine with our PID (symlink resolution will work via /proc/self/root/) @@ -6356,6 +6849,7 @@ network_policies: path: link_path, ..Default::default() }], + ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -6373,6 +6867,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; // Initial load at pid=0 — no symlink expansion @@ -6415,6 +6910,133 @@ network_policies: assert!(eval_l7(&engine, &input)); } + #[test] + fn middleware_chain_orders_global_policy_endpoint_once() { + let data = r#" +network_middlewares: + - name: global-redactor + middleware: openshell/secrets + endpoints: + include: ["api.example.com"] + - name: policy-redactor + middleware: openshell/secrets + - name: endpoint-redactor + middleware: openshell/secrets +network_policies: + api: + name: api + middleware: ["global-redactor", "policy-redactor"] + endpoints: + - host: api.example.com + port: 443 + protocol: rest + enforcement: enforce + middleware: ["policy-redactor", "endpoint-redactor"] + rules: + - allow: { method: POST, path: "/v1/**" } + binaries: + - { path: /usr/bin/curl } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "api.example.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let (chain, _) = engine + .query_middleware_chain_with_generation(&input, "/v1/messages") + .unwrap(); + let names: Vec<_> = chain.iter().map(|entry| entry.name.as_str()).collect(); + assert_eq!( + names, + vec!["global-redactor", "policy-redactor", "endpoint-redactor"] + ); + } + + #[test] + fn middleware_policy_validation_rejects_bad_configs() { + let cases = [ + ( + "missing reference", + r#" +network_middlewares: + - name: redactor + middleware: openshell/secrets +network_policies: + api: + middleware: ["missing"] + endpoints: + - { host: api.example.com, port: 443 } + binaries: + - { path: /usr/bin/curl } +"#, + "unknown middleware config 'missing'", + ), + ( + "invalid on_error", + r#" +network_middlewares: + - name: redactor + middleware: openshell/secrets + on_error: maybe +"#, + "invalid on_error", + ), + ( + "duplicate names", + r#" +network_middlewares: + - name: redactor + middleware: openshell/secrets + - name: redactor + middleware: openshell/secrets +"#, + "duplicate middleware config 'redactor'", + ), + ( + "reserved builtin", + r#" +network_middlewares: + - name: sigv4 + middleware: openshell/sigv4 +"#, + "unsupported built-in", + ), + ( + "tls skip attachment", + r#" +network_middlewares: + - name: redactor + middleware: openshell/secrets +network_policies: + api: + endpoints: + - host: api.example.com + port: 443 + tls: skip + middleware: ["redactor"] + binaries: + - { path: /usr/bin/curl } +"#, + "tls: skip", + ), + ]; + + for (name, data, expected) in cases { + let err = match OpaEngine::from_strings(TEST_POLICY, data) { + Ok(_) => panic!("{name}: expected policy validation failure"), + Err(err) => err.to_string(), + }; + assert!( + err.contains(expected), + "{name}: expected {expected:?} in {err:?}" + ); + } + } + #[test] fn l7_head_denied_when_only_post_allowed() { let engine = OpaEngine::from_strings( diff --git a/crates/openshell-supervisor-network/src/policy_local.rs b/crates/openshell-supervisor-network/src/policy_local.rs index 3cbc31502..fa8029c72 100644 --- a/crates/openshell-supervisor-network/src/policy_local.rs +++ b/crates/openshell-supervisor-network/src/policy_local.rs @@ -1047,6 +1047,7 @@ fn network_rule_from_json( name: rule.name.unwrap_or_default(), endpoints, binaries, + middleware: Vec::new(), }) } @@ -1133,6 +1134,7 @@ fn network_endpoint_from_json( credential_signing: String::new(), signing_service: String::new(), signing_region: String::new(), + middleware: Vec::new(), }) } @@ -1829,6 +1831,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }), ..Default::default() }; @@ -1853,6 +1856,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() } } @@ -1916,6 +1920,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() })); }) }; diff --git a/crates/openshell-supervisor-network/src/proxy.rs b/crates/openshell-supervisor-network/src/proxy.rs index 0d2c8c025..af3331735 100644 --- a/crates/openshell-supervisor-network/src/proxy.rs +++ b/crates/openshell-supervisor-network/src/proxy.rs @@ -1183,6 +1183,7 @@ async fn handle_tcp_connection( &mut tls_upstream, &ctx, &generation_guard, + Some(&opa_engine), ) .await } @@ -1288,6 +1289,7 @@ async fn handle_tcp_connection( &mut upstream, &ctx, &generation_guard, + Some(&opa_engine), ) .await { diff --git a/proto/middleware.proto b/proto/middleware.proto new file mode 100644 index 000000000..d5d2ad48d --- /dev/null +++ b/proto/middleware.proto @@ -0,0 +1,95 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +syntax = "proto3"; + +package openshell.middleware.v1; + +import "google/protobuf/empty.proto"; +import "google/protobuf/struct.proto"; + +service SupervisorMiddleware { + rpc Describe(google.protobuf.Empty) returns (MiddlewareManifest); + rpc ValidateConfig(ValidateConfigRequest) returns (ValidateConfigResponse); + rpc EvaluateHttpRequest(HttpRequestEvaluation) returns (HttpRequestResult); +} + +message MiddlewareManifest { + string api_version = 1; + string name = 2; + string service_version = 3; + repeated MiddlewareBinding bindings = 4; +} + +message MiddlewareBinding { + string id = 1; + string operation = 2; + string phase = 3; +} + +message ValidateConfigRequest { + string api_version = 1; + string binding_id = 2; + google.protobuf.Struct config = 3; +} + +message ValidateConfigResponse { + bool valid = 1; + string reason = 2; +} + +message HttpRequestEvaluation { + string api_version = 1; + string binding_id = 2; + string phase = 3; + RequestContext context = 4; + google.protobuf.Struct config = 5; + HttpRequestTarget target = 6; + map headers = 7; + bytes body = 8; +} + +message RequestContext { + string request_id = 1; + string sandbox_id = 2; + Process originating_process = 3; +} + +message HttpRequestTarget { + string scheme = 1; + string host = 2; + uint32 port = 3; + string method = 4; + string path = 5; + string query = 6; +} + +message Process { + string binary = 1; + uint32 pid = 2; + repeated string ancestors = 3; +} + +enum Decision { + DECISION_UNSPECIFIED = 0; + DECISION_ALLOW = 1; + DECISION_DENY = 2; +} + +message Finding { + string type = 1; + string label = 2; + uint32 count = 3; + string confidence = 4; + string severity = 5; +} + +message HttpRequestResult { + Decision decision = 1; + string reason = 2; + bytes body = 3; + bool has_body = 4; + map add_headers = 5; + repeated Finding findings = 6; + map metadata = 7; +} diff --git a/proto/sandbox.proto b/proto/sandbox.proto index 8a5a59333..5d2bc31a5 100644 --- a/proto/sandbox.proto +++ b/proto/sandbox.proto @@ -5,6 +5,8 @@ syntax = "proto3"; package openshell.sandbox.v1; +import "google/protobuf/struct.proto"; + // Sandbox-supervisor configuration and policy messages. // // Conventions: @@ -25,6 +27,8 @@ message SandboxPolicy { ProcessPolicy process = 4; // Network access policies keyed by name (e.g. "claude_code", "gitlab"). map network_policies = 5; + // Reusable supervisor middleware configs for network egress. + repeated NetworkMiddlewareConfig network_middlewares = 6; } // Filesystem access policy. @@ -59,6 +63,27 @@ message NetworkPolicyRule { repeated NetworkEndpoint endpoints = 2; // Allowed binary identities. repeated NetworkBinary binaries = 3; + // Ordered middleware configs applied to every endpoint in this policy. + repeated string middleware = 4; +} + +// A reusable middleware config referenced by network policies/endpoints. +message NetworkMiddlewareConfig { + // Policy-local config name. + string name = 1; + // Built-in or registered middleware implementation name. + string middleware = 2; + // Service-specific configuration. + google.protobuf.Struct config = 3; + // Failure behavior: "fail_closed" (default) or "fail_open". + string on_error = 4; + // Optional global endpoint selector for this config. + MiddlewareEndpointSelector endpoints = 5; +} + +message MiddlewareEndpointSelector { + repeated string include = 1; + repeated string exclude = 2; } // A network endpoint (host + port) with optional L7 inspection config. @@ -143,6 +168,8 @@ message NetworkEndpoint { uint32 json_rpc_max_body_bytes = 22; // MCP-only policy and inspection options. Only used when protocol is "mcp". McpOptions mcp = 23; + // Ordered middleware configs applied to this endpoint after policy-level middleware. + repeated string middleware = 24; } // MCP options are grouped so MCP-specific policy can grow without adding more @@ -175,8 +202,6 @@ message McpOptions { // MCP-family methods at the method layer unless a tool-name policy narrows // tools/call. When unset or false, explicit method rules are required. optional bool allow_all_known_mcp_methods = 2; -} - // Trusted GraphQL operation classification. message GraphqlOperation { // Operation type: "query", "mutation", or "subscription". From 3abe7d2a01117a2c9005a0e3d0292cc9b2d34ace Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Fri, 26 Jun 2026 13:25:39 -0700 Subject: [PATCH 02/16] fix(supervisor-middleware): harden middleware relay handling Signed-off-by: Piotr Mlocek --- Cargo.lock | 1 + crates/openshell-cli/src/policy_update.rs | 1 + .../src/mechanistic_mapper.rs | 1 + crates/openshell-server/src/grpc/policy.rs | 45 +- .../src/builtins/mod.rs | 2 +- .../src/builtins/secrets.rs | 87 +++- .../src/lib.rs | 298 ++++++++++- .../openshell-supervisor-network/Cargo.toml | 1 + .../src/l7/relay.rs | 489 +++++++++++++++++- .../src/l7/rest.rs | 283 ++++++---- .../openshell-supervisor-network/src/opa.rs | 23 +- 11 files changed, 1059 insertions(+), 172 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 4b43f48c5..0a2f749b4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3980,6 +3980,7 @@ dependencies = [ "tokio-tungstenite 0.26.2", "tower-mcp-types", "tracing", + "tracing-subscriber", "uuid", "webpki-roots 1.0.7", ] diff --git a/crates/openshell-cli/src/policy_update.rs b/crates/openshell-cli/src/policy_update.rs index 1f1f64750..824b1dde0 100644 --- a/crates/openshell-cli/src/policy_update.rs +++ b/crates/openshell-cli/src/policy_update.rs @@ -65,6 +65,7 @@ pub fn build_policy_update_plan( ..Default::default() }) .collect(), + middleware: Vec::new(), }; merge_operations.push(PolicyMergeOperation { operation: Some(policy_merge_operation::Operation::AddRule(AddNetworkRule { diff --git a/crates/openshell-sandbox/src/mechanistic_mapper.rs b/crates/openshell-sandbox/src/mechanistic_mapper.rs index 8ee2fc37f..bb83ddb66 100644 --- a/crates/openshell-sandbox/src/mechanistic_mapper.rs +++ b/crates/openshell-sandbox/src/mechanistic_mapper.rs @@ -162,6 +162,7 @@ pub fn generate_proposals(summaries: &[DenialSummary]) -> Vec { name: rule_name.clone(), endpoints: vec![endpoint], binaries, + middleware: Vec::new(), }; // Compute confidence. diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index cc8ff0d2e..09e311bb2 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -5746,6 +5746,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let submit = handle_submit_policy_analysis( @@ -5959,6 +5960,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let submit = handle_submit_policy_analysis( @@ -6075,6 +6077,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; handle_submit_policy_analysis( @@ -6180,6 +6183,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let mechanistic_submit = handle_submit_policy_analysis( &state, @@ -6257,6 +6261,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let agent_submit = handle_submit_policy_analysis( &state, @@ -6384,6 +6389,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; handle_submit_policy_analysis( @@ -6484,6 +6490,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; handle_submit_policy_analysis( @@ -6584,6 +6591,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; handle_submit_policy_analysis( @@ -6677,6 +6685,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; handle_submit_policy_analysis( @@ -6761,6 +6770,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; handle_submit_policy_analysis( @@ -6849,6 +6859,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; handle_submit_policy_analysis( @@ -6940,6 +6951,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; handle_submit_policy_analysis( @@ -7026,6 +7038,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let response = handle_submit_policy_analysis( @@ -7201,6 +7214,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; handle_submit_policy_analysis( @@ -7297,6 +7311,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; handle_submit_policy_analysis( @@ -7382,6 +7397,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; handle_submit_policy_analysis( @@ -7523,6 +7539,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; handle_submit_policy_analysis( @@ -7648,6 +7665,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let step1 = handle_submit_policy_analysis( &state, @@ -7689,6 +7707,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let step2 = handle_submit_policy_analysis( &state, @@ -7820,6 +7839,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let submit_one = |rule_name: &str, rule: NetworkPolicyRule| { @@ -7928,6 +7948,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let submit_one = || { let state = state.clone(); @@ -8028,6 +8049,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let submit = handle_submit_policy_analysis( @@ -8159,6 +8181,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; handle_submit_policy_analysis( @@ -8357,6 +8380,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }, }; @@ -8385,6 +8409,7 @@ mod tests { path: "/usr/bin/node".to_string(), ..Default::default() }], + ..Default::default() }, }; @@ -8413,6 +8438,7 @@ mod tests { path: "/usr/bin/node".to_string(), ..Default::default() }], + ..Default::default() }, }; @@ -8440,6 +8466,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let chunk = DraftChunkRecord { id: "chunk-1".to_string(), @@ -8508,6 +8535,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }, )) .collect(), @@ -8536,6 +8564,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let chunk = DraftChunkRecord { id: "chunk-merge".to_string(), @@ -8609,6 +8638,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }, )) .collect(), @@ -8637,6 +8667,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let chunk = DraftChunkRecord { id: "chunk-new".to_string(), @@ -8773,7 +8804,7 @@ mod tests { allowed_ips: vec!["127.0.0.1".to_string()], ..Default::default() }], - binaries: vec![], + ..Default::default() }; let result = validate_rule_not_always_blocked(&rule); assert!(result.is_err()); @@ -8794,7 +8825,7 @@ mod tests { allowed_ips: vec!["169.254.169.254".to_string()], ..Default::default() }], - binaries: vec![], + ..Default::default() }; let result = validate_rule_not_always_blocked(&rule); assert!(result.is_err()); @@ -8812,7 +8843,7 @@ mod tests { port: 80, ..Default::default() }], - binaries: vec![], + ..Default::default() }; let result = validate_rule_not_always_blocked(&rule); assert!(result.is_err()); @@ -8830,7 +8861,7 @@ mod tests { port: 8080, ..Default::default() }], - binaries: vec![], + ..Default::default() }; let result = validate_rule_not_always_blocked(&rule); assert!(result.is_err()); @@ -8848,7 +8879,7 @@ mod tests { port: 80, ..Default::default() }], - binaries: vec![], + ..Default::default() }; let result = validate_rule_not_always_blocked(&rule); assert!(result.is_err()); @@ -8896,7 +8927,7 @@ mod tests { allowed_ips: vec!["10.0.5.0/24".to_string()], ..Default::default() }], - binaries: vec![], + ..Default::default() }; let result = validate_rule_not_always_blocked(&rule); assert!(result.is_ok()); @@ -8913,7 +8944,7 @@ mod tests { port: 443, ..Default::default() }], - binaries: vec![], + ..Default::default() }; let result = validate_rule_not_always_blocked(&rule); assert!(result.is_ok()); diff --git a/crates/openshell-supervisor-middleware/src/builtins/mod.rs b/crates/openshell-supervisor-middleware/src/builtins/mod.rs index 60572d3e8..d91ee745e 100644 --- a/crates/openshell-supervisor-middleware/src/builtins/mod.rs +++ b/crates/openshell-supervisor-middleware/src/builtins/mod.rs @@ -1,4 +1,4 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -pub(crate) mod secrets; +pub mod secrets; diff --git a/crates/openshell-supervisor-middleware/src/builtins/secrets.rs b/crates/openshell-supervisor-middleware/src/builtins/secrets.rs index 6c94eb439..572102559 100644 --- a/crates/openshell-supervisor-middleware/src/builtins/secrets.rs +++ b/crates/openshell-supervisor-middleware/src/builtins/secrets.rs @@ -2,6 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 use std::collections::HashMap; +use std::sync::LazyLock; use miette::{Result, miette}; use openshell_core::proto::{Decision, Finding, HttpRequestEvaluation, HttpRequestResult}; @@ -9,7 +10,36 @@ use regex::Regex; use crate::BUILTIN_SECRETS; -pub(crate) fn validate_config(config: &prost_types::Struct) -> Result<()> { +/// A named secret-detection pattern. The `kind` is an audit-safe label that +/// flows into findings so operators can see *what* matched without seeing the +/// raw value. +struct SecretPattern { + kind: &'static str, + regex: Regex, +} + +impl SecretPattern { + fn new(kind: &'static str, pattern: &str) -> Self { + Self { + kind, + regex: Regex::new(pattern).expect("valid built-in secret redaction pattern"), + } + } +} + +/// Compiled once: recompiling per request would put regex construction on the +/// egress hot path. +static SECRET_PATTERNS: LazyLock<[SecretPattern; 2]> = LazyLock::new(|| { + [ + SecretPattern::new( + "keyword", + r#"(?i)(api[_-]?key|access[_-]?token|secret|password)(["']?\s*[:=]\s*["'])[^"',\s}]+(["']?)"#, + ), + SecretPattern::new("openai", r"(sk-[A-Za-z0-9_-]{16,})"), + ] +}); + +pub fn validate_config(config: &prost_types::Struct) -> Result<()> { let mode = config .fields .get("secrets") @@ -27,49 +57,54 @@ pub(crate) fn validate_config(config: &prost_types::Struct) -> Result<()> { Ok(()) } -pub(crate) fn evaluate_http_request( - evaluation: &HttpRequestEvaluation, -) -> Result { +pub fn evaluate_http_request(evaluation: &HttpRequestEvaluation) -> Result { let default_config = prost_types::Struct::default(); validate_config(evaluation.config.as_ref().unwrap_or(&default_config))?; let text = String::from_utf8(evaluation.body.clone()) .map_err(|_| miette!("{} requires UTF-8 request bodies", BUILTIN_SECRETS))?; - let (body, count) = redact_common_secrets(&text)?; + let (body, matches) = redact_common_secrets(&text); + let total: u32 = matches + .iter() + .fold(0u32, |acc, (_, count)| acc.saturating_add(*count)); let mut result = HttpRequestResult { decision: Decision::Allow as i32, reason: String::new(), body: body.into_bytes(), - has_body: count > 0, + has_body: !matches.is_empty(), add_headers: HashMap::new(), findings: Vec::new(), metadata: HashMap::new(), }; - if count > 0 { - result.findings.push(Finding { - r#type: "secret.common".into(), - label: "common secret pattern".into(), - count, - confidence: "medium".into(), - severity: "medium".into(), - }); + if !matches.is_empty() { + // One finding per matched pattern kind, so audit shows what matched. + for (kind, count) in &matches { + result.findings.push(Finding { + r#type: format!("secret.{kind}"), + label: format!("{kind} secret pattern"), + count: *count, + confidence: "medium".into(), + severity: "medium".into(), + }); + } result .metadata - .insert("secrets_redacted".into(), count.to_string()); + .insert("secrets_redacted".into(), total.to_string()); } Ok(result) } -fn redact_common_secrets(input: &str) -> Result<(String, u32)> { - let patterns = [ - r#"(?i)(api[_-]?key|access[_-]?token|secret|password)(["']?\s*[:=]\s*["'])[^"',\s}]+(["']?)"#, - r#"(sk-[A-Za-z0-9_-]{16,})"#, - ]; +/// Redact every configured secret pattern, returning the transformed text and +/// the per-kind match counts (only kinds that matched are included). +fn redact_common_secrets(input: &str) -> (String, Vec<(&'static str, u32)>) { let mut output = input.to_string(); - let mut count = 0u32; - for pattern in patterns { - let regex = Regex::new(pattern).map_err(|e| miette!("{e}"))?; - count = count.saturating_add(regex.find_iter(&output).count() as u32); - output = regex + let mut matches = Vec::new(); + for pattern in SECRET_PATTERNS.iter() { + let count = u32::try_from(pattern.regex.find_iter(&output).count()).unwrap_or(u32::MAX); + if count > 0 { + matches.push((pattern.kind, count)); + } + output = pattern + .regex .replace_all(&output, |captures: ®ex::Captures<'_>| { if captures.len() >= 4 { format!("{}{}[REDACTED]{}", &captures[1], &captures[2], &captures[3]) @@ -79,5 +114,5 @@ fn redact_common_secrets(input: &str) -> Result<(String, u32)> { }) .into_owned(); } - Ok((output, count)) + (output, matches) } diff --git a/crates/openshell-supervisor-middleware/src/lib.rs b/crates/openshell-supervisor-middleware/src/lib.rs index 7d9161fcf..b68d83c86 100644 --- a/crates/openshell-supervisor-middleware/src/lib.rs +++ b/crates/openshell-supervisor-middleware/src/lib.rs @@ -100,7 +100,7 @@ pub struct ChainOutcome { pub applied: Vec, } -#[derive(Debug, Clone, PartialEq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct NamespacedFinding { pub middleware: String, pub finding: Finding, @@ -112,6 +112,48 @@ pub struct MiddlewareInvocation { pub implementation: String, pub decision: Decision, pub transformed: bool, + /// True when the middleware could not be evaluated and `on_error` was applied + /// (service error, malformed/unsafe response, etc.). The `decision` reflects + /// the `on_error` outcome, not a decision the middleware actually returned. + pub failed: bool, +} + +enum OnErrorAction { + /// `fail_open`: skip this middleware, leaving the request unchanged. + FailOpen, + /// `fail_closed`: short-circuit the chain and deny with the given reason. + FailClosed(String), +} + +/// Apply a middleware entry's `on_error` policy after a failure (service error or +/// malformed response). Records a `failed` invocation for telemetry in both cases. +fn apply_on_error( + entry: &ChainEntry, + reason: &str, + applied: &mut Vec, +) -> OnErrorAction { + match entry.on_error { + OnError::FailOpen => { + applied.push(MiddlewareInvocation { + name: entry.name.clone(), + implementation: entry.implementation.clone(), + decision: Decision::Allow, + transformed: false, + failed: true, + }); + OnErrorAction::FailOpen + } + OnError::FailClosed => { + applied.push(MiddlewareInvocation { + name: entry.name.clone(), + implementation: entry.implementation.clone(), + decision: Decision::Deny, + transformed: false, + failed: true, + }); + OnErrorAction::FailClosed(format!("middleware_failed: {reason}")) + } + } } #[derive(Clone)] @@ -150,20 +192,33 @@ impl ChainRunner { .await { Ok(result) => result.into_inner(), - Err(err) => match entry.on_error { - OnError::FailOpen => { - applied.push(MiddlewareInvocation { - name: entry.name.clone(), - implementation: entry.implementation.clone(), - decision: Decision::Allow, - transformed: false, - }); - continue; + Err(err) => { + match apply_on_error(entry, &safe_reason(&err.to_string()), &mut applied) { + OnErrorAction::FailOpen => continue, + OnErrorAction::FailClosed(reason) => { + return Ok(ChainOutcome { + allowed: false, + reason, + body, + added_headers, + findings, + metadata, + applied, + }); + } } - OnError::FailClosed => { + } + }; + + // A result proposing unsafe header mutations is a malformed response: + // route it through `on_error` instead of applying any of it. + if validate_header_mutations(&headers, &result.add_headers).is_err() { + match apply_on_error(entry, "unsafe_response_headers", &mut applied) { + OnErrorAction::FailOpen => continue, + OnErrorAction::FailClosed(reason) => { return Ok(ChainOutcome { allowed: false, - reason: format!("middleware_failed: {}", safe_reason(&err.to_string())), + reason, body, added_headers, findings, @@ -171,17 +226,15 @@ impl ChainRunner { applied, }); } - }, - }; - - validate_header_mutations(&headers, &result.add_headers)?; + } + } for (name, value) in &result.add_headers { headers.insert(name.to_ascii_lowercase(), value.clone()); added_headers.insert(name.to_ascii_lowercase(), value.clone()); } let transformed = result.has_body; if result.has_body { - body = result.body.clone(); + result.body.clone_into(&mut body); } for finding in result.findings { findings.push(NamespacedFinding { @@ -200,6 +253,7 @@ impl ChainRunner { implementation: entry.implementation.clone(), decision: Decision::try_from(result.decision).unwrap_or(Decision::Unspecified), transformed, + failed: false, }); if result.decision == Decision::Deny as i32 { return Ok(ChainOutcome { @@ -264,7 +318,7 @@ fn validate_header_mutations( mutations: &HashMap, ) -> Result<()> { let mut seen = HashSet::new(); - for name in mutations.keys() { + for (name, value) in mutations { let lower = name.to_ascii_lowercase(); if !seen.insert(lower.clone()) || existing_headers.contains_key(&lower) { return Err(miette!( @@ -274,10 +328,27 @@ fn validate_header_mutations( if !is_safe_append_header(&lower) { return Err(miette!("middleware cannot append unsafe header '{name}'")); } + // Reject CR/LF and other control characters in the value: writing them + // verbatim into the upstream header block would enable header injection + // and request smuggling past the credential boundary. + if !is_safe_header_value(value) { + return Err(miette!( + "middleware cannot append header '{name}' with an unsafe value" + )); + } } Ok(()) } +/// A header value is safe to append only if it contains no control characters. +/// Horizontal tab, printable ASCII, and obs-text (>= 0x80) are permitted; CR, LF, +/// NUL, and other control bytes are rejected. +fn is_safe_header_value(value: &str) -> bool { + value + .bytes() + .all(|b| b == b'\t' || (0x20..=0x7e).contains(&b) || b >= 0x80) +} + fn is_safe_append_header(name: &str) -> bool { if name.is_empty() || name.contains(':') @@ -312,13 +383,12 @@ mod tests { name: name.into(), implementation: BUILTIN_SECRETS.into(), config: prost_types::Struct { - fields: [( + fields: std::iter::once(( "secrets".into(), prost_types::Value { kind: Some(prost_types::value::Kind::StringValue("redact".into())), }, - )] - .into_iter() + )) .collect(), }, on_error, @@ -427,11 +497,191 @@ mod tests { fn unsafe_header_mutation_is_rejected() { let err = validate_header_mutations( &BTreeMap::new(), - &[("Authorization".into(), "Bearer nope".into())] - .into_iter() - .collect(), + &std::iter::once(("Authorization".into(), "Bearer nope".into())).collect(), ) .expect_err("unsafe header"); assert!(err.to_string().contains("unsafe header")); } + + #[test] + fn header_value_with_crlf_is_rejected() { + // A safe header *name* with a CRLF-bearing value must still be rejected, + // otherwise it would inject extra headers into the upstream request. + let err = validate_header_mutations( + &BTreeMap::new(), + &std::iter::once(( + "x-openshell-middleware-inject".into(), + "ok\r\nAuthorization: Bearer evil".into(), + )) + .collect(), + ) + .expect_err("crlf value"); + assert!(err.to_string().contains("unsafe value")); + } + + /// A mock middleware that returns a fixed, caller-supplied result for every + /// evaluation. Used to exercise chain behavior the built-in cannot produce + /// (explicit deny, metadata, findings, unsafe header mutations). + struct ScriptedService { + result: openshell_core::proto::HttpRequestResult, + } + + #[tonic::async_trait] + impl SupervisorMiddleware for ScriptedService { + async fn describe( + &self, + _request: Request<()>, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + Ok(tonic::Response::new( + openshell_core::proto::MiddlewareManifest::default(), + )) + } + + async fn validate_config( + &self, + _request: Request, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + Ok(tonic::Response::new( + openshell_core::proto::ValidateConfigResponse { + valid: true, + reason: String::new(), + }, + )) + } + + async fn evaluate_http_request( + &self, + _request: Request, + ) -> std::result::Result< + tonic::Response, + tonic::Status, + > { + Ok(tonic::Response::new(self.result.clone())) + } + } + + fn allow_result() -> openshell_core::proto::HttpRequestResult { + openshell_core::proto::HttpRequestResult { + decision: Decision::Allow as i32, + reason: String::new(), + body: Vec::new(), + has_body: false, + add_headers: HashMap::new(), + findings: Vec::new(), + metadata: HashMap::new(), + } + } + + #[tokio::test] + async fn deny_decision_short_circuits_chain() { + let runner = ChainRunner::new(Arc::new(ScriptedService { + result: openshell_core::proto::HttpRequestResult { + decision: Decision::Deny as i32, + reason: "blocked_by_policy".into(), + ..allow_result() + }, + })); + let outcome = runner + .evaluate( + &[ + entry("first", OnError::FailClosed), + entry("second", OnError::FailClosed), + ], + input("hello"), + ) + .await + .expect("evaluate"); + assert!(!outcome.allowed); + assert_eq!(outcome.reason, "blocked_by_policy"); + // The deny short-circuits the chain: the second middleware never runs. + assert_eq!(outcome.applied.len(), 1); + assert_eq!(outcome.applied[0].decision, Decision::Deny); + assert!(!outcome.applied[0].failed); + } + + #[tokio::test] + async fn metadata_and_findings_are_namespaced_per_config() { + let runner = ChainRunner::new(Arc::new(ScriptedService { + result: openshell_core::proto::HttpRequestResult { + findings: vec![Finding { + r#type: "pii.email".into(), + label: "email address".into(), + count: 2, + confidence: "high".into(), + severity: "medium".into(), + }], + metadata: std::iter::once(("sensitivity".to_string(), "high".to_string())) + .collect(), + ..allow_result() + }, + })); + let outcome = runner + .evaluate( + &[ + entry("alpha", OnError::FailClosed), + entry("beta", OnError::FailClosed), + ], + input("hello"), + ) + .await + .expect("evaluate"); + assert!(outcome.allowed); + // Metadata is bucketed under each config's local name, so two configs + // emitting the same key do not collide. + assert_eq!(outcome.metadata["alpha"]["sensitivity"], "high"); + assert_eq!(outcome.metadata["beta"]["sensitivity"], "high"); + // Findings are tagged with the emitting config's name. + assert_eq!(outcome.findings.len(), 2); + assert_eq!(outcome.findings[0].middleware, "alpha"); + assert_eq!(outcome.findings[1].middleware, "beta"); + assert_eq!(outcome.findings[0].finding.r#type, "pii.email"); + assert_eq!(outcome.findings[0].finding.count, 2); + } + + fn unsafe_header_service() -> ScriptedService { + ScriptedService { + result: openshell_core::proto::HttpRequestResult { + add_headers: std::iter::once(( + "x-openshell-middleware-inject".to_string(), + "ok\r\nHost: evil".to_string(), + )) + .collect(), + ..allow_result() + }, + } + } + + #[tokio::test] + async fn malformed_response_headers_fail_closed_denies() { + let runner = ChainRunner::new(Arc::new(unsafe_header_service())); + let outcome = runner + .evaluate(&[entry("redact", OnError::FailClosed)], input("hello")) + .await + .expect("evaluate"); + assert!(!outcome.allowed); + assert!(outcome.reason.starts_with("middleware_failed:")); + assert!(outcome.applied.iter().any(|inv| inv.failed)); + // The unsafe header is never forwarded. + assert!(outcome.added_headers.is_empty()); + } + + #[tokio::test] + async fn malformed_response_headers_fail_open_continues() { + let runner = ChainRunner::new(Arc::new(unsafe_header_service())); + let outcome = runner + .evaluate(&[entry("redact", OnError::FailOpen)], input("hello")) + .await + .expect("evaluate"); + assert!(outcome.allowed); + assert_eq!(outcome.body, b"hello"); + assert!(outcome.added_headers.is_empty()); + assert_eq!(outcome.applied.len(), 1); + assert!(outcome.applied[0].failed); + } } diff --git a/crates/openshell-supervisor-network/Cargo.toml b/crates/openshell-supervisor-network/Cargo.toml index fd8fad5f7..b8cae5113 100644 --- a/crates/openshell-supervisor-network/Cargo.toml +++ b/crates/openshell-supervisor-network/Cargo.toml @@ -55,6 +55,7 @@ tempfile = "3" temp-env = "0.3" tokio-tungstenite = { workspace = true } futures = { workspace = true } +tracing-subscriber = { workspace = true } [target.'cfg(unix)'.dev-dependencies] libc = "0.2" diff --git a/crates/openshell-supervisor-network/src/l7/relay.rs b/crates/openshell-supervisor-network/src/l7/relay.rs index 4d501d0a3..6e5c3c4e9 100644 --- a/crates/openshell-supervisor-network/src/l7/relay.rs +++ b/crates/openshell-supervisor-network/src/l7/relay.rs @@ -783,9 +783,18 @@ async fn apply_middleware_chain( if chain.is_empty() { return Ok(MiddlewareApplyResult::Allowed(req)); } - let buffered = - crate::l7::rest::buffer_request_body_for_middleware(&req, client, Some(generation_guard)) - .await?; + let buffered = match crate::l7::rest::buffer_request_body_for_middleware( + &req, + client, + Some(generation_guard), + ) + .await? + { + crate::l7::rest::BufferResult::Buffered(buffered) => buffered, + crate::l7::rest::BufferResult::OverCapacity { recoverable } => { + return Ok(resolve_unbuffered_body(ctx, req, &chain, recoverable)); + } + }; let headers = safe_middleware_headers(&buffered.headers)?; let input = openshell_supervisor_middleware::HttpRequestInput { request_id: uuid::Uuid::new_v4().to_string(), @@ -819,6 +828,52 @@ async fn apply_middleware_chain( } } +/// Apply the chain's `on_error` policy when the request body cannot be buffered +/// for inspection because it exceeds the size cap. The RFC treats an unbufferable +/// body as an `on_error` event: it is denied unless every attached middleware is +/// `fail_open`, and passing it through is only safe when no bytes were consumed. +fn resolve_unbuffered_body( + ctx: &L7EvalContext, + req: crate::l7::provider::L7Request, + chain: &[openshell_supervisor_middleware::ChainEntry], + recoverable: bool, +) -> MiddlewareApplyResult { + let all_fail_open = chain + .iter() + .all(|entry| entry.on_error == openshell_supervisor_middleware::OnError::FailOpen); + if recoverable && all_fail_open { + emit_middleware_body_unavailable(ctx, false); + return MiddlewareApplyResult::Allowed(req); + } + emit_middleware_body_unavailable(ctx, true); + MiddlewareApplyResult::Denied("middleware_failed: request_body_over_capacity".into()) +} + +fn emit_middleware_body_unavailable(ctx: &L7EvalContext, denied: bool) { + let event = DetectionFindingBuilder::new(openshell_ocsf::ctx::ctx()) + .severity(if denied { + SeverityId::High + } else { + SeverityId::Medium + }) + .finding_info(FindingInfo::new( + "openshell.middleware.body_unavailable", + "Supervisor middleware could not inspect request body", + )) + .evidence_pairs(&[ + ("policy", ctx.policy_name.as_str()), + ("host", ctx.host.as_str()), + ("disposition", if denied { "denied" } else { "fail_open" }), + ]) + .message(if denied { + "Request body exceeded middleware inspection cap; denied" + } else { + "Request body exceeded middleware inspection cap; passed through (fail_open)" + }) + .build(); + ocsf_emit!(event); +} + fn safe_middleware_headers(headers: &[u8]) -> Result> { let header_str = std::str::from_utf8(headers).map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; @@ -885,14 +940,37 @@ fn emit_middleware_events( .dst_endpoint(Endpoint::from_domain(&ctx.host, ctx.port)) .firewall_rule(&ctx.policy_name, "middleware") .message(format!( - "MIDDLEWARE {} {} decision={:?} transformed={}", + "MIDDLEWARE {} {} decision={:?} transformed={} failed={}", invocation.name, invocation.implementation, invocation.decision, - invocation.transformed + invocation.transformed, + invocation.failed )) .build(); ocsf_emit!(event); + + // A middleware that failed but was bypassed under `fail_open` is an + // enforcement failure operators must be able to alert on, even though the + // request proceeded. + if invocation.failed && allowed { + let event = DetectionFindingBuilder::new(openshell_ocsf::ctx::ctx()) + .severity(SeverityId::Medium) + .finding_info(FindingInfo::new( + "openshell.middleware.failure", + "Supervisor middleware failed open", + )) + .evidence_pairs(&[ + ("middleware", invocation.name.as_str()), + ("implementation", invocation.implementation.as_str()), + ]) + .message(format!( + "Middleware {} failed and was bypassed (fail_open)", + invocation.name + )) + .build(); + ocsf_emit!(event); + } } if !outcome.allowed && outcome.reason.starts_with("middleware_failed:") { let event = DetectionFindingBuilder::new(openshell_ocsf::ctx::ctx()) @@ -2658,6 +2736,407 @@ network_policies: .unwrap(); } + #[tokio::test] + async fn l7_rest_middleware_over_capacity_fails_closed() { + let (config, tunnel_engine, ctx) = + middleware_relay_context("openshell/secrets", "fail_closed"); + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_inspection( + &config, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + // A declared body far above the 256 KiB inspection cap must be denied + // (fail-closed) before the body is read or reaches the upstream. + let request = format!( + "POST /v1/messages HTTP/1.1\r\nHost: api.example.test\r\nContent-Length: {}\r\nConnection: close\r\n\r\n", + 300 * 1024 + ); + app.write_all(request.as_bytes()).await.unwrap(); + + let mut response = [0u8; 512]; + let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) + .await + .expect("denial should reach client") + .unwrap(); + let response = String::from_utf8_lossy(&response[..n]); + assert!(response.contains("403 Forbidden")); + assert!(response.contains("request_body_over_capacity")); + + let mut upstream_request = [0u8; 32]; + let result = tokio::time::timeout( + std::time::Duration::from_millis(100), + upstream.read(&mut upstream_request), + ) + .await; + assert!( + matches!(result, Err(_) | Ok(Ok(0))), + "upstream should not receive request bytes" + ); + + drop(app); + tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should finish") + .unwrap() + .unwrap(); + } + + #[test] + fn over_capacity_resolution_honors_on_error() { + use openshell_supervisor_middleware::{ChainEntry, OnError}; + + let ctx = L7EvalContext { + host: "api.example.test".into(), + port: 443, + policy_name: "p".into(), + binary_path: "/usr/bin/curl".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + let req = || crate::l7::provider::L7Request { + action: "POST".into(), + target: "/v1".into(), + query_params: std::collections::HashMap::new(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }; + let fail_open = ChainEntry { + name: "m".into(), + implementation: "openshell/secrets".into(), + config: prost_types::Struct::default(), + on_error: OnError::FailOpen, + }; + let fail_closed = ChainEntry { + on_error: OnError::FailClosed, + ..fail_open.clone() + }; + + // Recoverable (Content-Length over cap, nothing consumed) + all fail-open + // -> stream through unprocessed. + assert!(matches!( + resolve_unbuffered_body(&ctx, req(), std::slice::from_ref(&fail_open), true), + MiddlewareApplyResult::Allowed(_) + )); + // Any fail-closed entry -> deny. + assert!(matches!( + resolve_unbuffered_body(&ctx, req(), &[fail_open.clone(), fail_closed], true), + MiddlewareApplyResult::Denied(_) + )); + // Not recoverable (chunked overflow already consumed bytes) -> deny even + // when every entry is fail-open. + assert!(matches!( + resolve_unbuffered_body(&ctx, req(), &[fail_open], false), + MiddlewareApplyResult::Denied(_) + )); + } + + /// Tracing layer that captures emitted `OcsfEvent`s for assertions. + struct OcsfCaptureLayer(Arc>>); + + impl tracing_subscriber::Layer for OcsfCaptureLayer { + fn on_event( + &self, + event: &tracing::Event<'_>, + _ctx: tracing_subscriber::layer::Context<'_, S>, + ) { + if event.metadata().target() == openshell_ocsf::OCSF_TARGET + && let Some(ocsf_event) = openshell_ocsf::clone_current_event() + { + self.0.lock().unwrap().push(ocsf_event); + } + } + } + + #[test] + fn middleware_ocsf_events_are_audit_safe() { + use openshell_supervisor_middleware::{ + ChainOutcome, MiddlewareInvocation, NamespacedFinding, + }; + use tracing_subscriber::layer::SubscriberExt; + + const RAW_SECRET: &str = "sk-RAWSECRETVALUE0123456789"; + + let events = Arc::new(std::sync::Mutex::new(Vec::new())); + let subscriber = tracing_subscriber::registry().with(OcsfCaptureLayer(Arc::clone(&events))); + let _guard = tracing::subscriber::set_default(subscriber); + + let ctx = L7EvalContext { + host: "api.example.test".into(), + port: 443, + policy_name: "rest_api".into(), + binary_path: "/usr/bin/curl".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + let req = crate::l7::provider::L7Request { + action: "POST".into(), + target: "/v1/messages".into(), + query_params: std::collections::HashMap::new(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }; + let outcome = ChainOutcome { + allowed: true, + reason: String::new(), + // The transformed body still holds the raw secret; emission must never + // serialize it. + body: format!(r#"{{"api_key":"{RAW_SECRET}"}}"#).into_bytes(), + added_headers: BTreeMap::new(), + findings: vec![NamespacedFinding { + middleware: "redact-secrets".into(), + finding: openshell_core::proto::Finding { + r#type: "secret.common".into(), + label: "common secret pattern".into(), + count: 1, + confidence: "medium".into(), + severity: "medium".into(), + }, + }], + metadata: BTreeMap::new(), + applied: vec![MiddlewareInvocation { + name: "redact-secrets".into(), + implementation: "openshell/secrets".into(), + decision: openshell_core::proto::Decision::Allow, + transformed: true, + failed: false, + }], + }; + + emit_middleware_events(&ctx, &req, &outcome); + + let captured = events.lock().unwrap(); + // Per-invocation decisions are HTTP Activity (class 4002). + assert!( + captured.iter().any(|e| e.class_uid() == 4002), + "expected an HTTP Activity event for the middleware invocation" + ); + // Findings are Detection Finding (class 2004) with the finding's severity. + let finding_event = captured + .iter() + .find(|e| e.class_uid() == 2004) + .expect("expected a Detection Finding event"); + assert_eq!(finding_event.base().severity, SeverityId::Medium); + + // No raw payload material may appear in any emitted event. + let serialized = serde_json::to_string(&*captured).expect("serialize events"); + assert!( + !serialized.contains(RAW_SECRET), + "raw secret leaked into OCSF events: {serialized}" + ); + // Safe finding metadata is still present. + assert!(serialized.contains("secret.common")); + } + + #[tokio::test] + async fn passthrough_relay_runs_middleware_redaction() { + // A no-protocol endpoint takes the credential-injection passthrough path; + // policy-level middleware must still inspect and redact its body. + let data = r#" +network_middlewares: + - name: request-middleware + middleware: openshell/secrets + on_error: fail_closed +network_policies: + passthrough_api: + name: passthrough_api + middleware: ["request-middleware"] + endpoints: + - host: api.example.test + port: 8080 + binaries: + - { path: /usr/bin/curl } +"#; + let engine = Arc::new(OpaEngine::from_strings(TEST_POLICY, data).unwrap()); + let generation_guard = engine + .generation_guard(engine.current_generation()) + .unwrap(); + let ctx = L7EvalContext { + host: "api.example.test".into(), + port: 8080, + policy_name: "passthrough_api".into(), + binary_path: "/usr/bin/curl".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let engine_task = Arc::clone(&engine); + let relay = tokio::spawn(async move { + relay_passthrough_with_credentials( + &mut relay_client, + &mut relay_upstream, + &ctx, + &generation_guard, + Some(engine_task.as_ref()), + ) + .await + }); + + let body = br#"{"api_key":"sk-1234567890abcdef"}"#; + let request = format!( + "POST /v1/messages HTTP/1.1\r\nHost: api.example.test\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + std::str::from_utf8(body).unwrap() + ); + app.write_all(request.as_bytes()).await.unwrap(); + + let mut upstream_request = [0u8; 1024]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut upstream_request), + ) + .await + .expect("request should reach upstream") + .unwrap(); + let upstream_request = String::from_utf8_lossy(&upstream_request[..n]); + assert!( + upstream_request.contains(r#""api_key":"[REDACTED]""#), + "unexpected upstream request: {upstream_request:?}" + ); + assert!(!upstream_request.contains("sk-1234567890abcdef")); + + upstream + .write_all(b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\nConnection: close\r\n\r\n") + .await + .unwrap(); + let mut client_response = [0u8; 512]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + app.read(&mut client_response), + ) + .await + .expect("response should reach client") + .unwrap(); + assert!(String::from_utf8_lossy(&client_response[..n]).contains("204 No Content")); + drop(app); + tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should finish") + .unwrap() + .unwrap(); + } + + #[tokio::test] + async fn websocket_upgrade_request_is_inspected_and_denied() { + // The WebSocket upgrade handshake is an HTTP request the hook can inspect + // and deny: a fail-closed middleware blocks the upgrade before it is + // forwarded. + let data = r#" +network_middlewares: + - name: request-middleware + middleware: example/unavailable + on_error: fail_closed +network_policies: + ws_api: + name: ws_api + middleware: ["request-middleware"] + endpoints: + - host: gateway.example.test + port: 443 + protocol: websocket + enforcement: enforce + rules: + - allow: + method: GET + path: "/ws" + binaries: + - { path: /usr/bin/node } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "gateway.example.test".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let (endpoint_config, generation) = engine + .query_endpoint_config_with_generation(&input) + .unwrap(); + let config = crate::l7::parse_l7_config(&endpoint_config.unwrap()).unwrap(); + let tunnel_engine = engine.clone_engine_for_tunnel(generation).unwrap(); + let ctx = L7EvalContext { + host: "gateway.example.test".into(), + port: 443, + policy_name: "ws_api".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_inspection( + &config, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + app.write_all( + b"GET /ws HTTP/1.1\r\nHost: gateway.example.test\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\nSec-WebSocket-Version: 13\r\n\r\n", + ) + .await + .unwrap(); + + let mut response = [0u8; 512]; + let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) + .await + .expect("denial should reach client") + .unwrap(); + let response = String::from_utf8_lossy(&response[..n]); + assert!(response.contains("403 Forbidden")); + assert!(response.contains("middleware_failed")); + + let mut upstream_request = [0u8; 32]; + let result = tokio::time::timeout( + std::time::Duration::from_millis(100), + upstream.read(&mut upstream_request), + ) + .await; + assert!( + matches!(result, Err(_) | Ok(Ok(0))), + "upstream should not receive the upgrade request" + ); + + drop(app); + tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should finish") + .unwrap() + .unwrap(); + } + #[tokio::test] async fn passthrough_relay_injects_token_grant_authorization_header() { let (generation_guard, ctx, fixture) = diff --git a/crates/openshell-supervisor-network/src/l7/rest.rs b/crates/openshell-supervisor-network/src/l7/rest.rs index 1a4036abd..2c85cacf6 100644 --- a/crates/openshell-supervisor-network/src/l7/rest.rs +++ b/crates/openshell-supervisor-network/src/l7/rest.rs @@ -774,11 +774,23 @@ pub(crate) struct BufferedRequestBody { pub(crate) body: Vec, } +/// Result of attempting to buffer a request body for middleware inspection. +pub(crate) enum BufferResult { + /// The full body was buffered within the size cap. + Buffered(BufferedRequestBody), + /// The body exceeded the inspection cap. `recoverable` is true when no body + /// bytes were consumed yet (a declared `Content-Length` over the cap), so the + /// request can still be streamed through unprocessed under fail-open. It is + /// false once bytes have been consumed (chunked overflow), where denying is + /// the only safe outcome. + OverCapacity { recoverable: bool }, +} + pub(crate) async fn buffer_request_body_for_middleware( req: &L7Request, client: &mut C, generation_guard: Option<&PolicyGenerationGuard>, -) -> Result { +) -> Result { let header_end = req .raw_header .windows(4) @@ -787,17 +799,19 @@ pub(crate) async fn buffer_request_body_for_middleware( let headers = req.raw_header[..header_end].to_vec(); let already_read = &req.raw_header[header_end..]; match req.body_length { - BodyLength::None => Ok(BufferedRequestBody { + BodyLength::None => Ok(BufferResult::Buffered(BufferedRequestBody { headers, body: already_read.to_vec(), - }), + })), BodyLength::ContentLength(len) => { - let len = usize::try_from(len) - .map_err(|_| miette!("request body is too large for middleware"))?; + // The declared length is known before any further reads, so an + // over-cap body here has not consumed the stream and can be passed + // through unprocessed if every middleware is fail-open. + let Ok(len) = usize::try_from(len) else { + return Ok(BufferResult::OverCapacity { recoverable: true }); + }; if len > MAX_MIDDLEWARE_BODY_BYTES { - return Err(miette!( - "middleware buffers at most {MAX_MIDDLEWARE_BODY_BYTES} request body bytes" - )); + return Ok(BufferResult::OverCapacity { recoverable: true }); } let initial_len = already_read.len().min(len); let mut body = Vec::with_capacity(len); @@ -818,11 +832,21 @@ pub(crate) async fn buffer_request_body_for_middleware( body.extend_from_slice(&buf[..n]); remaining -= n; } - Ok(BufferedRequestBody { headers, body }) + Ok(BufferResult::Buffered(BufferedRequestBody { + headers, + body, + })) } BodyLength::Chunked => { - let body = collect_chunked_body(client, already_read, generation_guard).await?; - Ok(BufferedRequestBody { headers, body }) + // Chunked bodies are decoded incrementally into the payload bytes + // middleware expects. On overflow, we have already consumed wire + // bytes from the client stream and cannot re-enter the normal raw + // relay path without a separate splice-through buffer. + Ok(collect_chunked_body(client, already_read, generation_guard) + .await + .map_or(BufferResult::OverCapacity { recoverable: false }, |body| { + BufferResult::Buffered(BufferedRequestBody { headers, body }) + })) } } } @@ -835,7 +859,7 @@ pub(crate) fn rebuild_request_with_buffered_body( ) -> Result { let mut header_bytes = set_content_length(headers, body.len())?; header_bytes = strip_header(&header_bytes, "transfer-encoding")?; - header_bytes = append_headers(&header_bytes, add_headers)?; + header_bytes = append_headers(&header_bytes, add_headers); header_bytes.extend_from_slice(body); Ok(L7Request { action: req.action.clone(), @@ -900,15 +924,11 @@ async fn collect_and_rewrite_request_body( } BodyLength::Chunked => { let body = collect_chunked_body(client, already_read, generation_guard).await?; - if body_bytes_contain_reserved_marker(&body) { - return Err(miette!( - "request body credential rewrite does not support chunked bodies containing credential placeholders" - )); - } - Ok(PreparedRequestBody { - headers: rewritten_headers.to_vec(), - body, - }) + let (mut headers, body) = + rewrite_buffered_body(rewritten_headers, original_header_str, body, resolver)?; + headers = set_content_length(&headers, body.len())?; + headers = strip_header(&headers, "transfer-encoding")?; + Ok(PreparedRequestBody { headers, body }) } } } @@ -1076,37 +1096,15 @@ async fn collect_chunked_body( already_read: &[u8], generation_guard: Option<&PolicyGenerationGuard>, ) -> Result> { - let mut read_buf = [0u8; RELAY_BUF_SIZE]; - let mut parse_buf = Vec::from(already_read); - let mut pos = 0usize; + let mut buffered_pos = 0usize; + let mut body = Vec::new(); loop { - if parse_buf.len() > MAX_REWRITE_BODY_BYTES { - return Err(miette!( - "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" - )); - } - - let size_line_end = loop { - if let Some(end) = find_crlf(&parse_buf, pos) { - break end; - } - let n = client.read(&mut read_buf).await.into_diagnostic()?; - if n == 0 { - return Err(miette!("Chunked body ended before chunk-size line")); - } - if let Some(guard) = generation_guard { - guard.ensure_current()?; - } - parse_buf.extend_from_slice(&read_buf[..n]); - if parse_buf.len() > MAX_REWRITE_BODY_BYTES { - return Err(miette!( - "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" - )); - } - }; - - let size_line = std::str::from_utf8(&parse_buf[pos..size_line_end]) + let size_line = + read_chunked_line(client, already_read, &mut buffered_pos, generation_guard) + .await + .map_err(|e| miette!("Chunked body ended before chunk-size line: {e}"))?; + let size_line = std::str::from_utf8(&size_line) .into_diagnostic() .map_err(|_| miette!("Invalid UTF-8 in chunk-size line"))?; let size_token = size_line @@ -1117,64 +1115,109 @@ async fn collect_chunked_body( let chunk_size = usize::from_str_radix(size_token, 16) .into_diagnostic() .map_err(|_| miette!("Invalid chunk size token: {size_token:?}"))?; - pos = size_line_end + 2; if chunk_size == 0 { loop { - let trailer_end = loop { - if let Some(end) = find_crlf(&parse_buf, pos) { - break end; - } - let n = client.read(&mut read_buf).await.into_diagnostic()?; - if n == 0 { - return Err(miette!("Chunked body ended before trailer terminator")); - } - if let Some(guard) = generation_guard { - guard.ensure_current()?; - } - parse_buf.extend_from_slice(&read_buf[..n]); - if parse_buf.len() > MAX_REWRITE_BODY_BYTES { - return Err(miette!( - "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" - )); - } - }; - let trailer_line = &parse_buf[pos..trailer_end]; - pos = trailer_end + 2; + let trailer_line = + read_chunked_line(client, already_read, &mut buffered_pos, generation_guard) + .await + .map_err(|e| { + miette!("Chunked body ended before trailer terminator: {e}") + })?; if trailer_line.is_empty() { - return Ok(parse_buf); + return Ok(body); } } } - let chunk_end = pos - .checked_add(chunk_size) - .ok_or_else(|| miette!("Chunk size overflow"))?; - let chunk_with_crlf_end = chunk_end - .checked_add(2) - .ok_or_else(|| miette!("Chunk size overflow"))?; - while parse_buf.len() < chunk_with_crlf_end { - let n = client.read(&mut read_buf).await.into_diagnostic()?; - if n == 0 { - return Err(miette!("Chunked body ended mid-chunk")); - } - if let Some(guard) = generation_guard { - guard.ensure_current()?; - } - parse_buf.extend_from_slice(&read_buf[..n]); - if parse_buf.len() > MAX_REWRITE_BODY_BYTES { - return Err(miette!( - "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" - )); - } + if body.len().saturating_add(chunk_size) > MAX_REWRITE_BODY_BYTES { + return Err(miette!( + "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" + )); } - if &parse_buf[chunk_end..chunk_with_crlf_end] != b"\r\n" { + read_buffered_exact( + client, + already_read, + &mut buffered_pos, + chunk_size, + &mut body, + generation_guard, + ) + .await + .map_err(|e| miette!("Chunked body ended mid-chunk: {e}"))?; + + let mut chunk_crlf = Vec::with_capacity(2); + read_buffered_exact( + client, + already_read, + &mut buffered_pos, + 2, + &mut chunk_crlf, + generation_guard, + ) + .await + .map_err(|e| miette!("Chunked body ended before chunk terminator: {e}"))?; + if chunk_crlf.as_slice() != b"\r\n" { return Err(miette!("Chunk missing terminating CRLF")); } - pos = chunk_with_crlf_end; } } +async fn read_chunked_line( + client: &mut C, + already_read: &[u8], + buffered_pos: &mut usize, + generation_guard: Option<&PolicyGenerationGuard>, +) -> Result> { + let mut line = Vec::new(); + loop { + let byte = read_buffered_byte(client, already_read, buffered_pos, generation_guard).await?; + line.push(byte); + if line.len() > MAX_REWRITE_BODY_BYTES { + return Err(miette!( + "request body credential rewrite buffers at most {MAX_REWRITE_BODY_BYTES} bytes" + )); + } + if line.ends_with(b"\r\n") { + line.truncate(line.len() - 2); + return Ok(line); + } + } +} + +async fn read_buffered_exact( + client: &mut C, + already_read: &[u8], + buffered_pos: &mut usize, + len: usize, + out: &mut Vec, + generation_guard: Option<&PolicyGenerationGuard>, +) -> Result<()> { + for _ in 0..len { + let byte = read_buffered_byte(client, already_read, buffered_pos, generation_guard).await?; + out.push(byte); + } + Ok(()) +} + +async fn read_buffered_byte( + client: &mut C, + already_read: &[u8], + buffered_pos: &mut usize, + generation_guard: Option<&PolicyGenerationGuard>, +) -> Result { + if *buffered_pos < already_read.len() { + let byte = already_read[*buffered_pos]; + *buffered_pos += 1; + return Ok(byte); + } + let byte = client.read_u8().await.into_diagnostic()?; + if let Some(guard) = generation_guard { + guard.ensure_current()?; + } + Ok(byte) +} + fn content_type(headers: &str) -> Option { headers.lines().skip(1).find_map(|line| { let (name, value) = line.split_once(':')?; @@ -1262,9 +1305,9 @@ fn strip_header(headers: &[u8], strip_name: &str) -> Result> { fn append_headers( headers: &[u8], add_headers: &std::collections::BTreeMap, -) -> Result> { +) -> Vec { if add_headers.is_empty() { - return Ok(headers.to_vec()); + return headers.to_vec(); } let split = headers .windows(4) @@ -1279,7 +1322,7 @@ fn append_headers( out.extend_from_slice(value.as_bytes()); } out.extend_from_slice(b"\r\n\r\n"); - Ok(out) + out } pub(crate) fn request_is_websocket_upgrade(raw_header: &[u8]) -> bool { @@ -3151,6 +3194,20 @@ mod tests { } } + #[tokio::test] + async fn collect_chunked_body_decodes_payload_bytes() { + let mut client = tokio::io::empty(); + let body = collect_chunked_body( + &mut client, + b"5\r\nhello\r\n6;ext=value\r\n world\r\n0\r\nx-checksum: abc\r\n\r\n", + None, + ) + .await + .expect("chunked body should decode"); + + assert_eq!(body, b"hello world"); + } + /// SEC-009: Bare LF in headers enables header injection. #[tokio::test] async fn reject_bare_lf_in_headers() { @@ -5257,6 +5314,38 @@ mod tests { assert!(!forwarded.contains("OPENSHELL-RESOLVE-ENV")); } + #[tokio::test] + async fn relay_request_body_rewrite_normalizes_chunked_payload() { + let (_, resolver) = SecretResolver::from_provider_env( + [("API_TOKEN".to_string(), "provider-real-token".to_string())] + .into_iter() + .collect(), + ); + let resolver = resolver.expect("resolver"); + let alias = "provider.v1-OPENSHELL-RESOLVE-ENV-API_TOKEN"; + let raw = format!( + "POST /api/messages HTTP/1.1\r\n\ + Host: api.example.com\r\n\ + Authorization: Bearer {alias}\r\n\ + Transfer-Encoding: chunked\r\n\r\n\ + 5\r\nhello\r\n0\r\n\r\n", + ); + + let forwarded = relay_and_capture_with_options( + raw.into_bytes(), + BodyLength::Chunked, + Some(&resolver), + true, + ) + .await + .expect("relay should succeed"); + + assert!(forwarded.contains("Authorization: Bearer provider-real-token\r\n")); + assert!(forwarded.contains("Content-Length: 5\r\n")); + assert!(!forwarded.contains("Transfer-Encoding: chunked\r\n")); + assert!(forwarded.ends_with("hello")); + } + #[tokio::test] async fn relay_request_body_rewrites_percent_encoded_canonical_urlencoded_token() { let (_, resolver) = SecretResolver::from_provider_env( diff --git a/crates/openshell-supervisor-network/src/opa.rs b/crates/openshell-supervisor-network/src/opa.rs index 451c57e59..a584b414b 100644 --- a/crates/openshell-supervisor-network/src/opa.rs +++ b/crates/openshell-supervisor-network/src/opa.rs @@ -777,11 +777,7 @@ fn query_middleware_chain_locked( .map_err(|e| miette::miette!("{e}"))?; let contexts = parse_middleware_contexts(&contexts_val); let Some(context) = select_middleware_context(&contexts, request_path) else { - return Ok(global_middleware_entries( - &configs, - &input.host, - &HashSet::new(), - )?); + return global_middleware_entries(&configs, &input.host, &HashSet::new()); }; let mut explicit = Vec::new(); @@ -876,12 +872,16 @@ fn middleware_selector_matches(config: ®orus::Value, host: &str) -> bool { let Some(selector) = get_field(config, "endpoints") else { return false; }; - let includes = get_str_array(selector, "include"); - let excludes = get_str_array(selector, "exclude"); - let included = - !includes.is_empty() && includes.iter().any(|pattern| host_matches(pattern, host)); - let excluded = excludes.iter().any(|pattern| host_matches(pattern, host)); - included && !excluded + let include_patterns = get_str_array(selector, "include"); + let exclude_patterns = get_str_array(selector, "exclude"); + let matches_include = !include_patterns.is_empty() + && include_patterns + .iter() + .any(|pattern| host_matches(pattern, host)); + let matches_exclude = exclude_patterns + .iter() + .any(|pattern| host_matches(pattern, host)); + matches_include && !matches_exclude } fn host_matches(pattern: &str, host: &str) -> bool { @@ -956,7 +956,6 @@ fn regorus_value_to_prost(value: ®orus::Value) -> prost_types::Value { }) .collect(), }), - regorus::Value::Null | regorus::Value::Undefined => Kind::NullValue(0), _ => Kind::NullValue(0), }), } From fb7544a234220e0184d72254b8a095ed296fb8cf Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Fri, 26 Jun 2026 16:20:15 -0700 Subject: [PATCH 03/16] fix(supervisor-middleware): default stored policy rule fields Signed-off-by: Piotr Mlocek --- crates/openshell-server/src/grpc/policy.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index 09e311bb2..ad4fdf5ba 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -7100,6 +7100,7 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], + ..Default::default() }; let chunk = DraftChunkRecord { id: "chunk-provider-prefix".to_string(), From c836c846215b25311950bd37b2305aa8561cefd9 Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Fri, 26 Jun 2026 16:58:26 -0700 Subject: [PATCH 04/16] fix(supervisor-middleware): resolve rebase policy conflicts Signed-off-by: Piotr Mlocek --- crates/openshell-policy/src/lib.rs | 4 +- .../data/sandbox-policy.rego | 35 ++-- .../openshell-supervisor-network/src/opa.rs | 183 +++++++++++++++--- proto/sandbox.proto | 2 + 4 files changed, 177 insertions(+), 47 deletions(-) diff --git a/crates/openshell-policy/src/lib.rs b/crates/openshell-policy/src/lib.rs index 6ccfd1158..21561596c 100644 --- a/crates/openshell-policy/src/lib.rs +++ b/crates/openshell-policy/src/lib.rs @@ -19,8 +19,8 @@ use std::path::Path; use miette::{IntoDiagnostic, Result, WrapErr}; use openshell_core::proto::{ FilesystemPolicy, GraphqlOperation, L7Allow, L7DenyRule, L7QueryMatcher, L7Rule, - LandlockPolicy, MiddlewareEndpointSelector, NetworkBinary, NetworkEndpoint, - NetworkMiddlewareConfig, NetworkPolicyRule, ProcessPolicy, SandboxPolicy, McpOptions, + LandlockPolicy, McpOptions, MiddlewareEndpointSelector, NetworkBinary, NetworkEndpoint, + NetworkMiddlewareConfig, NetworkPolicyRule, ProcessPolicy, SandboxPolicy, }; use serde::{Deserialize, Serialize}; diff --git a/crates/openshell-supervisor-network/data/sandbox-policy.rego b/crates/openshell-supervisor-network/data/sandbox-policy.rego index afa4f6947..9228416e1 100644 --- a/crates/openshell-supervisor-network/data/sandbox-policy.rego +++ b/crates/openshell-supervisor-network/data/sandbox-policy.rego @@ -842,14 +842,29 @@ _policy_endpoint_configs(policy) := [ep | endpoint_has_extended_config(ep) ] -# Collect matching endpoint configs across all policies. Iterates over -# _matching_policy_names (a set, safe from regorus variable collisions) -# then collects per-policy configs via the helper function. +# Collect matching endpoint identities across all policies. Iterates over +# _matching_policy_names (a set, safe from regorus variable collisions) then +# returns the selected policy name plus endpoint index/path. Rust uses that +# identity to look up middleware attachment from policy data. +_matching_endpoint_contexts := [ctx | + some pname + _matching_policy_names[pname] + policy := data.network_policies[pname] + ep := policy.endpoints[i] + endpoint_matches_request(ep, input.network) + ctx := { + "policy": pname, + "endpoint_index": i, + "endpoint_path": object.get(ep, "path", ""), + } +] + _matching_endpoint_configs := [cfg | some pname _matching_policy_names[pname] cfgs := _policy_endpoint_configs(data.network_policies[pname]) cfg := cfgs[_] + endpoint_has_extended_config(cfg) ] matched_endpoint_config := _matching_endpoint_configs[0] if { @@ -858,20 +873,6 @@ matched_endpoint_config := _matching_endpoint_configs[0] if { network_middlewares := object.get(data, "network_middlewares", []) -_matching_middleware_contexts := [ctx | - some pname - _matching_policy_names[pname] - policy := data.network_policies[pname] - some ep - ep := policy.endpoints[_] - endpoint_matches_request(ep, input.network) - ctx := { - "policy": pname, - "policy_middleware": object.get(policy, "middleware", []), - "endpoint": ep, - } -] - _policy_has_exact_declared_endpoint(policy) if { some ep ep := policy.endpoints[_] diff --git a/crates/openshell-supervisor-network/src/opa.rs b/crates/openshell-supervisor-network/src/opa.rs index a584b414b..3d0f75bf7 100644 --- a/crates/openshell-supervisor-network/src/opa.rs +++ b/crates/openshell-supervisor-network/src/opa.rs @@ -750,9 +750,9 @@ fn network_input_json(input: &NetworkInput) -> serde_json::Value { } #[derive(Debug, Clone)] -struct MiddlewareContext { - policy_middleware: Vec, - endpoint_middleware: Vec, +struct MatchedEndpointContext { + policy_name: String, + endpoint_index: usize, endpoint_path: String, } @@ -773,19 +773,20 @@ fn query_middleware_chain_locked( return Ok(Vec::new()); } let contexts_val = engine - .eval_rule("data.openshell.sandbox._matching_middleware_contexts".into()) + .eval_rule("data.openshell.sandbox._matching_endpoint_contexts".into()) .map_err(|e| miette::miette!("{e}"))?; - let contexts = parse_middleware_contexts(&contexts_val); - let Some(context) = select_middleware_context(&contexts, request_path) else { + let contexts = parse_endpoint_contexts(&contexts_val); + let Some(context) = select_endpoint_context(&contexts, request_path)? else { return global_middleware_entries(&configs, &input.host, &HashSet::new()); }; + let policies_val = engine + .eval_rule("data.network_policies".into()) + .map_err(|e| miette::miette!("{e}"))?; + let (policy_middleware, endpoint_middleware) = + middleware_for_endpoint_identity(&policies_val, context)?; let mut explicit = Vec::new(); - for name in context - .policy_middleware - .iter() - .chain(context.endpoint_middleware.iter()) - { + for name in policy_middleware.iter().chain(endpoint_middleware.iter()) { if !explicit.contains(name) { explicit.push(name.clone()); } @@ -814,7 +815,7 @@ fn parse_middleware_configs(value: ®orus::Value) -> Result Vec { +fn parse_endpoint_contexts(value: ®orus::Value) -> Vec { let regorus::Value::Array(values) = value else { return Vec::new(); }; @@ -824,30 +825,87 @@ fn parse_middleware_contexts(value: ®orus::Value) -> Vec { let regorus::Value::Object(_) = value else { return None; }; - let endpoint = get_field(value, "endpoint")?; - Some(MiddlewareContext { - policy_middleware: get_str_array(value, "policy_middleware"), - endpoint_middleware: get_str_array(endpoint, "middleware"), - endpoint_path: get_str(endpoint, "path").unwrap_or_default(), + Some(MatchedEndpointContext { + policy_name: get_str(value, "policy").unwrap_or_default(), + endpoint_index: get_usize(value, "endpoint_index").unwrap_or_default(), + endpoint_path: get_str(value, "endpoint_path").unwrap_or_default(), }) }) .collect() } -fn select_middleware_context<'a>( - contexts: &'a [MiddlewareContext], +fn middleware_for_endpoint_identity( + policies: ®orus::Value, + context: &MatchedEndpointContext, +) -> Result<(Vec, Vec)> { + let policy = get_field(policies, &context.policy_name).ok_or_else(|| { + miette::miette!( + "matched endpoint policy '{}' was not found in OPA data", + context.policy_name + ) + })?; + let endpoint = get_array(policy, "endpoints") + .and_then(|endpoints| endpoints.get(context.endpoint_index)) + .ok_or_else(|| { + miette::miette!( + "matched endpoint {}[{}] was not found in OPA data", + context.policy_name, + context.endpoint_index + ) + })?; + Ok(( + get_str_array(policy, "middleware"), + get_str_array(endpoint, "middleware"), + )) +} + +fn select_endpoint_context<'a>( + contexts: &'a [MatchedEndpointContext], request_path: &str, -) -> Option<&'a MiddlewareContext> { - contexts +) -> Result> { + let matching: Vec<_> = contexts .iter() .filter(|context| crate::l7::endpoint_path_matches(&context.endpoint_path, request_path)) - .max_by_key(|context| { - if context.endpoint_path.is_empty() { - 0 - } else { - context.endpoint_path.chars().filter(|c| *c != '*').count() - } - }) + .map(|context| (endpoint_path_specificity(&context.endpoint_path), context)) + .collect(); + let Some(max_specificity) = matching.iter().map(|(specificity, _)| *specificity).max() else { + return Ok(None); + }; + let best: Vec<_> = matching + .into_iter() + .filter(|(specificity, _)| *specificity == max_specificity) + .map(|(_, context)| context) + .collect(); + if best.len() > 1 { + let matches = best + .iter() + .map(|context| { + format!( + "{}[{}] path={}", + context.policy_name, + context.endpoint_index, + if context.endpoint_path.is_empty() { + "" + } else { + context.endpoint_path.as_str() + } + ) + }) + .collect::>() + .join(", "); + return Err(miette::miette!( + "ambiguous middleware endpoint match for request path '{request_path}': {matches}" + )); + } + Ok(best.into_iter().next()) +} + +fn endpoint_path_specificity(path: &str) -> usize { + if path.is_empty() { + 0 + } else { + path.chars().filter(|c| *c != '*').count() + } } fn global_middleware_entries( @@ -918,6 +976,25 @@ fn get_field<'a>(val: &'a regorus::Value, key: &str) -> Option<&'a regorus::Valu } } +fn get_array<'a>(val: &'a regorus::Value, key: &str) -> Option<&'a [regorus::Value]> { + let regorus::Value::Array(values) = get_field(val, key)? else { + return None; + }; + Some(values) +} + +fn get_usize(val: ®orus::Value, key: &str) -> Option { + let value = get_field(val, key)?; + let regorus::Value::Number(number) = value else { + return None; + }; + let value = number.as_f64()?; + if !value.is_finite() || value.fract() != 0.0 || value < 0.0 { + return None; + } + format!("{value:.0}").parse::().ok() +} + fn regorus_value_to_struct(value: ®orus::Value) -> prost_types::Struct { let regorus::Value::Object(map) = value else { return prost_types::Struct::default(); @@ -3305,6 +3382,7 @@ network_policies: path: "/usr/bin/curl".to_string(), ..Default::default() }], + middleware: vec![], }, ); @@ -3323,6 +3401,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); @@ -3377,6 +3456,7 @@ network_policies: path: "/usr/bin/curl".to_string(), ..Default::default() }], + middleware: vec![], }, ); @@ -3395,6 +3475,7 @@ network_policies: run_as_group: "sandbox".to_string(), }), network_policies, + network_middlewares: vec![], }; let engine = OpaEngine::from_proto(&proto).expect("engine from proto"); @@ -6955,6 +7036,52 @@ network_policies: ); } + #[test] + fn middleware_chain_rejects_ambiguous_duplicate_endpoint_identity() { + let data = r#" +network_middlewares: + - name: first-redactor + middleware: openshell/secrets + - name: second-redactor + middleware: openshell/secrets +network_policies: + api: + name: api + endpoints: + - host: api.example.com + port: 443 + protocol: rest + enforcement: enforce + middleware: ["first-redactor"] + access: full + - host: api.example.com + port: 443 + protocol: rest + enforcement: enforce + middleware: ["second-redactor"] + access: full + binaries: + - { path: /usr/bin/curl } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "api.example.com".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/curl"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let err = engine + .query_middleware_chain_with_generation(&input, "/v1/messages") + .expect_err("equivalent endpoint identities should be ambiguous"); + assert!( + err.to_string() + .contains("ambiguous middleware endpoint match"), + "{err:?}" + ); + } + #[test] fn middleware_policy_validation_rejects_bad_configs() { let cases = [ diff --git a/proto/sandbox.proto b/proto/sandbox.proto index 5d2bc31a5..a73d762e5 100644 --- a/proto/sandbox.proto +++ b/proto/sandbox.proto @@ -202,6 +202,8 @@ message McpOptions { // MCP-family methods at the method layer unless a tool-name policy narrows // tools/call. When unset or false, explicit method rules are required. optional bool allow_all_known_mcp_methods = 2; +} + // Trusted GraphQL operation classification. message GraphqlOperation { // Operation type: "query", "mutation", or "subscription". From eb661690c2b28bad040737d6d42adfabd4c3f1a3 Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Mon, 29 Jun 2026 10:15:03 -0700 Subject: [PATCH 05/16] feat(supervisor-middleware): implement phase one runtime Signed-off-by: Piotr Mlocek --- Cargo.lock | 1 + crates/openshell-policy/Cargo.toml | 1 + crates/openshell-policy/src/lib.rs | 47 +++++ .../src/lib.rs | 88 +++++++-- .../src/service.rs | 9 +- .../data/sandbox-policy.rego | 2 + .../src/l7/relay.rs | 164 ++++++++++++++++- .../src/l7/rest.rs | 63 +++++++ .../openshell-supervisor-network/src/opa.rs | 174 ++++++++++++++---- .../openshell-supervisor-network/src/proxy.rs | 67 +++++++ 10 files changed, 551 insertions(+), 65 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0a2f749b4..c6fb086df 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3769,6 +3769,7 @@ version = "0.0.0" dependencies = [ "miette", "openshell-core", + "openshell-supervisor-middleware", "prost-types", "serde", "serde_json", diff --git a/crates/openshell-policy/Cargo.toml b/crates/openshell-policy/Cargo.toml index 50bea5b32..7ccd5d967 100644 --- a/crates/openshell-policy/Cargo.toml +++ b/crates/openshell-policy/Cargo.toml @@ -12,6 +12,7 @@ repository.workspace = true [dependencies] openshell-core = { path = "../openshell-core", default-features = false } +openshell-supervisor-middleware = { path = "../openshell-supervisor-middleware" } prost-types = { workspace = true } serde = { workspace = true } serde_json = { workspace = true } diff --git a/crates/openshell-policy/src/lib.rs b/crates/openshell-policy/src/lib.rs index 21561596c..6b5045a9d 100644 --- a/crates/openshell-policy/src/lib.rs +++ b/crates/openshell-policy/src/lib.rs @@ -1218,6 +1218,8 @@ pub enum PolicyViolation { }, /// `credential_signing` and `request_body_credential_rewrite` are both set. CredentialSigningWithBodyRewrite { policy_name: String, host: String }, + /// A built-in middleware configuration is invalid. + InvalidBuiltinMiddlewareConfig { name: String, reason: String }, } impl fmt::Display for PolicyViolation { @@ -1279,6 +1281,9 @@ impl fmt::Display for PolicyViolation { and request_body_credential_rewrite set; these options are mutually exclusive" ) } + Self::InvalidBuiltinMiddlewareConfig { name, reason } => { + write!(f, "middleware config '{name}' is invalid: {reason}") + } } } } @@ -1410,6 +1415,21 @@ pub fn validate_sandbox_policy( } } + for middleware in &policy.network_middlewares { + if middleware.middleware.starts_with("openshell/") { + let config = middleware.config.as_ref().cloned().unwrap_or_default(); + if let Err(error) = openshell_supervisor_middleware::validate_builtin_config( + &middleware.middleware, + &config, + ) { + violations.push(PolicyViolation::InvalidBuiltinMiddlewareConfig { + name: middleware.name.clone(), + reason: error.to_string(), + }); + } + } + } + if violations.is_empty() { Ok(()) } else { @@ -1845,6 +1865,33 @@ network_policies: assert_eq!(violations.len(), 2); } + #[test] + fn validate_rejects_invalid_builtin_middleware_config() { + let mut policy = restrictive_default_policy(); + policy.network_middlewares.push(NetworkMiddlewareConfig { + name: "redact-secrets".into(), + middleware: "openshell/secrets".into(), + config: Some(prost_types::Struct { + fields: std::iter::once(( + "secrets".into(), + prost_types::Value { + kind: Some(prost_types::value::Kind::StringValue("allow".into())), + }, + )) + .collect(), + }), + on_error: String::new(), + endpoints: None, + }); + + let violations = validate_sandbox_policy(&policy).expect_err("invalid config"); + assert!(violations.iter().any(|violation| matches!( + violation, + PolicyViolation::InvalidBuiltinMiddlewareConfig { name, .. } + if name == "redact-secrets" + ))); + } + #[test] fn validate_rejects_non_sandbox_user() { let mut policy = restrictive_default_policy(); diff --git a/crates/openshell-supervisor-middleware/src/lib.rs b/crates/openshell-supervisor-middleware/src/lib.rs index b68d83c86..4ec7e2782 100644 --- a/crates/openshell-supervisor-middleware/src/lib.rs +++ b/crates/openshell-supervisor-middleware/src/lib.rs @@ -14,7 +14,7 @@ pub use service::InProcessMiddlewareService; use openshell_core::proto::middleware::v1::supervisor_middleware_server::SupervisorMiddleware; use openshell_core::proto::{ - Decision, Finding, HttpRequestEvaluation, HttpRequestTarget, NetworkMiddlewareConfig, Process, + Decision, Finding, HttpRequestEvaluation, HttpRequestTarget, NetworkMiddlewareConfig, RequestContext, }; use tonic::Request; @@ -24,6 +24,19 @@ pub const HTTP_REQUEST_OPERATION: &str = "HttpRequest"; pub const PRE_CREDENTIALS_PHASE: &str = "pre_credentials"; pub const BUILTIN_SECRETS: &str = "openshell/secrets"; +/// Validate the configuration for an in-process middleware implementation. +/// +/// Policy admission uses this same implementation-specific validation before a +/// configuration can reach the request path. +pub fn validate_builtin_config(implementation: &str, config: &prost_types::Struct) -> Result<()> { + match implementation { + BUILTIN_SECRETS => builtins::secrets::validate_config(config), + other => Err(miette!( + "middleware implementation '{other}' is not available in phase 1" + )), + } +} + #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum OnError { FailClosed, @@ -76,9 +89,6 @@ impl TryFrom<&NetworkMiddlewareConfig> for ChainEntry { pub struct HttpRequestInput { pub request_id: String, pub sandbox_id: String, - pub binary: String, - pub pid: u32, - pub ancestors: Vec, pub scheme: String, pub host: String, pub port: u16, @@ -210,6 +220,26 @@ impl ChainRunner { } }; + let decision = match Decision::try_from(result.decision) { + Ok(decision @ (Decision::Allow | Decision::Deny)) => decision, + Ok(Decision::Unspecified) | Err(_) => { + match apply_on_error(entry, "invalid_response_decision", &mut applied) { + OnErrorAction::FailOpen => continue, + OnErrorAction::FailClosed(reason) => { + return Ok(ChainOutcome { + allowed: false, + reason, + body, + added_headers, + findings, + metadata, + applied, + }); + } + } + } + }; + // A result proposing unsafe header mutations is a malformed response: // route it through `on_error` instead of applying any of it. if validate_header_mutations(&headers, &result.add_headers).is_err() { @@ -251,11 +281,11 @@ impl ChainRunner { applied.push(MiddlewareInvocation { name: entry.name.clone(), implementation: entry.implementation.clone(), - decision: Decision::try_from(result.decision).unwrap_or(Decision::Unspecified), + decision, transformed, failed: false, }); - if result.decision == Decision::Deny as i32 { + if decision == Decision::Deny { return Ok(ChainOutcome { allowed: false, reason: safe_reason(&result.reason), @@ -293,11 +323,7 @@ fn build_evaluation( context: Some(RequestContext { request_id: input.request_id.clone(), sandbox_id: input.sandbox_id.clone(), - originating_process: Some(Process { - binary: input.binary.clone(), - pid: input.pid, - ancestors: input.ancestors.clone(), - }), + originating_process: None, }), config: Some(entry.config.clone()), target: Some(HttpRequestTarget { @@ -399,9 +425,6 @@ mod tests { HttpRequestInput { request_id: "req".into(), sandbox_id: "sbx".into(), - binary: "/usr/bin/curl".into(), - pid: 42, - ancestors: vec![], scheme: "https".into(), host: "api.example.com".into(), port: 443, @@ -413,6 +436,21 @@ mod tests { } } + #[test] + fn phase_one_evaluation_omits_originating_process() { + let entry = entry("redact", OnError::FailClosed); + let input = input("payload"); + let evaluation = build_evaluation(&entry, &input, &BTreeMap::new(), b"payload"); + + assert!( + evaluation + .context + .expect("request context") + .originating_process + .is_none() + ); + } + #[tokio::test] async fn redacts_common_secret_patterns() { let outcome = ChainRunner::default() @@ -684,4 +722,26 @@ mod tests { assert_eq!(outcome.applied.len(), 1); assert!(outcome.applied[0].failed); } + + #[tokio::test] + async fn unspecified_decision_uses_fail_closed() { + let runner = ChainRunner::new(Arc::new(ScriptedService { + result: openshell_core::proto::HttpRequestResult { + decision: Decision::Unspecified as i32, + ..allow_result() + }, + })); + + let outcome = runner + .evaluate(&[entry("redact", OnError::FailClosed)], input("hello")) + .await + .expect("evaluate"); + + assert!(!outcome.allowed); + assert_eq!( + outcome.reason, + "middleware_failed: invalid_response_decision" + ); + assert!(outcome.applied[0].failed); + } } diff --git a/crates/openshell-supervisor-middleware/src/service.rs b/crates/openshell-supervisor-middleware/src/service.rs index 31cca5694..cbd9231cd 100644 --- a/crates/openshell-supervisor-middleware/src/service.rs +++ b/crates/openshell-supervisor-middleware/src/service.rs @@ -10,7 +10,7 @@ use tonic::{Request, Response, Status}; use crate::{ API_VERSION, BUILTIN_SECRETS, HTTP_REQUEST_OPERATION, PRE_CREDENTIALS_PHASE, builtins, - safe_reason, + safe_reason, validate_builtin_config, }; #[derive(Debug, Default)] @@ -40,12 +40,7 @@ impl SupervisorMiddleware for InProcessMiddlewareService { ) -> Result, Status> { let request = request.into_inner(); let config = request.config.unwrap_or_default(); - let validation = match request.binding_id.as_str() { - BUILTIN_SECRETS => builtins::secrets::validate_config(&config), - other => Err(miette::miette!( - "middleware implementation '{other}' is not available in phase 1" - )), - }; + let validation = validate_builtin_config(&request.binding_id, &config); Ok(Response::new(match validation { Ok(()) => ValidateConfigResponse { valid: true, diff --git a/crates/openshell-supervisor-network/data/sandbox-policy.rego b/crates/openshell-supervisor-network/data/sandbox-policy.rego index 9228416e1..52f6f1046 100644 --- a/crates/openshell-supervisor-network/data/sandbox-policy.rego +++ b/crates/openshell-supervisor-network/data/sandbox-policy.rego @@ -871,6 +871,8 @@ matched_endpoint_config := _matching_endpoint_configs[0] if { count(_matching_endpoint_configs) > 0 } +network_policies := object.get(data, "network_policies", {}) + network_middlewares := object.get(data, "network_middlewares", []) _policy_has_exact_declared_endpoint(policy) if { diff --git a/crates/openshell-supervisor-network/src/l7/relay.rs b/crates/openshell-supervisor-network/src/l7/relay.rs index 6e5c3c4e9..c773fdcf4 100644 --- a/crates/openshell-supervisor-network/src/l7/relay.rs +++ b/crates/openshell-supervisor-network/src/l7/relay.rs @@ -768,12 +768,12 @@ fn jsonrpc_engine_type(protocol: L7Protocol) -> &'static str { } } -enum MiddlewareApplyResult { +pub(crate) enum MiddlewareApplyResult { Allowed(crate::l7::provider::L7Request), Denied(String), } -async fn apply_middleware_chain( +pub(crate) async fn apply_middleware_chain( req: crate::l7::provider::L7Request, client: &mut C, ctx: &L7EvalContext, @@ -796,18 +796,16 @@ async fn apply_middleware_chain( } }; let headers = safe_middleware_headers(&buffered.headers)?; + let query = raw_query_from_request_headers(&buffered.headers)?; let input = openshell_supervisor_middleware::HttpRequestInput { request_id: uuid::Uuid::new_v4().to_string(), - sandbox_id: String::new(), - binary: ctx.binary_path.clone(), - pid: 0, - ancestors: ctx.ancestors.clone(), + sandbox_id: openshell_ocsf::ctx::ctx().sandbox_id.clone(), scheme: "https".into(), host: ctx.host.clone(), port: ctx.port, method: req.action.clone(), path: req.target.clone(), - query: String::new(), + query, headers, body: buffered.body, }; @@ -828,6 +826,19 @@ async fn apply_middleware_chain( } } +fn raw_query_from_request_headers(headers: &[u8]) -> Result { + let header_str = + std::str::from_utf8(headers).map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let target = header_str + .lines() + .next() + .and_then(|line| line.split_whitespace().nth(1)) + .ok_or_else(|| miette!("HTTP request line is missing a target"))?; + Ok(target + .split_once('?') + .map_or_else(String::new, |(_, query)| query.to_string())) +} + /// Apply the chain's `on_error` policy when the request body cannot be buffered /// for inspection because it exceeds the size cap. The RFC treats an unbufferable /// body as an `on_error` event: it is denied unless every attached middleware is @@ -1446,6 +1457,37 @@ where } if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { + let chain = + engine.query_middleware_chain(&middleware_network_input(ctx), &req.target)?; + let req = + match apply_middleware_chain(req, client, ctx, chain, engine.generation_guard()) + .await? + { + MiddlewareApplyResult::Allowed(req) => req, + MiddlewareApplyResult::Denied(reason) => { + crate::l7::rest::RestProvider::default() + .deny_with_redacted_target( + &crate::l7::provider::L7Request { + action: request_info.action.clone(), + target: redacted_target.clone(), + query_params: request_info.query_params.clone(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }, + &ctx.policy_name, + &reason, + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } + }; // Future MCP response/SSE introspection or rewrite would hook here // before returning upstream bytes. The current policy schema has no // trusted-annotations or version-profile field, so MCP responses and @@ -2736,6 +2778,104 @@ network_policies: .unwrap(); } + #[tokio::test] + async fn jsonrpc_middleware_fail_closed_does_not_reach_upstream() { + let data = r#" +network_middlewares: + - name: request-middleware + middleware: example/unavailable + on_error: fail_closed +network_policies: + jsonrpc_api: + name: jsonrpc_api + middleware: ["request-middleware"] + endpoints: + - host: api.example.test + port: 443 + protocol: json-rpc + enforcement: enforce + rules: + - allow: + method: reports.list + binaries: + - { path: /usr/bin/node } +"#; + let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); + let input = NetworkInput { + host: "api.example.test".into(), + port: 443, + binary_path: PathBuf::from("/usr/bin/node"), + binary_sha256: "unused".into(), + ancestors: vec![], + cmdline_paths: vec![], + }; + let (endpoint_config, generation) = engine + .query_endpoint_config_with_generation(&input) + .expect("endpoint config"); + let config = crate::l7::parse_l7_config(&endpoint_config.expect("json-rpc config")) + .expect("parse JSON-RPC config"); + let tunnel_engine = engine.clone_engine_for_tunnel(generation).unwrap(); + let ctx = L7EvalContext { + host: "api.example.test".into(), + port: 443, + policy_name: "jsonrpc_api".into(), + binary_path: "/usr/bin/node".into(), + ancestors: vec![], + cmdline_paths: vec![], + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_jsonrpc( + &config, + &tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + let body = br#"{"jsonrpc":"2.0","id":1,"method":"reports.list"}"#; + let request = format!( + "POST /rpc HTTP/1.1\r\nHost: api.example.test\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + std::str::from_utf8(body).unwrap() + ); + app.write_all(request.as_bytes()).await.unwrap(); + + let mut response = [0u8; 512]; + let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) + .await + .expect("denial should reach client") + .unwrap(); + let response = String::from_utf8_lossy(&response[..n]); + assert!(response.contains("403 Forbidden")); + assert!(response.contains("middleware_failed")); + + let mut upstream_request = [0u8; 32]; + let result = tokio::time::timeout( + std::time::Duration::from_millis(100), + upstream.read(&mut upstream_request), + ) + .await; + assert!( + matches!(result, Err(_) | Ok(Ok(0))), + "upstream should not receive request bytes" + ); + + drop(app); + tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should finish") + .unwrap() + .unwrap(); + } + #[tokio::test] async fn l7_rest_middleware_over_capacity_fails_closed() { let (config, tunnel_engine, ctx) = @@ -2842,6 +2982,16 @@ network_policies: )); } + #[test] + fn middleware_keeps_the_raw_request_query() { + let query = raw_query_from_request_headers( + b"POST /v1/messages?token=a%2Bb&scope=private HTTP/1.1\r\nHost: api.example.test\r\n\r\n", + ) + .expect("query from request headers"); + + assert_eq!(query, "token=a%2Bb&scope=private"); + } + /// Tracing layer that captures emitted `OcsfEvent`s for assertions. struct OcsfCaptureLayer(Arc>>); diff --git a/crates/openshell-supervisor-network/src/l7/rest.rs b/crates/openshell-supervisor-network/src/l7/rest.rs index 2c85cacf6..19f73e2ad 100644 --- a/crates/openshell-supervisor-network/src/l7/rest.rs +++ b/crates/openshell-supervisor-network/src/l7/rest.rs @@ -246,6 +246,36 @@ async fn parse_http_request( })) } +/// Build an L7 request from a request already buffered by another proxy path. +/// +/// The forward proxy needs this after it has consumed the incoming HTTP/1 +/// headers itself. Keep the framing and query parsing here so it matches the +/// stream-based REST parser rather than growing another local parser. +pub(crate) fn request_from_buffered_http( + action: impl Into, + target: impl Into, + query_target: &str, + raw_header: Vec, +) -> Result { + let header_end = raw_header + .windows(4) + .position(|window| window == b"\r\n\r\n") + .ok_or_else(|| miette!("HTTP request headers are missing the CRLF terminator"))? + + 4; + let header_str = std::str::from_utf8(&raw_header[..header_end]) + .map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; + let body_length = parse_body_length(header_str)?; + let (_, query_params) = parse_target_query(query_target)?; + + Ok(L7Request { + action: action.into(), + target: target.into(), + query_params, + raw_header, + body_length, + }) +} + /// Rebuild the request line in a raw HTTP header block with a canonicalized /// target. Called when the canonical path differs from what the client sent, /// so the upstream dispatches on the exact bytes the policy engine evaluated. @@ -3015,6 +3045,39 @@ mod tests { } } + #[test] + fn buffered_request_parser_uses_shared_framing_and_query_parsing() { + let request = request_from_buffered_http( + "POST", + "/v1/items", + "/v1/items?tag=first&tag=second", + b"POST /v1/items?tag=first&tag=second HTTP/1.1\r\nHost: api.example.com\r\nContent-Length: 3\r\n\r\nabc" + .to_vec(), + ) + .expect("parse buffered request"); + + assert_eq!(request.action, "POST"); + assert_eq!(request.target, "/v1/items"); + assert_eq!( + request.query_params.get("tag"), + Some(&vec!["first".to_string(), "second".to_string()]) + ); + assert!(matches!(request.body_length, BodyLength::ContentLength(3))); + } + + #[test] + fn buffered_request_parser_rejects_missing_header_terminator() { + let err = request_from_buffered_http( + "GET", + "/v1/items", + "/v1/items", + b"GET /v1/items HTTP/1.1\r\nHost: api.example.com\r\n".to_vec(), + ) + .expect_err("unterminated headers must be rejected"); + + assert!(err.to_string().contains("missing the CRLF terminator")); + } + #[test] fn parse_chunked() { let headers = diff --git a/crates/openshell-supervisor-network/src/opa.rs b/crates/openshell-supervisor-network/src/opa.rs index 3d0f75bf7..3efec0212 100644 --- a/crates/openshell-supervisor-network/src/opa.rs +++ b/crates/openshell-supervisor-network/src/opa.rs @@ -776,12 +776,12 @@ fn query_middleware_chain_locked( .eval_rule("data.openshell.sandbox._matching_endpoint_contexts".into()) .map_err(|e| miette::miette!("{e}"))?; let contexts = parse_endpoint_contexts(&contexts_val); - let Some(context) = select_endpoint_context(&contexts, request_path)? else { - return global_middleware_entries(&configs, &input.host, &HashSet::new()); - }; let policies_val = engine - .eval_rule("data.network_policies".into()) + .eval_rule("data.openshell.sandbox.network_policies".into()) .map_err(|e| miette::miette!("{e}"))?; + let Some(context) = select_endpoint_context(&contexts, request_path, &policies_val)? else { + return global_middleware_entries(&configs, &input.host, &HashSet::new()); + }; let (policy_middleware, endpoint_middleware) = middleware_for_endpoint_identity(&policies_val, context)?; @@ -862,6 +862,7 @@ fn middleware_for_endpoint_identity( fn select_endpoint_context<'a>( contexts: &'a [MatchedEndpointContext], request_path: &str, + policies: ®orus::Value, ) -> Result> { let matching: Vec<_> = contexts .iter() @@ -876,30 +877,56 @@ fn select_endpoint_context<'a>( .filter(|(specificity, _)| *specificity == max_specificity) .map(|(_, context)| context) .collect(); - if best.len() > 1 { - let matches = best - .iter() - .map(|context| { - format!( - "{}[{}] path={}", - context.policy_name, - context.endpoint_index, - if context.endpoint_path.is_empty() { - "" - } else { - context.endpoint_path.as_str() - } - ) - }) - .collect::>() - .join(", "); - return Err(miette::miette!( - "ambiguous middleware endpoint match for request path '{request_path}': {matches}" - )); + if let Some((first, rest)) = best.split_first() { + let first_middleware = explicit_middleware_for_endpoint_identity(policies, first)?; + for context in rest { + if explicit_middleware_for_endpoint_identity(policies, context)? != first_middleware { + let matches = best + .iter() + .map(|context| { + format!( + "{}[{}] path={}", + context.policy_name, + context.endpoint_index, + if context.endpoint_path.is_empty() { + "" + } else { + context.endpoint_path.as_str() + } + ) + }) + .collect::>() + .join(", "); + return Err(miette::miette!( + "ambiguous middleware endpoint match for request path '{request_path}': {matches}" + )); + } + } } Ok(best.into_iter().next()) } +fn explicit_middleware_for_endpoint_identity( + policies: ®orus::Value, + context: &MatchedEndpointContext, +) -> Result> { + let (policy_middleware, endpoint_middleware) = + middleware_for_endpoint_identity(policies, context)?; + Ok(dedup_middleware_names( + policy_middleware.iter().chain(endpoint_middleware.iter()), + )) +} + +fn dedup_middleware_names<'a>(names: impl IntoIterator) -> Vec { + let mut deduped = Vec::new(); + for name in names { + if !deduped.contains(name) { + deduped.push(name.clone()); + } + } + deduped +} + fn endpoint_path_specificity(path: &str) -> usize { if path.is_empty() { 0 @@ -1402,10 +1429,91 @@ fn validate_middleware_policies(data: &serde_json::Value) -> Vec { )); } } + validate_ambiguous_middleware_endpoints( + policy_name, + policy, + &policy_middleware, + &mut errors, + ); } errors } +fn validate_ambiguous_middleware_endpoints( + policy_name: &str, + policy: &serde_json::Value, + policy_middleware: &[String], + errors: &mut Vec, +) { + let endpoints = policy + .get("endpoints") + .and_then(serde_json::Value::as_array) + .map_or(&[][..], Vec::as_slice); + let mut seen: Vec<(usize, MiddlewareEndpointKey, Vec)> = Vec::new(); + for (index, endpoint) in endpoints.iter().enumerate() { + let key = middleware_endpoint_key(endpoint); + let endpoint_middleware = json_string_array(endpoint.get("middleware")); + let chain = + dedup_middleware_names(policy_middleware.iter().chain(endpoint_middleware.iter())); + for (previous_index, previous_key, previous_chain) in &seen { + if previous_key == &key && previous_chain != &chain { + errors.push(format!( + "network policy '{policy_name}' endpoints[{previous_index}] and endpoints[{index}] have equivalent middleware selection keys ({key}) but different middleware chains" + )); + } + } + seen.push((index, key, chain)); + } +} + +#[derive(Debug, PartialEq, Eq)] +struct MiddlewareEndpointKey { + host: String, + ports: Vec, + path: String, +} + +impl std::fmt::Display for MiddlewareEndpointKey { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "host={} ports={:?} path={}", + if self.host.is_empty() { + "" + } else { + self.host.as_str() + }, + self.ports, + if self.path.is_empty() { + "" + } else { + self.path.as_str() + } + ) + } +} + +fn middleware_endpoint_key(endpoint: &serde_json::Value) -> MiddlewareEndpointKey { + let host = endpoint + .get("host") + .and_then(serde_json::Value::as_str) + .unwrap_or_default() + .to_ascii_lowercase(); + let mut ports: Vec = endpoint + .get("ports") + .and_then(serde_json::Value::as_array) + .map(|ports| ports.iter().filter_map(serde_json::Value::as_u64).collect()) + .unwrap_or_default(); + ports.sort_unstable(); + ports.dedup(); + let path = endpoint + .get("path") + .and_then(serde_json::Value::as_str) + .unwrap_or_default() + .to_string(); + MiddlewareEndpointKey { host, ports, path } +} + fn json_string_array(value: Option<&serde_json::Value>) -> Vec { value .and_then(serde_json::Value::as_array) @@ -7037,7 +7145,7 @@ network_policies: } #[test] - fn middleware_chain_rejects_ambiguous_duplicate_endpoint_identity() { + fn middleware_validation_rejects_ambiguous_duplicate_endpoint_middleware() { let data = r#" network_middlewares: - name: first-redactor @@ -7063,21 +7171,13 @@ network_policies: binaries: - { path: /usr/bin/curl } "#; - let engine = OpaEngine::from_strings(TEST_POLICY, data).unwrap(); - let input = NetworkInput { - host: "api.example.com".into(), - port: 443, - binary_path: PathBuf::from("/usr/bin/curl"), - binary_sha256: "unused".into(), - ancestors: vec![], - cmdline_paths: vec![], + let err = match OpaEngine::from_strings(TEST_POLICY, data) { + Ok(_) => panic!("equivalent endpoints with different middleware should be invalid"), + Err(err) => err, }; - let err = engine - .query_middleware_chain_with_generation(&input, "/v1/messages") - .expect_err("equivalent endpoint identities should be ambiguous"); assert!( err.to_string() - .contains("ambiguous middleware endpoint match"), + .contains("equivalent middleware selection keys"), "{err:?}" ); } diff --git a/crates/openshell-supervisor-network/src/proxy.rs b/crates/openshell-supervisor-network/src/proxy.rs index af3331735..f8310fbdc 100644 --- a/crates/openshell-supervisor-network/src/proxy.rs +++ b/crates/openshell-supervisor-network/src/proxy.rs @@ -4184,6 +4184,73 @@ async fn handle_forward_proxy( } emit_forward_success_activity(activity_tx, l7_activity_pending); + let middleware_path = path.split_once('?').map_or(path.as_str(), |(path, _)| path); + let middleware_input = crate::opa::NetworkInput { + host: host_lc.clone(), + port, + binary_path: decision.binary.clone().unwrap_or_default(), + binary_sha256: String::new(), + ancestors: decision.ancestors.clone(), + cmdline_paths: decision.cmdline_paths.clone(), + }; + let (chain, generation) = + opa_engine.query_middleware_chain_with_generation(&middleware_input, middleware_path)?; + if generation != forward_generation_guard.captured_generation() { + emit_l7_tunnel_close_after_policy_change( + &host_lc, + port, + miette::miette!( + "policy changed before forward middleware evaluation [expected_generation:{} current_generation:{}]", + forward_generation_guard.captured_generation(), + generation, + ), + ); + respond( + client, + &build_json_error_response( + 403, + "Forbidden", + "policy_denied", + &format!("{method} {host_lc}:{port}{path} not permitted by policy"), + ), + ) + .await?; + return Ok(()); + } + if !chain.is_empty() { + let request = crate::l7::rest::request_from_buffered_http( + method, + middleware_path, + &upstream_target, + forward_request_bytes, + )?; + forward_request_bytes = match crate::l7::relay::apply_middleware_chain( + request, + client, + &l7_ctx, + chain, + &forward_generation_guard, + ) + .await? + { + crate::l7::relay::MiddlewareApplyResult::Allowed(request) => request.raw_header, + crate::l7::relay::MiddlewareApplyResult::Denied(reason) => { + emit_activity_simple(activity_tx, true, "middleware"); + respond( + client, + &build_json_error_response( + 403, + "Forbidden", + "middleware_denied", + &format!("{method} {host_lc}:{port}{path} denied by middleware: {reason}"), + ), + ) + .await?; + return Ok(()); + } + }; + } + forward_request_bytes = match inject_token_grant_for_forward_request( method, &upstream_target, From ac1d6c867f9da22dce67000c371c6b6164050f7d Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Mon, 29 Jun 2026 14:10:40 -0700 Subject: [PATCH 06/16] fix(supervisor-middleware): harden selection and buffering Signed-off-by: Piotr Mlocek --- Cargo.lock | 1 + crates/openshell-cli/src/policy_update.rs | 1 - crates/openshell-policy/Cargo.toml | 1 + crates/openshell-policy/src/compose.rs | 1 - crates/openshell-policy/src/lib.rs | 373 ++++++++++-- crates/openshell-policy/src/merge.rs | 24 - crates/openshell-providers/src/profiles.rs | 2 - .../src/mechanistic_mapper.rs | 1 - crates/openshell-server/src/grpc/policy.rs | 32 -- .../openshell-server/src/grpc/validation.rs | 22 + .../Cargo.toml | 4 +- .../src/builtins/mod.rs | 25 + .../src/builtins/secrets.rs | 22 +- .../src/lib.rs | 383 +++++++++++-- .../src/service.rs | 24 +- .../data/sandbox-policy.rego | 19 - .../src/l7/relay.rs | 151 +++-- .../src/l7/rest.rs | 138 +++-- .../openshell-supervisor-network/src/opa.rs | 534 ++++-------------- .../src/policy_local.rs | 5 - .../openshell-supervisor-network/src/proxy.rs | 5 +- proto/middleware.proto | 2 + proto/sandbox.proto | 8 +- 23 files changed, 1073 insertions(+), 705 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c6fb086df..2de3c0353 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3767,6 +3767,7 @@ dependencies = [ name = "openshell-policy" version = "0.0.0" dependencies = [ + "glob", "miette", "openshell-core", "openshell-supervisor-middleware", diff --git a/crates/openshell-cli/src/policy_update.rs b/crates/openshell-cli/src/policy_update.rs index 824b1dde0..1f1f64750 100644 --- a/crates/openshell-cli/src/policy_update.rs +++ b/crates/openshell-cli/src/policy_update.rs @@ -65,7 +65,6 @@ pub fn build_policy_update_plan( ..Default::default() }) .collect(), - middleware: Vec::new(), }; merge_operations.push(PolicyMergeOperation { operation: Some(policy_merge_operation::Operation::AddRule(AddNetworkRule { diff --git a/crates/openshell-policy/Cargo.toml b/crates/openshell-policy/Cargo.toml index 7ccd5d967..073728db1 100644 --- a/crates/openshell-policy/Cargo.toml +++ b/crates/openshell-policy/Cargo.toml @@ -11,6 +11,7 @@ license.workspace = true repository.workspace = true [dependencies] +glob = { workspace = true } openshell-core = { path = "../openshell-core", default-features = false } openshell-supervisor-middleware = { path = "../openshell-supervisor-middleware" } prost-types = { workspace = true } diff --git a/crates/openshell-policy/src/compose.rs b/crates/openshell-policy/src/compose.rs index 1ad0d4617..7ca8584d9 100644 --- a/crates/openshell-policy/src/compose.rs +++ b/crates/openshell-policy/src/compose.rs @@ -115,7 +115,6 @@ mod tests { ..Default::default() }], binaries: Vec::new(), - middleware: Vec::new(), } } diff --git a/crates/openshell-policy/src/lib.rs b/crates/openshell-policy/src/lib.rs index 6b5045a9d..c301af01a 100644 --- a/crates/openshell-policy/src/lib.rs +++ b/crates/openshell-policy/src/lib.rs @@ -12,7 +12,7 @@ mod compose; mod merge; -use std::collections::{BTreeMap, HashMap}; +use std::collections::{BTreeMap, HashMap, HashSet}; use std::fmt; use std::path::Path; @@ -89,8 +89,6 @@ struct NetworkPolicyRuleDef { endpoints: Vec, #[serde(default, skip_serializing_if = "Vec::is_empty")] binaries: Vec, - #[serde(default, skip_serializing_if = "Vec::is_empty")] - middleware: Vec, } #[derive(Debug, Serialize, Deserialize)] @@ -174,8 +172,6 @@ struct NetworkEndpointDef { json_rpc: Option, #[serde(default, skip_serializing_if = "Option::is_none")] mcp: Option, - #[serde(default, skip_serializing_if = "Vec::is_empty")] - middleware: Vec, } // Signature dictated by serde's `skip_serializing_if`, which requires `&T`. @@ -788,7 +784,6 @@ fn to_proto(raw: PolicyFile) -> SandboxPolicy { signing_region: e.signing_region, json_rpc_max_body_bytes: json_rpc_max_body_bytes(&e.json_rpc, &e.mcp), mcp: mcp_options(&e.mcp), - middleware: e.middleware, } }) .collect(), @@ -800,7 +795,6 @@ fn to_proto(raw: PolicyFile) -> SandboxPolicy { ..Default::default() }) .collect(), - middleware: rule.middleware, }; (key, proto_rule) }) @@ -938,7 +932,6 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { signing_region: e.signing_region.clone(), json_rpc, mcp, - middleware: e.middleware.clone(), } }) .collect(), @@ -950,7 +943,6 @@ fn from_proto(policy: &SandboxPolicy) -> PolicyFile { harness: false, }) .collect(), - middleware: rule.middleware.clone(), }; (key.clone(), yaml_rule) }) @@ -1220,6 +1212,16 @@ pub enum PolicyViolation { CredentialSigningWithBodyRewrite { policy_name: String, host: String }, /// A built-in middleware configuration is invalid. InvalidBuiltinMiddlewareConfig { name: String, reason: String }, + /// A middleware configuration is structurally invalid. + InvalidMiddlewareConfig { name: String, reason: String }, + /// Middleware configuration names must be unique. + DuplicateMiddlewareConfigName { name: String }, + /// A middleware selector conflicts with an endpoint that skips TLS inspection. + MiddlewareTlsSkipConflict { + middleware_name: String, + policy_name: String, + host: String, + }, } impl fmt::Display for PolicyViolation { @@ -1281,13 +1283,67 @@ impl fmt::Display for PolicyViolation { and request_body_credential_rewrite set; these options are mutually exclusive" ) } - Self::InvalidBuiltinMiddlewareConfig { name, reason } => { + Self::InvalidBuiltinMiddlewareConfig { name, reason } + | Self::InvalidMiddlewareConfig { name, reason } => { write!(f, "middleware config '{name}' is invalid: {reason}") } + Self::DuplicateMiddlewareConfigName { name } => { + write!(f, "duplicate middleware config '{name}'") + } + Self::MiddlewareTlsSkipConflict { + middleware_name, + policy_name, + host, + } => { + write!( + f, + "middleware config '{middleware_name}' selects network policy \ + '{policy_name}' tls: skip endpoint '{host}'" + ) + } } } } +/// Match a middleware host selector pattern using the runtime's glob semantics. +/// +/// Invalid or empty patterns return an error instead of silently becoming a +/// non-match. +pub fn middleware_host_matches(pattern: &str, host: &str) -> std::result::Result { + if pattern.is_empty() { + return Err("host pattern must not be empty".to_string()); + } + if pattern.chars().any(char::is_whitespace) { + return Err("host pattern must not contain whitespace".to_string()); + } + + let pattern = glob::Pattern::new(&pattern.to_ascii_lowercase()) + .map_err(|error| format!("invalid host pattern: {error}"))?; + Ok(pattern.matches(&host.to_ascii_lowercase())) +} + +fn middleware_selector_matches_host( + middleware: &NetworkMiddlewareConfig, + host: &str, +) -> std::result::Result { + let Some(selector) = &middleware.endpoints else { + return Ok(false); + }; + let matches_include = selector + .include + .iter() + .try_fold(false, |matched, pattern| { + middleware_host_matches(pattern, host).map(|matches| matched || matches) + })?; + let matches_exclude = selector + .exclude + .iter() + .try_fold(false, |matched, pattern| { + middleware_host_matches(pattern, host).map(|matches| matched || matches) + })?; + Ok(matches_include && !matches_exclude) +} + /// Validate that a sandbox policy does not contain unsafe content. /// /// Returns `Ok(())` if the policy is safe, or `Err(violations)` listing all @@ -1302,6 +1358,9 @@ impl fmt::Display for PolicyViolation { /// - Individual path lengths must not exceed [`MAX_PATH_LENGTH`] /// - Total path count must not exceed [`MAX_FILESYSTEM_PATHS`] /// - Network endpoint hosts must not use TLD wildcards (e.g. `*.com`) +/// - Middleware names, implementations, failure modes, selectors, and built-in +/// configurations must be valid +/// - Middleware selectors must not match endpoints that skip TLS inspection pub fn validate_sandbox_policy( policy: &SandboxPolicy, ) -> std::result::Result<(), Vec> { @@ -1415,9 +1474,67 @@ pub fn validate_sandbox_policy( } } + let mut middleware_names = HashSet::new(); for middleware in &policy.network_middlewares { - if middleware.middleware.starts_with("openshell/") { - let config = middleware.config.as_ref().cloned().unwrap_or_default(); + if middleware.name.is_empty() { + violations.push(PolicyViolation::InvalidMiddlewareConfig { + name: middleware.name.clone(), + reason: "name must not be empty".to_string(), + }); + } else if !middleware_names.insert(middleware.name.clone()) { + violations.push(PolicyViolation::DuplicateMiddlewareConfigName { + name: middleware.name.clone(), + }); + } + + if middleware.middleware.is_empty() { + violations.push(PolicyViolation::InvalidMiddlewareConfig { + name: middleware.name.clone(), + reason: "implementation must not be empty".to_string(), + }); + } else if middleware.middleware.starts_with("openshell/") + && middleware.middleware != openshell_supervisor_middleware::BUILTIN_SECRETS + { + violations.push(PolicyViolation::InvalidMiddlewareConfig { + name: middleware.name.clone(), + reason: format!("unsupported built-in '{}'", middleware.middleware), + }); + } + + if !matches!( + middleware.on_error.as_str(), + "" | "fail_closed" | "fail_open" + ) { + violations.push(PolicyViolation::InvalidMiddlewareConfig { + name: middleware.name.clone(), + reason: format!("invalid on_error '{}'", middleware.on_error), + }); + } + + let Some(selector) = &middleware.endpoints else { + violations.push(PolicyViolation::InvalidMiddlewareConfig { + name: middleware.name.clone(), + reason: "endpoint selector is required".to_string(), + }); + continue; + }; + if selector.include.is_empty() { + violations.push(PolicyViolation::InvalidMiddlewareConfig { + name: middleware.name.clone(), + reason: "endpoint selector must include at least one host pattern".to_string(), + }); + } + for pattern in selector.include.iter().chain(&selector.exclude) { + if let Err(reason) = middleware_host_matches(pattern, "validation.invalid") { + violations.push(PolicyViolation::InvalidMiddlewareConfig { + name: middleware.name.clone(), + reason: format!("endpoint selector pattern '{pattern}' is invalid: {reason}"), + }); + } + } + + if middleware.middleware == openshell_supervisor_middleware::BUILTIN_SECRETS { + let config = middleware.config.clone().unwrap_or_default(); if let Err(error) = openshell_supervisor_middleware::validate_builtin_config( &middleware.middleware, &config, @@ -1428,6 +1545,25 @@ pub fn validate_sandbox_policy( }); } } + + for (key, rule) in &policy.network_policies { + let policy_name = if rule.name.is_empty() { + key + } else { + &rule.name + }; + for endpoint in &rule.endpoints { + if endpoint.tls == "skip" + && middleware_selector_matches_host(middleware, &endpoint.host).unwrap_or(false) + { + violations.push(PolicyViolation::MiddlewareTlsSkipConflict { + middleware_name: middleware.name.clone(), + policy_name: policy_name.clone(), + host: endpoint.host.clone(), + }); + } + } + } } if violations.is_empty() { @@ -1569,17 +1705,17 @@ network_middlewares: service: mode: redact max_matches: 2 - - name: endpoint-redactor + - name: secondary-redactor middleware: openshell/secrets + endpoints: + include: ["api.example.com"] network_policies: api: name: api - middleware: ["global-redactor"] endpoints: - host: api.example.com port: 443 protocol: rest - middleware: ["endpoint-redactor"] binaries: - path: /usr/bin/curl "#; @@ -1612,26 +1748,9 @@ network_policies: .fields .contains_key("service") ); - assert_eq!( - proto.network_policies["api"].middleware, - vec!["global-redactor"] - ); - assert_eq!( - proto.network_policies["api"].endpoints[0].middleware, - vec!["endpoint-redactor"] - ); - let yaml_out = serialize_sandbox_policy(&proto).expect("serialize failed"); let reparsed = parse_sandbox_policy(&yaml_out).expect("re-parse failed"); assert_eq!(reparsed.network_middlewares, proto.network_middlewares); - assert_eq!( - reparsed.network_policies["api"].middleware, - vec!["global-redactor"] - ); - assert_eq!( - reparsed.network_policies["api"].endpoints[0].middleware, - vec!["endpoint-redactor"] - ); } #[test] @@ -1764,6 +1883,31 @@ network_policies: assert!(parse_sandbox_policy(yaml).is_err()); } + #[test] + fn parse_rejects_middleware_attachments_on_network_policies_and_endpoints() { + let policy_attachment = r" +version: 1 +network_policies: + api: + middleware: [redact] + endpoints: + - host: api.example.com + port: 443 +"; + assert!(parse_sandbox_policy(policy_attachment).is_err()); + + let endpoint_attachment = r" +version: 1 +network_policies: + api: + endpoints: + - host: api.example.com + port: 443 + middleware: [redact] +"; + assert!(parse_sandbox_policy(endpoint_attachment).is_err()); + } + #[test] fn l7_config_stanza_runtime_fields_use_canonical_schema() { let fields = l7_config_alias_runtime_fields( @@ -1837,6 +1981,19 @@ network_policies: // ---- Policy validation tests ---- + fn middleware_config(name: &str, implementation: &str) -> NetworkMiddlewareConfig { + NetworkMiddlewareConfig { + name: name.into(), + middleware: implementation.into(), + config: None, + on_error: String::new(), + endpoints: Some(MiddlewareEndpointSelector { + include: vec!["api.example.com".into()], + exclude: Vec::new(), + }), + } + } + #[test] fn validate_rejects_root_run_as_user() { let mut policy = restrictive_default_policy(); @@ -1868,21 +2025,17 @@ network_policies: #[test] fn validate_rejects_invalid_builtin_middleware_config() { let mut policy = restrictive_default_policy(); - policy.network_middlewares.push(NetworkMiddlewareConfig { - name: "redact-secrets".into(), - middleware: "openshell/secrets".into(), - config: Some(prost_types::Struct { - fields: std::iter::once(( - "secrets".into(), - prost_types::Value { - kind: Some(prost_types::value::Kind::StringValue("allow".into())), - }, - )) - .collect(), - }), - on_error: String::new(), - endpoints: None, + let mut middleware = middleware_config("redact-secrets", "openshell/secrets"); + middleware.config = Some(prost_types::Struct { + fields: std::iter::once(( + "secrets".into(), + prost_types::Value { + kind: Some(prost_types::value::Kind::StringValue("allow".into())), + }, + )) + .collect(), }); + policy.network_middlewares.push(middleware); let violations = validate_sandbox_policy(&policy).expect_err("invalid config"); assert!(violations.iter().any(|violation| matches!( @@ -1892,6 +2045,134 @@ network_policies: ))); } + #[test] + fn validate_rejects_invalid_middleware_control_fields() { + let cases = [ + ( + middleware_config("", "openshell/secrets"), + "name must not be empty", + ), + ( + middleware_config("redactor", ""), + "implementation must not be empty", + ), + ( + middleware_config("redactor", "openshell/unknown"), + "unsupported built-in", + ), + ( + { + let mut middleware = middleware_config("redactor", "openshell/secrets"); + middleware.on_error = "maybe".into(); + middleware + }, + "invalid on_error", + ), + ( + { + let mut middleware = middleware_config("redactor", "openshell/secrets"); + middleware.endpoints = None; + middleware + }, + "endpoint selector is required", + ), + ( + { + let mut middleware = middleware_config("redactor", "openshell/secrets"); + middleware.endpoints.as_mut().unwrap().include.clear(); + middleware + }, + "must include at least one host pattern", + ), + ]; + + for (middleware, expected) in cases { + let mut policy = restrictive_default_policy(); + policy.network_middlewares.push(middleware); + let errors = validate_sandbox_policy(&policy) + .expect_err("invalid middleware must be rejected") + .into_iter() + .map(|violation| violation.to_string()) + .collect::>() + .join("; "); + assert!( + errors.contains(expected), + "expected {expected:?} in {errors:?}" + ); + } + } + + #[test] + fn validate_rejects_duplicate_middleware_config_names() { + let mut policy = restrictive_default_policy(); + policy + .network_middlewares + .push(middleware_config("redactor", "openshell/secrets")); + policy + .network_middlewares + .push(middleware_config("redactor", "openshell/secrets")); + + let violations = validate_sandbox_policy(&policy).expect_err("duplicate name"); + assert!(violations.iter().any(|violation| matches!( + violation, + PolicyViolation::DuplicateMiddlewareConfigName { name } if name == "redactor" + ))); + } + + #[test] + fn validate_rejects_malformed_middleware_selector_patterns() { + let mut policy = restrictive_default_policy(); + let mut middleware = middleware_config("redactor", "openshell/secrets"); + middleware.endpoints.as_mut().unwrap().include = vec!["api[.example.com".into()]; + policy.network_middlewares.push(middleware); + + let errors = validate_sandbox_policy(&policy) + .expect_err("malformed selector") + .into_iter() + .map(|violation| violation.to_string()) + .collect::>() + .join("; "); + assert!(errors.contains("invalid host pattern"), "{errors}"); + } + + #[test] + fn middleware_host_selector_matching_is_case_insensitive() { + assert!(middleware_host_matches("*.Example.COM", "API.example.com").unwrap()); + assert!(!middleware_host_matches("*.example.com", "example.com").unwrap()); + assert!(middleware_host_matches("*", "deep.api.example.com").unwrap()); + } + + #[test] + fn validate_rejects_middleware_selector_matching_tls_skip_endpoint() { + let mut policy = restrictive_default_policy(); + policy + .network_middlewares + .push(middleware_config("redactor", "openshell/secrets")); + policy.network_policies.insert( + "api".into(), + NetworkPolicyRule { + name: "api".into(), + endpoints: vec![NetworkEndpoint { + host: "api.example.com".into(), + port: 443, + tls: "skip".into(), + ..Default::default() + }], + binaries: Vec::new(), + }, + ); + + let violations = validate_sandbox_policy(&policy).expect_err("tls skip conflict"); + assert!(violations.iter().any(|violation| matches!( + violation, + PolicyViolation::MiddlewareTlsSkipConflict { + middleware_name, + policy_name, + host, + } if middleware_name == "redactor" && policy_name == "api" && host == "api.example.com" + ))); + } + #[test] fn validate_rejects_non_sandbox_user() { let mut policy = restrictive_default_policy(); diff --git a/crates/openshell-policy/src/merge.rs b/crates/openshell-policy/src/merge.rs index 1c63e6ebc..04f390198 100644 --- a/crates/openshell-policy/src/merge.rs +++ b/crates/openshell-policy/src/merge.rs @@ -989,7 +989,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }, ); @@ -1008,7 +1007,6 @@ mod tests { path: "/usr/bin/gh".to_string(), ..Default::default() }], - ..Default::default() }; let result = merge_policy( @@ -1037,7 +1035,6 @@ mod tests { name: "existing".to_string(), endpoints: vec![endpoint("api.github.com", 443)], binaries: vec![advisor_binary("/usr/bin/curl")], - ..Default::default() }, ); @@ -1048,7 +1045,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let result = merge_policy( @@ -1080,7 +1076,6 @@ mod tests { ..Default::default() }, ], - ..Default::default() }; let result = merge_policy( @@ -1112,7 +1107,6 @@ mod tests { path: "/usr/bin/python".to_string(), ..Default::default() }], - ..Default::default() }, ); @@ -1126,7 +1120,6 @@ mod tests { ..Default::default() }], binaries: vec![advisor_binary("/usr/bin/python")], - ..Default::default() }; let result = merge_policy( @@ -1454,7 +1447,6 @@ mod tests { path: "/usr/bin/gh".to_string(), ..Default::default() }], - ..Default::default() }, ); @@ -1479,7 +1471,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let merged = merge_policy( @@ -1503,7 +1494,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; // Merge an *unrelated* rule for a different host. The proposed rule @@ -1534,7 +1524,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let mut policy = restrictive_default_policy(); @@ -1547,7 +1536,6 @@ mod tests { path: "/usr/bin/git".to_string(), ..Default::default() }], - ..Default::default() }, ); @@ -1579,7 +1567,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; // Endpoint exists in the policy but with a *different* binary. The @@ -1595,7 +1582,6 @@ mod tests { path: "/usr/bin/git".to_string(), ..Default::default() }], - ..Default::default() }, ); @@ -1632,7 +1618,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let mut policy = restrictive_default_policy(); @@ -1652,7 +1637,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }, ); @@ -1680,7 +1664,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let mut policy = restrictive_default_policy(); @@ -1703,7 +1686,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }, ); @@ -1727,7 +1709,6 @@ mod tests { path: "/usr/bin/git".to_string(), ..Default::default() }], - ..Default::default() }; let merged = merge_policy( @@ -1752,7 +1733,6 @@ mod tests { name: "any_binary_rule".to_string(), endpoints: vec![endpoint("api.github.com", 443)], binaries: vec![], - ..Default::default() }; let mut policy = restrictive_default_policy(); @@ -1765,7 +1745,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }, ); @@ -1823,7 +1802,6 @@ mod tests { path: "/usr/bin/gh".to_string(), ..Default::default() }], - ..Default::default() }; let composed = compose_effective_policy( &SandboxPolicy::default(), @@ -1855,7 +1833,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let result = merge_policy( composed, @@ -1924,7 +1901,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let result = merge_policy( policy, diff --git a/crates/openshell-providers/src/profiles.rs b/crates/openshell-providers/src/profiles.rs index 1eb1b54d2..ddfbcaf7d 100644 --- a/crates/openshell-providers/src/profiles.rs +++ b/crates/openshell-providers/src/profiles.rs @@ -450,7 +450,6 @@ impl ProviderTypeProfile { NetworkPolicyRule { name: rule_name.to_string(), endpoints: self.endpoints.iter().map(endpoint_to_proto).collect(), - middleware: Vec::new(), binaries: self.binaries.iter().map(binary_to_proto).collect(), } } @@ -788,7 +787,6 @@ fn endpoint_to_proto(endpoint: &EndpointProfile) -> NetworkEndpoint { request_body_credential_rewrite: endpoint.request_body_credential_rewrite, advisor_proposed: false, persisted_queries: endpoint.persisted_queries.clone(), - middleware: Vec::new(), graphql_persisted_queries: endpoint .graphql_persisted_queries .iter() diff --git a/crates/openshell-sandbox/src/mechanistic_mapper.rs b/crates/openshell-sandbox/src/mechanistic_mapper.rs index bb83ddb66..8ee2fc37f 100644 --- a/crates/openshell-sandbox/src/mechanistic_mapper.rs +++ b/crates/openshell-sandbox/src/mechanistic_mapper.rs @@ -162,7 +162,6 @@ pub fn generate_proposals(summaries: &[DenialSummary]) -> Vec { name: rule_name.clone(), endpoints: vec![endpoint], binaries, - middleware: Vec::new(), }; // Compute confidence. diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index ad4fdf5ba..d3bc213ba 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -5746,7 +5746,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let submit = handle_submit_policy_analysis( @@ -5960,7 +5959,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let submit = handle_submit_policy_analysis( @@ -6077,7 +6075,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; handle_submit_policy_analysis( @@ -6183,7 +6180,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let mechanistic_submit = handle_submit_policy_analysis( &state, @@ -6261,7 +6257,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let agent_submit = handle_submit_policy_analysis( &state, @@ -6389,7 +6384,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; handle_submit_policy_analysis( @@ -6490,7 +6484,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; handle_submit_policy_analysis( @@ -6591,7 +6584,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; handle_submit_policy_analysis( @@ -6685,7 +6677,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; handle_submit_policy_analysis( @@ -6770,7 +6761,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; handle_submit_policy_analysis( @@ -6859,7 +6849,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; handle_submit_policy_analysis( @@ -6951,7 +6940,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; handle_submit_policy_analysis( @@ -7038,7 +7026,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let response = handle_submit_policy_analysis( @@ -7100,7 +7087,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let chunk = DraftChunkRecord { id: "chunk-provider-prefix".to_string(), @@ -7215,7 +7201,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; handle_submit_policy_analysis( @@ -7312,7 +7297,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; handle_submit_policy_analysis( @@ -7398,7 +7382,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; handle_submit_policy_analysis( @@ -7540,7 +7523,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; handle_submit_policy_analysis( @@ -7666,7 +7648,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let step1 = handle_submit_policy_analysis( &state, @@ -7708,7 +7689,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let step2 = handle_submit_policy_analysis( &state, @@ -7840,7 +7820,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let submit_one = |rule_name: &str, rule: NetworkPolicyRule| { @@ -7949,7 +7928,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let submit_one = || { let state = state.clone(); @@ -8050,7 +8028,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let submit = handle_submit_policy_analysis( @@ -8182,7 +8159,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; handle_submit_policy_analysis( @@ -8381,7 +8357,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }, }; @@ -8410,7 +8385,6 @@ mod tests { path: "/usr/bin/node".to_string(), ..Default::default() }], - ..Default::default() }, }; @@ -8439,7 +8413,6 @@ mod tests { path: "/usr/bin/node".to_string(), ..Default::default() }], - ..Default::default() }, }; @@ -8467,7 +8440,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let chunk = DraftChunkRecord { id: "chunk-1".to_string(), @@ -8536,7 +8508,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }, )) .collect(), @@ -8565,7 +8536,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let chunk = DraftChunkRecord { id: "chunk-merge".to_string(), @@ -8639,7 +8609,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }, )) .collect(), @@ -8668,7 +8637,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }; let chunk = DraftChunkRecord { id: "chunk-new".to_string(), diff --git a/crates/openshell-server/src/grpc/validation.rs b/crates/openshell-server/src/grpc/validation.rs index 09e9f1cad..af6b84af6 100644 --- a/crates/openshell-server/src/grpc/validation.rs +++ b/crates/openshell-server/src/grpc/validation.rs @@ -1621,6 +1621,28 @@ mod tests { assert!(err.message().contains("TLD wildcard")); } + #[test] + fn validate_policy_safety_rejects_invalid_middleware_before_acceptance() { + use openshell_core::proto::{MiddlewareEndpointSelector, NetworkMiddlewareConfig}; + + let mut policy = openshell_policy::restrictive_default_policy(); + policy.network_middlewares.push(NetworkMiddlewareConfig { + name: "redactor".into(), + middleware: "openshell/secrets".into(), + on_error: "maybe".into(), + endpoints: Some(MiddlewareEndpointSelector { + include: vec!["api[.example.com".into()], + exclude: Vec::new(), + }), + ..Default::default() + }); + + let err = validate_policy_safety(&policy).unwrap_err(); + assert_eq!(err.code(), Code::InvalidArgument); + assert!(err.message().contains("invalid on_error")); + assert!(err.message().contains("invalid host pattern")); + } + #[test] fn validate_no_reserved_provider_policy_keys_rejects_reserved_key() { use openshell_core::proto::NetworkPolicyRule; diff --git a/crates/openshell-supervisor-middleware/Cargo.toml b/crates/openshell-supervisor-middleware/Cargo.toml index fdaeb2e82..4ae355894 100644 --- a/crates/openshell-supervisor-middleware/Cargo.toml +++ b/crates/openshell-supervisor-middleware/Cargo.toml @@ -16,10 +16,8 @@ openshell-core = { path = "../openshell-core" } miette = { workspace = true } prost-types = { workspace = true } regex = { workspace = true } -tonic = { workspace = true } - -[dev-dependencies] tokio = { workspace = true } +tonic = { workspace = true } [lints] workspace = true diff --git a/crates/openshell-supervisor-middleware/src/builtins/mod.rs b/crates/openshell-supervisor-middleware/src/builtins/mod.rs index d91ee745e..1db620220 100644 --- a/crates/openshell-supervisor-middleware/src/builtins/mod.rs +++ b/crates/openshell-supervisor-middleware/src/builtins/mod.rs @@ -2,3 +2,28 @@ // SPDX-License-Identifier: Apache-2.0 pub mod secrets; + +use miette::{Result, miette}; +use openshell_core::proto::{HttpRequestEvaluation, HttpRequestResult, MiddlewareBinding}; + +pub fn describe() -> Vec { + vec![secrets::describe()] +} + +pub fn validate_config(binding_id: &str, config: &prost_types::Struct) -> Result<()> { + match binding_id { + secrets::BINDING_ID => secrets::validate_config(config), + other => Err(miette!( + "middleware implementation '{other}' is not available in phase 1" + )), + } +} + +pub fn evaluate_http_request(evaluation: &HttpRequestEvaluation) -> Result { + match evaluation.binding_id.as_str() { + secrets::BINDING_ID => secrets::evaluate_http_request(evaluation), + other => Err(miette!( + "middleware implementation '{other}' is not available in phase 1" + )), + } +} diff --git a/crates/openshell-supervisor-middleware/src/builtins/secrets.rs b/crates/openshell-supervisor-middleware/src/builtins/secrets.rs index 572102559..d88ac080d 100644 --- a/crates/openshell-supervisor-middleware/src/builtins/secrets.rs +++ b/crates/openshell-supervisor-middleware/src/builtins/secrets.rs @@ -5,10 +5,24 @@ use std::collections::HashMap; use std::sync::LazyLock; use miette::{Result, miette}; -use openshell_core::proto::{Decision, Finding, HttpRequestEvaluation, HttpRequestResult}; +use openshell_core::proto::{ + Decision, Finding, HttpRequestEvaluation, HttpRequestResult, MiddlewareBinding, +}; use regex::Regex; -use crate::BUILTIN_SECRETS; +pub const BINDING_ID: &str = "openshell/secrets"; +const OPERATION: &str = "HttpRequest"; +const PHASE: &str = "pre_credentials"; +const MAX_BODY_BYTES: u64 = 256 * 1024; + +pub fn describe() -> MiddlewareBinding { + MiddlewareBinding { + id: BINDING_ID.into(), + operation: OPERATION.into(), + phase: PHASE.into(), + max_body_bytes: MAX_BODY_BYTES, + } +} /// A named secret-detection pattern. The `kind` is an audit-safe label that /// flows into findings so operators can see *what* matched without seeing the @@ -51,7 +65,7 @@ pub fn validate_config(config: &prost_types::Struct) -> Result<()> { if mode != "redact" { return Err(miette!( "{} only supports config.secrets: redact in phase 1", - BUILTIN_SECRETS + BINDING_ID )); } Ok(()) @@ -61,7 +75,7 @@ pub fn evaluate_http_request(evaluation: &HttpRequestEvaluation) -> Result Result<()> { - match implementation { - BUILTIN_SECRETS => builtins::secrets::validate_config(config), - other => Err(miette!( - "middleware implementation '{other}' is not available in phase 1" - )), - } + builtins::validate_config(implementation, config) } #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -85,6 +79,26 @@ impl TryFrom<&NetworkMiddlewareConfig> for ChainEntry { } } +/// A policy-selected middleware config joined with metadata reported by its +/// service's `Describe` call. A missing binding is retained so `on_error` can +/// decide whether the request fails open or closed. +#[derive(Debug, Clone)] +pub struct DescribedChainEntry { + entry: ChainEntry, + binding: Option, + max_body_bytes: usize, +} + +impl DescribedChainEntry { + pub fn max_body_bytes(&self) -> usize { + self.max_body_bytes + } + + pub fn on_error(&self) -> OnError { + self.entry.on_error + } +} + #[derive(Debug, Clone)] pub struct HttpRequestInput { pub request_id: String, @@ -138,15 +152,15 @@ enum OnErrorAction { /// Apply a middleware entry's `on_error` policy after a failure (service error or /// malformed response). Records a `failed` invocation for telemetry in both cases. fn apply_on_error( - entry: &ChainEntry, + entry: &DescribedChainEntry, reason: &str, applied: &mut Vec, ) -> OnErrorAction { - match entry.on_error { + match entry.entry.on_error { OnError::FailOpen => { applied.push(MiddlewareInvocation { - name: entry.name.clone(), - implementation: entry.implementation.clone(), + name: entry.entry.name.clone(), + implementation: entry.entry.implementation.clone(), decision: Decision::Allow, transformed: false, failed: true, @@ -155,8 +169,8 @@ fn apply_on_error( } OnError::FailClosed => { applied.push(MiddlewareInvocation { - name: entry.name.clone(), - implementation: entry.implementation.clone(), + name: entry.entry.name.clone(), + implementation: entry.entry.implementation.clone(), decision: Decision::Deny, transformed: false, failed: true, @@ -168,24 +182,102 @@ fn apply_on_error( #[derive(Clone)] pub struct ChainRunner { + state: Arc, +} + +struct MiddlewareServiceState { service: Arc, + manifest: OnceCell, } +static IN_PROCESS_SERVICE: LazyLock> = LazyLock::new(|| { + Arc::new(MiddlewareServiceState { + service: Arc::new(InProcessMiddlewareService), + manifest: OnceCell::new(), + }) +}); + impl Default for ChainRunner { fn default() -> Self { - Self::new(Arc::new(InProcessMiddlewareService)) + Self { + state: Arc::clone(&IN_PROCESS_SERVICE), + } } } impl ChainRunner { pub fn new(service: Arc) -> Self { - Self { service } + Self { + state: Arc::new(MiddlewareServiceState { + service, + manifest: OnceCell::new(), + }), + } + } + + async fn manifest(&self) -> Result<&MiddlewareManifest> { + self.state + .manifest + .get_or_try_init(|| async { + self.state + .service + .describe(Request::new(())) + .await + .map(tonic::Response::into_inner) + .map_err(|error| { + miette!( + "middleware Describe failed: {}", + safe_reason(&error.to_string()) + ) + }) + }) + .await + } + + pub async fn describe_chain(&self, entries: &[ChainEntry]) -> Result> { + let manifest = self.manifest().await?; + entries + .iter() + .map(|entry| { + let binding = manifest + .bindings + .iter() + .find(|binding| binding.id == entry.implementation) + .cloned(); + let max_body_bytes = binding + .as_ref() + .map(|binding| { + usize::try_from(binding.max_body_bytes).map_err(|_| { + miette!( + "middleware binding '{}' reports a body limit too large for this platform", + binding.id + ) + }) + }) + .transpose()? + .unwrap_or(0); + Ok(DescribedChainEntry { + entry: entry.clone(), + binding, + max_body_bytes, + }) + }) + .collect() } pub async fn evaluate( &self, entries: &[ChainEntry], input: HttpRequestInput, + ) -> Result { + let entries = self.describe_chain(entries).await?; + self.evaluate_described(&entries, input).await + } + + pub async fn evaluate_described( + &self, + entries: &[DescribedChainEntry], + input: HttpRequestInput, ) -> Result { let mut headers = input.headers.clone(); let mut body = input.body.clone(); @@ -195,8 +287,41 @@ impl ChainRunner { let mut applied = Vec::new(); for entry in entries { - let evaluation = build_evaluation(entry, &input, &headers, &body); + let Some(binding) = entry.binding.as_ref() else { + match apply_on_error(entry, "binding_not_described", &mut applied) { + OnErrorAction::FailOpen => continue, + OnErrorAction::FailClosed(reason) => { + return Ok(ChainOutcome { + allowed: false, + reason, + body, + added_headers, + findings, + metadata, + applied, + }); + } + } + }; + if body.len() > entry.max_body_bytes { + match apply_on_error(entry, "request_body_over_capacity", &mut applied) { + OnErrorAction::FailOpen => continue, + OnErrorAction::FailClosed(reason) => { + return Ok(ChainOutcome { + allowed: false, + reason, + body, + added_headers, + findings, + metadata, + applied, + }); + } + } + } + let evaluation = build_evaluation(entry, binding, &input, &headers, &body); let result = match self + .state .service .evaluate_http_request(Request::new(evaluation)) .await @@ -240,6 +365,23 @@ impl ChainRunner { } }; + if result.has_body && result.body.len() > entry.max_body_bytes { + match apply_on_error(entry, "response_body_over_capacity", &mut applied) { + OnErrorAction::FailOpen => continue, + OnErrorAction::FailClosed(reason) => { + return Ok(ChainOutcome { + allowed: false, + reason, + body, + added_headers, + findings, + metadata, + applied, + }); + } + } + } + // A result proposing unsafe header mutations is a malformed response: // route it through `on_error` instead of applying any of it. if validate_header_mutations(&headers, &result.add_headers).is_err() { @@ -268,19 +410,19 @@ impl ChainRunner { } for finding in result.findings { findings.push(NamespacedFinding { - middleware: entry.name.clone(), + middleware: entry.entry.name.clone(), finding, }); } if !result.metadata.is_empty() { metadata.insert( - entry.name.clone(), + entry.entry.name.clone(), result.metadata.clone().into_iter().collect(), ); } applied.push(MiddlewareInvocation { - name: entry.name.clone(), - implementation: entry.implementation.clone(), + name: entry.entry.name.clone(), + implementation: entry.entry.implementation.clone(), decision, transformed, failed: false, @@ -311,21 +453,22 @@ impl ChainRunner { } fn build_evaluation( - entry: &ChainEntry, + entry: &DescribedChainEntry, + binding: &MiddlewareBinding, input: &HttpRequestInput, headers: &BTreeMap, body: &[u8], ) -> HttpRequestEvaluation { HttpRequestEvaluation { api_version: API_VERSION.into(), - binding_id: entry.implementation.clone(), - phase: PRE_CREDENTIALS_PHASE.into(), + binding_id: binding.id.clone(), + phase: binding.phase.clone(), context: Some(RequestContext { request_id: input.request_id.clone(), sandbox_id: input.sandbox_id.clone(), originating_process: None, }), - config: Some(entry.config.clone()), + config: Some(entry.entry.config.clone()), target: Some(HttpRequestTarget { scheme: input.scheme.clone(), host: input.host.clone(), @@ -436,11 +579,16 @@ mod tests { } } - #[test] - fn phase_one_evaluation_omits_originating_process() { - let entry = entry("redact", OnError::FailClosed); + #[tokio::test] + async fn phase_one_evaluation_omits_originating_process() { + let entries = ChainRunner::default() + .describe_chain(&[entry("redact", OnError::FailClosed)]) + .await + .expect("describe chain"); + let entry = &entries[0]; + let binding = entry.binding.as_ref().expect("described binding"); let input = input("payload"); - let evaluation = build_evaluation(&entry, &input, &BTreeMap::new(), b"payload"); + let evaluation = build_evaluation(entry, binding, &input, &BTreeMap::new(), b"payload"); assert!( evaluation @@ -527,8 +675,9 @@ mod tests { .into_inner(); assert_eq!(manifest.api_version, API_VERSION); assert_eq!(manifest.bindings[0].id, BUILTIN_SECRETS); - assert_eq!(manifest.bindings[0].operation, HTTP_REQUEST_OPERATION); - assert_eq!(manifest.bindings[0].phase, PRE_CREDENTIALS_PHASE); + assert_eq!(manifest.bindings[0].operation, "HttpRequest"); + assert_eq!(manifest.bindings[0].phase, "pre_credentials"); + assert_eq!(manifest.bindings[0].max_body_bytes, 256 * 1024); } #[test] @@ -561,6 +710,8 @@ mod tests { /// evaluation. Used to exercise chain behavior the built-in cannot produce /// (explicit deny, metadata, findings, unsafe header mutations). struct ScriptedService { + binding_id: String, + max_body_bytes: u64, result: openshell_core::proto::HttpRequestResult, } @@ -569,13 +720,18 @@ mod tests { async fn describe( &self, _request: Request<()>, - ) -> std::result::Result< - tonic::Response, - tonic::Status, - > { - Ok(tonic::Response::new( - openshell_core::proto::MiddlewareManifest::default(), - )) + ) -> std::result::Result, tonic::Status> { + Ok(tonic::Response::new(MiddlewareManifest { + api_version: API_VERSION.into(), + name: "test/middleware".into(), + service_version: "test".into(), + bindings: vec![MiddlewareBinding { + id: self.binding_id.clone(), + operation: "HttpRequest".into(), + phase: "pre_credentials".into(), + max_body_bytes: self.max_body_bytes, + }], + })) } async fn validate_config( @@ -604,6 +760,14 @@ mod tests { } } + fn scripted_service(result: openshell_core::proto::HttpRequestResult) -> ScriptedService { + ScriptedService { + binding_id: BUILTIN_SECRETS.into(), + max_body_bytes: 256 * 1024, + result, + } + } + fn allow_result() -> openshell_core::proto::HttpRequestResult { openshell_core::proto::HttpRequestResult { decision: Decision::Allow as i32, @@ -617,14 +781,49 @@ mod tests { } #[tokio::test] - async fn deny_decision_short_circuits_chain() { + async fn descriptors_are_resolved_from_any_middleware_service() { let runner = ChainRunner::new(Arc::new(ScriptedService { - result: openshell_core::proto::HttpRequestResult { + binding_id: "example/redactor".into(), + max_body_bytes: 4096, + result: allow_result(), + })); + let entry = ChainEntry { + name: "external".into(), + implementation: "example/redactor".into(), + config: prost_types::Struct::default(), + on_error: OnError::FailClosed, + }; + + let described = runner + .describe_chain(std::slice::from_ref(&entry)) + .await + .expect("describe external middleware"); + assert_eq!(described[0].max_body_bytes(), 4096); + assert_eq!( + described[0] + .binding + .as_ref() + .expect("described binding") + .phase, + "pre_credentials" + ); + + let outcome = runner + .evaluate_described(&described, input("hello")) + .await + .expect("evaluate external middleware"); + assert!(outcome.allowed); + } + + #[tokio::test] + async fn deny_decision_short_circuits_chain() { + let runner = ChainRunner::new(Arc::new(scripted_service( + openshell_core::proto::HttpRequestResult { decision: Decision::Deny as i32, reason: "blocked_by_policy".into(), ..allow_result() }, - })); + ))); let outcome = runner .evaluate( &[ @@ -645,8 +844,8 @@ mod tests { #[tokio::test] async fn metadata_and_findings_are_namespaced_per_config() { - let runner = ChainRunner::new(Arc::new(ScriptedService { - result: openshell_core::proto::HttpRequestResult { + let runner = ChainRunner::new(Arc::new(scripted_service( + openshell_core::proto::HttpRequestResult { findings: vec![Finding { r#type: "pii.email".into(), label: "email address".into(), @@ -658,7 +857,7 @@ mod tests { .collect(), ..allow_result() }, - })); + ))); let outcome = runner .evaluate( &[ @@ -683,16 +882,14 @@ mod tests { } fn unsafe_header_service() -> ScriptedService { - ScriptedService { - result: openshell_core::proto::HttpRequestResult { - add_headers: std::iter::once(( - "x-openshell-middleware-inject".to_string(), - "ok\r\nHost: evil".to_string(), - )) - .collect(), - ..allow_result() - }, - } + scripted_service(openshell_core::proto::HttpRequestResult { + add_headers: std::iter::once(( + "x-openshell-middleware-inject".to_string(), + "ok\r\nHost: evil".to_string(), + )) + .collect(), + ..allow_result() + }) } #[tokio::test] @@ -724,13 +921,79 @@ mod tests { } #[tokio::test] - async fn unspecified_decision_uses_fail_closed() { + async fn oversized_replacement_body_honors_on_error() { let runner = ChainRunner::new(Arc::new(ScriptedService { + binding_id: BUILTIN_SECRETS.into(), + max_body_bytes: 4, result: openshell_core::proto::HttpRequestResult { - decision: Decision::Unspecified as i32, + body: b"too large".to_vec(), + has_body: true, ..allow_result() }, })); + let fail_open = entry("small", OnError::FailOpen); + let mut fail_closed = fail_open.clone(); + fail_closed.on_error = OnError::FailClosed; + + let open_outcome = runner + .evaluate(&[fail_open], input("safe")) + .await + .expect("fail-open evaluation"); + assert!(open_outcome.allowed); + assert_eq!(open_outcome.body, b"safe"); + assert!(open_outcome.applied[0].failed); + + let closed_outcome = runner + .evaluate(&[fail_closed], input("safe")) + .await + .expect("fail-closed evaluation"); + assert!(!closed_outcome.allowed); + assert_eq!( + closed_outcome.reason, + "middleware_failed: response_body_over_capacity" + ); + assert!(closed_outcome.applied[0].failed); + } + + #[tokio::test] + async fn oversized_request_body_honors_on_error() { + let runner = ChainRunner::new(Arc::new(ScriptedService { + binding_id: BUILTIN_SECRETS.into(), + max_body_bytes: 4, + result: allow_result(), + })); + let fail_open = entry("small", OnError::FailOpen); + let mut fail_closed = fail_open.clone(); + fail_closed.on_error = OnError::FailClosed; + + let open_outcome = runner + .evaluate(&[fail_open], input("hello")) + .await + .expect("fail-open evaluation"); + assert!(open_outcome.allowed); + assert_eq!(open_outcome.body, b"hello"); + assert!(open_outcome.applied[0].failed); + + let closed_outcome = runner + .evaluate(&[fail_closed], input("hello")) + .await + .expect("fail-closed evaluation"); + assert!(!closed_outcome.allowed); + assert_eq!( + closed_outcome.reason, + "middleware_failed: request_body_over_capacity" + ); + assert!(closed_outcome.applied[0].failed); + } + + #[tokio::test] + async fn unspecified_decision_uses_fail_closed() { + let runner = ChainRunner::new(Arc::new(scripted_service( + openshell_core::proto::HttpRequestResult { + decision: Decision::Unspecified as i32, + ..allow_result() + }, + ))); let outcome = runner .evaluate(&[entry("redact", OnError::FailClosed)], input("hello")) diff --git a/crates/openshell-supervisor-middleware/src/service.rs b/crates/openshell-supervisor-middleware/src/service.rs index cbd9231cd..51df8d070 100644 --- a/crates/openshell-supervisor-middleware/src/service.rs +++ b/crates/openshell-supervisor-middleware/src/service.rs @@ -3,15 +3,12 @@ use openshell_core::proto::middleware::v1::supervisor_middleware_server::SupervisorMiddleware; use openshell_core::proto::{ - HttpRequestEvaluation, HttpRequestResult, MiddlewareBinding, MiddlewareManifest, - ValidateConfigRequest, ValidateConfigResponse, + HttpRequestEvaluation, HttpRequestResult, MiddlewareManifest, ValidateConfigRequest, + ValidateConfigResponse, }; use tonic::{Request, Response, Status}; -use crate::{ - API_VERSION, BUILTIN_SECRETS, HTTP_REQUEST_OPERATION, PRE_CREDENTIALS_PHASE, builtins, - safe_reason, validate_builtin_config, -}; +use crate::{API_VERSION, builtins, safe_reason, validate_builtin_config}; #[derive(Debug, Default)] pub struct InProcessMiddlewareService; @@ -26,11 +23,7 @@ impl SupervisorMiddleware for InProcessMiddlewareService { api_version: API_VERSION.into(), name: "openshell/in-process".into(), service_version: env!("CARGO_PKG_VERSION").into(), - bindings: vec![MiddlewareBinding { - id: BUILTIN_SECRETS.into(), - operation: HTTP_REQUEST_OPERATION.into(), - phase: PRE_CREDENTIALS_PHASE.into(), - }], + bindings: builtins::describe(), })) } @@ -58,13 +51,8 @@ impl SupervisorMiddleware for InProcessMiddlewareService { request: Request, ) -> Result, Status> { let request = request.into_inner(); - let result = match request.binding_id.as_str() { - BUILTIN_SECRETS => builtins::secrets::evaluate_http_request(&request), - other => Err(miette::miette!( - "middleware implementation '{other}' is not available in phase 1" - )), - } - .map_err(|err| Status::invalid_argument(safe_reason(&err.to_string())))?; + let result = builtins::evaluate_http_request(&request) + .map_err(|err| Status::invalid_argument(safe_reason(&err.to_string())))?; Ok(Response::new(result)) } } diff --git a/crates/openshell-supervisor-network/data/sandbox-policy.rego b/crates/openshell-supervisor-network/data/sandbox-policy.rego index 52f6f1046..fcc5838e1 100644 --- a/crates/openshell-supervisor-network/data/sandbox-policy.rego +++ b/crates/openshell-supervisor-network/data/sandbox-policy.rego @@ -842,23 +842,6 @@ _policy_endpoint_configs(policy) := [ep | endpoint_has_extended_config(ep) ] -# Collect matching endpoint identities across all policies. Iterates over -# _matching_policy_names (a set, safe from regorus variable collisions) then -# returns the selected policy name plus endpoint index/path. Rust uses that -# identity to look up middleware attachment from policy data. -_matching_endpoint_contexts := [ctx | - some pname - _matching_policy_names[pname] - policy := data.network_policies[pname] - ep := policy.endpoints[i] - endpoint_matches_request(ep, input.network) - ctx := { - "policy": pname, - "endpoint_index": i, - "endpoint_path": object.get(ep, "path", ""), - } -] - _matching_endpoint_configs := [cfg | some pname _matching_policy_names[pname] @@ -871,8 +854,6 @@ matched_endpoint_config := _matching_endpoint_configs[0] if { count(_matching_endpoint_configs) > 0 } -network_policies := object.get(data, "network_policies", {}) - network_middlewares := object.get(data, "network_middlewares", []) _policy_has_exact_declared_endpoint(policy) if { diff --git a/crates/openshell-supervisor-network/src/l7/relay.rs b/crates/openshell-supervisor-network/src/l7/relay.rs index c773fdcf4..8383b6bb2 100644 --- a/crates/openshell-supervisor-network/src/l7/relay.rs +++ b/crates/openshell-supervisor-network/src/l7/relay.rs @@ -453,8 +453,7 @@ where let _ = &eval_target; if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { - let chain = - engine.query_middleware_chain(&middleware_network_input(ctx), &req.target)?; + let chain = engine.query_middleware_chain(&middleware_network_input(ctx))?; let req = match apply_middleware_chain(req, client, ctx, chain, engine.generation_guard()) .await? @@ -773,20 +772,45 @@ pub(crate) enum MiddlewareApplyResult { Denied(String), } +fn middleware_chain_body_limit( + chain: &[openshell_supervisor_middleware::DescribedChainEntry], +) -> Option { + chain + .iter() + .map(openshell_supervisor_middleware::DescribedChainEntry::max_body_bytes) + .min() +} + pub(crate) async fn apply_middleware_chain( req: crate::l7::provider::L7Request, client: &mut C, ctx: &L7EvalContext, chain: Vec, generation_guard: &PolicyGenerationGuard, +) -> Result { + apply_middleware_chain_for_scheme(req, client, ctx, "https", chain, generation_guard).await +} + +pub(crate) async fn apply_middleware_chain_for_scheme( + req: crate::l7::provider::L7Request, + client: &mut C, + ctx: &L7EvalContext, + scheme: &str, + chain: Vec, + generation_guard: &PolicyGenerationGuard, ) -> Result { if chain.is_empty() { return Ok(MiddlewareApplyResult::Allowed(req)); } + let runner = openshell_supervisor_middleware::ChainRunner::default(); + let chain = runner.describe_chain(&chain).await?; + let max_body_bytes = + middleware_chain_body_limit(&chain).expect("non-empty middleware chain has a body limit"); let buffered = match crate::l7::rest::buffer_request_body_for_middleware( &req, client, Some(generation_guard), + max_body_bytes, ) .await? { @@ -797,21 +821,8 @@ pub(crate) async fn apply_middleware_chain, + query: String, + body: Vec, +) -> openshell_supervisor_middleware::HttpRequestInput { + openshell_supervisor_middleware::HttpRequestInput { + request_id: uuid::Uuid::new_v4().to_string(), + sandbox_id: openshell_ocsf::ctx::ctx().sandbox_id.clone(), + scheme: scheme.into(), + host: ctx.host.clone(), + port: ctx.port, + method: req.action.clone(), + path: req.target.clone(), + query, + headers, + body, + } +} + fn raw_query_from_request_headers(headers: &[u8]) -> Result { let header_str = std::str::from_utf8(headers).map_err(|_| miette!("HTTP headers contain invalid UTF-8"))?; @@ -846,12 +879,12 @@ fn raw_query_from_request_headers(headers: &[u8]) -> Result { fn resolve_unbuffered_body( ctx: &L7EvalContext, req: crate::l7::provider::L7Request, - chain: &[openshell_supervisor_middleware::ChainEntry], + chain: &[openshell_supervisor_middleware::DescribedChainEntry], recoverable: bool, ) -> MiddlewareApplyResult { let all_fail_open = chain .iter() - .all(|entry| entry.on_error == openshell_supervisor_middleware::OnError::FailOpen); + .all(|entry| entry.on_error() == openshell_supervisor_middleware::OnError::FailOpen); if recoverable && all_fail_open { emit_middleware_body_unavailable(ctx, false); return MiddlewareApplyResult::Allowed(req); @@ -1187,8 +1220,7 @@ where let _ = &eval_target; if allowed || config.enforcement == EnforcementMode::Audit { - let chain = - engine.query_middleware_chain(&middleware_network_input(ctx), &req.target)?; + let chain = engine.query_middleware_chain(&middleware_network_input(ctx))?; let req = match apply_middleware_chain(req, client, ctx, chain, engine.generation_guard()) .await? @@ -1457,8 +1489,7 @@ where } if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { - let chain = - engine.query_middleware_chain(&middleware_network_input(ctx), &req.target)?; + let chain = engine.query_middleware_chain(&middleware_network_input(ctx))?; let req = match apply_middleware_chain(req, client, ctx, chain, engine.generation_guard()) .await? @@ -1682,8 +1713,7 @@ where let _ = &eval_target; if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { - let chain = - engine.query_middleware_chain(&middleware_network_input(ctx), &req.target)?; + let chain = engine.query_middleware_chain(&middleware_network_input(ctx))?; let req = match apply_middleware_chain(req, client, ctx, chain, engine.generation_guard()) .await? @@ -2136,8 +2166,7 @@ where let req = if let Some(engine) = middleware_engine { let input = middleware_network_input(ctx); - let (chain, generation) = - engine.query_middleware_chain_with_generation(&input, &req.target)?; + let (chain, generation) = engine.query_middleware_chain_with_generation(&input)?; if generation != generation_guard.captured_generation() { return Ok(()); } @@ -2326,10 +2355,11 @@ network_middlewares: - name: request-middleware middleware: {middleware_impl} on_error: {on_error} + endpoints: + include: ["api.example.test"] network_policies: rest_api: name: rest_api - middleware: ["request-middleware"] endpoints: - host: api.example.test port: 8080 @@ -2785,10 +2815,11 @@ network_middlewares: - name: request-middleware middleware: example/unavailable on_error: fail_closed + endpoints: + include: ["api.example.test"] network_policies: jsonrpc_api: name: jsonrpc_api - middleware: ["request-middleware"] endpoints: - host: api.example.test port: 443 @@ -2929,8 +2960,8 @@ network_policies: .unwrap(); } - #[test] - fn over_capacity_resolution_honors_on_error() { + #[tokio::test] + async fn over_capacity_resolution_honors_on_error() { use openshell_supervisor_middleware::{ChainEntry, OnError}; let ctx = L7EvalContext { @@ -2963,21 +2994,31 @@ network_policies: ..fail_open.clone() }; + let runner = openshell_supervisor_middleware::ChainRunner::default(); + let open_chain = runner + .describe_chain(std::slice::from_ref(&fail_open)) + .await + .expect("describe fail-open chain"); + let mixed_chain = runner + .describe_chain(&[fail_open.clone(), fail_closed]) + .await + .expect("describe mixed chain"); + // Recoverable (Content-Length over cap, nothing consumed) + all fail-open // -> stream through unprocessed. assert!(matches!( - resolve_unbuffered_body(&ctx, req(), std::slice::from_ref(&fail_open), true), + resolve_unbuffered_body(&ctx, req(), &open_chain, true), MiddlewareApplyResult::Allowed(_) )); // Any fail-closed entry -> deny. assert!(matches!( - resolve_unbuffered_body(&ctx, req(), &[fail_open.clone(), fail_closed], true), + resolve_unbuffered_body(&ctx, req(), &mixed_chain, true), MiddlewareApplyResult::Denied(_) )); // Not recoverable (chunked overflow already consumed bytes) -> deny even // when every entry is fail-open. assert!(matches!( - resolve_unbuffered_body(&ctx, req(), &[fail_open], false), + resolve_unbuffered_body(&ctx, req(), &open_chain, false), MiddlewareApplyResult::Denied(_) )); } @@ -2992,6 +3033,40 @@ network_policies: assert_eq!(query, "token=a%2Bb&scope=private"); } + #[test] + fn middleware_request_input_preserves_plain_http_scheme() { + let req = crate::l7::provider::L7Request { + action: "POST".into(), + target: "/v1/messages".into(), + query_params: std::collections::HashMap::new(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }; + let ctx = L7EvalContext { + host: "api.example.test".into(), + port: 80, + policy_name: "api".into(), + binary_path: "/usr/bin/curl".into(), + ancestors: Vec::new(), + cmdline_paths: Vec::new(), + secret_resolver: None, + activity_tx: None, + dynamic_credentials: None, + token_grant_resolver: None, + }; + + let input = middleware_request_input( + "http", + &req, + &ctx, + BTreeMap::new(), + String::new(), + Vec::new(), + ); + + assert_eq!(input.scheme, "http"); + } + /// Tracing layer that captures emitted `OcsfEvent`s for assertions. struct OcsfCaptureLayer(Arc>>); @@ -3096,16 +3171,17 @@ network_policies: #[tokio::test] async fn passthrough_relay_runs_middleware_redaction() { // A no-protocol endpoint takes the credential-injection passthrough path; - // policy-level middleware must still inspect and redact its body. + // host-selected middleware must still inspect and redact its body. let data = r#" network_middlewares: - name: request-middleware middleware: openshell/secrets on_error: fail_closed + endpoints: + include: ["api.example.test"] network_policies: passthrough_api: name: passthrough_api - middleware: ["request-middleware"] endpoints: - host: api.example.test port: 8080 @@ -3197,10 +3273,11 @@ network_middlewares: - name: request-middleware middleware: example/unavailable on_error: fail_closed + endpoints: + include: ["gateway.example.test"] network_policies: ws_api: name: ws_api - middleware: ["request-middleware"] endpoints: - host: gateway.example.test port: 443 diff --git a/crates/openshell-supervisor-network/src/l7/rest.rs b/crates/openshell-supervisor-network/src/l7/rest.rs index 19f73e2ad..15825d1b2 100644 --- a/crates/openshell-supervisor-network/src/l7/rest.rs +++ b/crates/openshell-supervisor-network/src/l7/rest.rs @@ -27,7 +27,19 @@ const MAX_REWRITE_BODY_BYTES: usize = 256 * 1024; /// Maximum body bytes for `SigV4` body-signing mode. Larger than the credential /// rewrite limit because Bedrock payloads can be several megabytes. const MAX_SIGV4_BODY_BYTES: usize = 10 * 1024 * 1024; -pub(crate) const MAX_MIDDLEWARE_BODY_BYTES: usize = MAX_REWRITE_BODY_BYTES; +#[cfg(test)] +async fn max_middleware_body_bytes() -> usize { + let chain = openshell_supervisor_middleware::ChainRunner::default() + .describe_chain(&[openshell_supervisor_middleware::ChainEntry { + name: "test".into(), + implementation: openshell_supervisor_middleware::BUILTIN_SECRETS.into(), + config: prost_types::Struct::default(), + on_error: openshell_supervisor_middleware::OnError::FailClosed, + }]) + .await + .expect("describe built-in middleware"); + chain[0].max_body_bytes() +} const RELAY_BUF_SIZE: usize = 8192; const HTTP_METHOD_PREFIXES: &[&[u8]] = &[ b"GET ", @@ -820,6 +832,7 @@ pub(crate) async fn buffer_request_body_for_middleware( req: &L7Request, client: &mut C, generation_guard: Option<&PolicyGenerationGuard>, + max_body_bytes: usize, ) -> Result { let header_end = req .raw_header @@ -840,7 +853,7 @@ pub(crate) async fn buffer_request_body_for_middleware( let Ok(len) = usize::try_from(len) else { return Ok(BufferResult::OverCapacity { recoverable: true }); }; - if len > MAX_MIDDLEWARE_BODY_BYTES { + if len > max_body_bytes { return Ok(BufferResult::OverCapacity { recoverable: true }); } let initial_len = already_read.len().min(len); @@ -869,14 +882,18 @@ pub(crate) async fn buffer_request_body_for_middleware( } BodyLength::Chunked => { // Chunked bodies are decoded incrementally into the payload bytes - // middleware expects. On overflow, we have already consumed wire - // bytes from the client stream and cannot re-enter the normal raw - // relay path without a separate splice-through buffer. - Ok(collect_chunked_body(client, already_read, generation_guard) - .await - .map_or(BufferResult::OverCapacity { recoverable: false }, |body| { - BufferResult::Buffered(BufferedRequestBody { headers, body }) - })) + // middleware expects, but the middleware cap counts the complete + // wire representation, including framing and trailers. On overflow, + // we have already consumed wire bytes from the client stream and + // cannot re-enter the normal raw relay path without a separate + // splice-through buffer. + Ok( + collect_chunked_body(client, already_read, generation_guard, Some(max_body_bytes)) + .await + .map_or(BufferResult::OverCapacity { recoverable: false }, |body| { + BufferResult::Buffered(BufferedRequestBody { headers, body }) + }), + ) } } } @@ -953,7 +970,7 @@ async fn collect_and_rewrite_request_body( Ok(PreparedRequestBody { headers, body }) } BodyLength::Chunked => { - let body = collect_chunked_body(client, already_read, generation_guard).await?; + let body = collect_chunked_body(client, already_read, generation_guard, None).await?; let (mut headers, body) = rewrite_buffered_body(rewritten_headers, original_header_str, body, resolver)?; headers = set_content_length(&headers, body.len())?; @@ -1125,15 +1142,19 @@ async fn collect_chunked_body( client: &mut C, already_read: &[u8], generation_guard: Option<&PolicyGenerationGuard>, + max_wire_bytes: Option, ) -> Result> { - let mut buffered_pos = 0usize; + let mut read_state = ChunkedReadState { + buffered_pos: 0, + wire_bytes: 0, + max_wire_bytes, + }; let mut body = Vec::new(); loop { - let size_line = - read_chunked_line(client, already_read, &mut buffered_pos, generation_guard) - .await - .map_err(|e| miette!("Chunked body ended before chunk-size line: {e}"))?; + let size_line = read_chunked_line(client, already_read, &mut read_state, generation_guard) + .await + .map_err(|e| miette!("Chunked body ended before chunk-size line: {e}"))?; let size_line = std::str::from_utf8(&size_line) .into_diagnostic() .map_err(|_| miette!("Invalid UTF-8 in chunk-size line"))?; @@ -1149,7 +1170,7 @@ async fn collect_chunked_body( if chunk_size == 0 { loop { let trailer_line = - read_chunked_line(client, already_read, &mut buffered_pos, generation_guard) + read_chunked_line(client, already_read, &mut read_state, generation_guard) .await .map_err(|e| { miette!("Chunked body ended before trailer terminator: {e}") @@ -1168,7 +1189,7 @@ async fn collect_chunked_body( read_buffered_exact( client, already_read, - &mut buffered_pos, + &mut read_state, chunk_size, &mut body, generation_guard, @@ -1180,7 +1201,7 @@ async fn collect_chunked_body( read_buffered_exact( client, already_read, - &mut buffered_pos, + &mut read_state, 2, &mut chunk_crlf, generation_guard, @@ -1193,15 +1214,21 @@ async fn collect_chunked_body( } } +struct ChunkedReadState { + buffered_pos: usize, + wire_bytes: usize, + max_wire_bytes: Option, +} + async fn read_chunked_line( client: &mut C, already_read: &[u8], - buffered_pos: &mut usize, + state: &mut ChunkedReadState, generation_guard: Option<&PolicyGenerationGuard>, ) -> Result> { let mut line = Vec::new(); loop { - let byte = read_buffered_byte(client, already_read, buffered_pos, generation_guard).await?; + let byte = read_buffered_byte(client, already_read, state, generation_guard).await?; line.push(byte); if line.len() > MAX_REWRITE_BODY_BYTES { return Err(miette!( @@ -1218,13 +1245,13 @@ async fn read_chunked_line( async fn read_buffered_exact( client: &mut C, already_read: &[u8], - buffered_pos: &mut usize, + state: &mut ChunkedReadState, len: usize, out: &mut Vec, generation_guard: Option<&PolicyGenerationGuard>, ) -> Result<()> { for _ in 0..len { - let byte = read_buffered_byte(client, already_read, buffered_pos, generation_guard).await?; + let byte = read_buffered_byte(client, already_read, state, generation_guard).await?; out.push(byte); } Ok(()) @@ -1233,18 +1260,30 @@ async fn read_buffered_exact( async fn read_buffered_byte( client: &mut C, already_read: &[u8], - buffered_pos: &mut usize, + state: &mut ChunkedReadState, generation_guard: Option<&PolicyGenerationGuard>, ) -> Result { - if *buffered_pos < already_read.len() { - let byte = already_read[*buffered_pos]; - *buffered_pos += 1; - return Ok(byte); - } - let byte = client.read_u8().await.into_diagnostic()?; - if let Some(guard) = generation_guard { - guard.ensure_current()?; + if state + .max_wire_bytes + .is_some_and(|max| state.wire_bytes >= max) + { + return Err(miette!( + "chunked body wire representation exceeds middleware buffer limit" + )); } + + let byte = if state.buffered_pos < already_read.len() { + let byte = already_read[state.buffered_pos]; + state.buffered_pos += 1; + byte + } else { + let byte = client.read_u8().await.into_diagnostic()?; + if let Some(guard) = generation_guard { + guard.ensure_current()?; + } + byte + }; + state.wire_bytes += 1; Ok(byte) } @@ -3264,6 +3303,7 @@ mod tests { &mut client, b"5\r\nhello\r\n6;ext=value\r\n world\r\n0\r\nx-checksum: abc\r\n\r\n", None, + None, ) .await .expect("chunked body should decode"); @@ -3271,6 +3311,40 @@ mod tests { assert_eq!(body, b"hello world"); } + #[tokio::test] + async fn middleware_chunked_wire_body_at_cap_is_allowed() { + let max_body_bytes = max_middleware_body_bytes().await; + let payload_len = max_body_bytes - 14; + let mut wire = format!("{payload_len:x}\r\n").into_bytes(); + wire.extend(std::iter::repeat_n(b'x', payload_len)); + wire.extend_from_slice(b"\r\n0\r\n\r\n"); + assert_eq!(wire.len(), max_body_bytes); + + let body = collect_chunked_body(&mut tokio::io::empty(), &wire, None, Some(max_body_bytes)) + .await + .expect("wire representation at the cap should be allowed"); + + assert_eq!(body.len(), payload_len); + } + + #[tokio::test] + async fn middleware_chunked_wire_body_over_cap_is_rejected() { + let max_body_bytes = max_middleware_body_bytes().await; + let payload_len = max_body_bytes - 13; + let mut wire = format!("{payload_len:x}\r\n").into_bytes(); + wire.extend(std::iter::repeat_n(b'x', payload_len)); + wire.extend_from_slice(b"\r\n0\r\n\r\n"); + assert_eq!(wire.len(), max_body_bytes + 1); + assert!(payload_len < max_body_bytes); + + let error = + collect_chunked_body(&mut tokio::io::empty(), &wire, None, Some(max_body_bytes)) + .await + .expect_err("wire framing over the cap must be rejected"); + + assert!(error.to_string().contains("wire representation")); + } + /// SEC-009: Bare LF in headers enables header injection. #[tokio::test] async fn reject_bare_lf_in_headers() { diff --git a/crates/openshell-supervisor-network/src/opa.rs b/crates/openshell-supervisor-network/src/opa.rs index 3efec0212..c4e773996 100644 --- a/crates/openshell-supervisor-network/src/opa.rs +++ b/crates/openshell-supervisor-network/src/opa.rs @@ -135,17 +135,13 @@ impl TunnelPolicyEngine { &self.engine } - /// Query the ordered middleware chain for a request path within this tunnel. - pub fn query_middleware_chain( - &self, - input: &NetworkInput, - request_path: &str, - ) -> Result> { + /// Query the ordered middleware chain for a destination within this tunnel. + pub fn query_middleware_chain(&self, input: &NetworkInput) -> Result> { let mut engine = self .engine .lock() .map_err(|_| miette::miette!("OPA engine lock poisoned"))?; - query_middleware_chain_locked(&mut engine, input, request_path) + query_middleware_chain_locked(&mut engine, input) } } @@ -208,21 +204,21 @@ impl OpaEngine { /// gap between user-specified symlink paths (e.g., `/usr/bin/python3`) and /// kernel-resolved canonical paths (e.g., `/usr/bin/python3.11`). pub fn from_proto_with_pid(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> Result { + if let Err(violations) = openshell_policy::validate_sandbox_policy(proto) { + let errors = violations + .iter() + .map(ToString::to_string) + .collect::>() + .join("\n"); + return Err(miette::miette!("policy validation failed:\n{errors}")); + } + let data_json_str = proto_to_opa_data_json(proto, entrypoint_pid); // Parse back to Value for preprocessing, then re-serialize let mut data: serde_json::Value = serde_json::from_str(&data_json_str) .map_err(|e| miette::miette!("internal: failed to parse proto JSON: {e}"))?; - // Validate BEFORE expanding presets - let middleware_errors = validate_middleware_policies(&data); - if !middleware_errors.is_empty() { - return Err(miette::miette!( - "middleware policy validation failed:\n{}", - middleware_errors.join("\n") - )); - } - let (errors, warnings) = crate::l7::validate_l7_policies(&data); for w in &warnings { openshell_ocsf::ocsf_emit!( @@ -571,18 +567,17 @@ impl OpaEngine { } } - /// Query the ordered middleware chain for a parsed HTTP request path. + /// Query the ordered middleware chain for an admitted destination. pub fn query_middleware_chain_with_generation( &self, input: &NetworkInput, - request_path: &str, ) -> Result<(Vec, u64)> { let mut engine = self .engine .lock() .map_err(|_| miette::miette!("OPA engine lock poisoned"))?; let generation = self.current_generation(); - let chain = query_middleware_chain_locked(&mut engine, input, request_path)?; + let chain = query_middleware_chain_locked(&mut engine, input)?; Ok((chain, generation)) } @@ -749,17 +744,9 @@ fn network_input_json(input: &NetworkInput) -> serde_json::Value { }) } -#[derive(Debug, Clone)] -struct MatchedEndpointContext { - policy_name: String, - endpoint_index: usize, - endpoint_path: String, -} - fn query_middleware_chain_locked( engine: &mut regorus::Engine, input: &NetworkInput, - request_path: &str, ) -> Result> { engine .set_input_json(&network_input_json(input).to_string()) @@ -772,37 +759,7 @@ fn query_middleware_chain_locked( if configs.is_empty() { return Ok(Vec::new()); } - let contexts_val = engine - .eval_rule("data.openshell.sandbox._matching_endpoint_contexts".into()) - .map_err(|e| miette::miette!("{e}"))?; - let contexts = parse_endpoint_contexts(&contexts_val); - let policies_val = engine - .eval_rule("data.openshell.sandbox.network_policies".into()) - .map_err(|e| miette::miette!("{e}"))?; - let Some(context) = select_endpoint_context(&contexts, request_path, &policies_val)? else { - return global_middleware_entries(&configs, &input.host, &HashSet::new()); - }; - let (policy_middleware, endpoint_middleware) = - middleware_for_endpoint_identity(&policies_val, context)?; - - let mut explicit = Vec::new(); - for name in policy_middleware.iter().chain(endpoint_middleware.iter()) { - if !explicit.contains(name) { - explicit.push(name.clone()); - } - } - let explicit_set: HashSet = explicit.iter().cloned().collect(); - let mut ordered = global_middleware_entries(&configs, &input.host, &explicit_set)?; - for name in explicit { - if !ordered.iter().any(|entry| entry.name == name) { - let config = configs - .iter() - .find(|config| get_str(config, "name").as_deref() == Some(name.as_str())) - .ok_or_else(|| miette::miette!("unknown middleware config '{name}'"))?; - ordered.push(chain_entry_from_value(config)?); - } - } - Ok(ordered) + global_middleware_entries(&configs, &input.host) } fn parse_middleware_configs(value: ®orus::Value) -> Result> { @@ -815,169 +772,37 @@ fn parse_middleware_configs(value: ®orus::Value) -> Result Vec { - let regorus::Value::Array(values) = value else { - return Vec::new(); - }; - values - .iter() - .filter_map(|value| { - let regorus::Value::Object(_) = value else { - return None; - }; - Some(MatchedEndpointContext { - policy_name: get_str(value, "policy").unwrap_or_default(), - endpoint_index: get_usize(value, "endpoint_index").unwrap_or_default(), - endpoint_path: get_str(value, "endpoint_path").unwrap_or_default(), - }) - }) - .collect() -} - -fn middleware_for_endpoint_identity( - policies: ®orus::Value, - context: &MatchedEndpointContext, -) -> Result<(Vec, Vec)> { - let policy = get_field(policies, &context.policy_name).ok_or_else(|| { - miette::miette!( - "matched endpoint policy '{}' was not found in OPA data", - context.policy_name - ) - })?; - let endpoint = get_array(policy, "endpoints") - .and_then(|endpoints| endpoints.get(context.endpoint_index)) - .ok_or_else(|| { - miette::miette!( - "matched endpoint {}[{}] was not found in OPA data", - context.policy_name, - context.endpoint_index - ) - })?; - Ok(( - get_str_array(policy, "middleware"), - get_str_array(endpoint, "middleware"), - )) -} - -fn select_endpoint_context<'a>( - contexts: &'a [MatchedEndpointContext], - request_path: &str, - policies: ®orus::Value, -) -> Result> { - let matching: Vec<_> = contexts - .iter() - .filter(|context| crate::l7::endpoint_path_matches(&context.endpoint_path, request_path)) - .map(|context| (endpoint_path_specificity(&context.endpoint_path), context)) - .collect(); - let Some(max_specificity) = matching.iter().map(|(specificity, _)| *specificity).max() else { - return Ok(None); - }; - let best: Vec<_> = matching - .into_iter() - .filter(|(specificity, _)| *specificity == max_specificity) - .map(|(_, context)| context) - .collect(); - if let Some((first, rest)) = best.split_first() { - let first_middleware = explicit_middleware_for_endpoint_identity(policies, first)?; - for context in rest { - if explicit_middleware_for_endpoint_identity(policies, context)? != first_middleware { - let matches = best - .iter() - .map(|context| { - format!( - "{}[{}] path={}", - context.policy_name, - context.endpoint_index, - if context.endpoint_path.is_empty() { - "" - } else { - context.endpoint_path.as_str() - } - ) - }) - .collect::>() - .join(", "); - return Err(miette::miette!( - "ambiguous middleware endpoint match for request path '{request_path}': {matches}" - )); - } - } - } - Ok(best.into_iter().next()) -} - -fn explicit_middleware_for_endpoint_identity( - policies: ®orus::Value, - context: &MatchedEndpointContext, -) -> Result> { - let (policy_middleware, endpoint_middleware) = - middleware_for_endpoint_identity(policies, context)?; - Ok(dedup_middleware_names( - policy_middleware.iter().chain(endpoint_middleware.iter()), - )) -} - -fn dedup_middleware_names<'a>(names: impl IntoIterator) -> Vec { - let mut deduped = Vec::new(); - for name in names { - if !deduped.contains(name) { - deduped.push(name.clone()); - } - } - deduped -} - -fn endpoint_path_specificity(path: &str) -> usize { - if path.is_empty() { - 0 - } else { - path.chars().filter(|c| *c != '*').count() - } -} - -fn global_middleware_entries( - configs: &[regorus::Value], - host: &str, - explicit: &HashSet, -) -> Result> { +fn global_middleware_entries(configs: &[regorus::Value], host: &str) -> Result> { let mut entries = Vec::new(); for config in configs { - let name = get_str(config, "name").unwrap_or_default(); - if explicit.contains(&name) { - continue; - } - if middleware_selector_matches(config, host) { + if middleware_selector_matches(config, host)? { entries.push(chain_entry_from_value(config)?); } } Ok(entries) } -fn middleware_selector_matches(config: ®orus::Value, host: &str) -> bool { +fn middleware_selector_matches(config: ®orus::Value, host: &str) -> Result { let Some(selector) = get_field(config, "endpoints") else { - return false; + return Ok(false); }; let include_patterns = get_str_array(selector, "include"); let exclude_patterns = get_str_array(selector, "exclude"); - let matches_include = !include_patterns.is_empty() - && include_patterns - .iter() - .any(|pattern| host_matches(pattern, host)); + let matches_include = include_patterns + .iter() + .try_fold(false, |matched, pattern| { + openshell_policy::middleware_host_matches(pattern, host) + .map(|matches| matched || matches) + .map_err(|error| miette::miette!(error)) + })?; let matches_exclude = exclude_patterns .iter() - .any(|pattern| host_matches(pattern, host)); - matches_include && !matches_exclude -} - -fn host_matches(pattern: &str, host: &str) -> bool { - if pattern == "*" || pattern == "**" { - return true; - } - if !pattern.contains('*') { - return pattern.eq_ignore_ascii_case(host); - } - glob::Pattern::new(&pattern.to_ascii_lowercase()) - .is_ok_and(|pattern| pattern.matches(&host.to_ascii_lowercase())) + .try_fold(false, |matched, pattern| { + openshell_policy::middleware_host_matches(pattern, host) + .map(|matches| matched || matches) + .map_err(|error| miette::miette!(error)) + })?; + Ok(matches_include && !matches_exclude) } fn chain_entry_from_value(value: ®orus::Value) -> Result { @@ -1003,25 +828,6 @@ fn get_field<'a>(val: &'a regorus::Value, key: &str) -> Option<&'a regorus::Valu } } -fn get_array<'a>(val: &'a regorus::Value, key: &str) -> Option<&'a [regorus::Value]> { - let regorus::Value::Array(values) = get_field(val, key)? else { - return None; - }; - Some(values) -} - -fn get_usize(val: ®orus::Value, key: &str) -> Option { - let value = get_field(val, key)?; - let regorus::Value::Number(number) = value else { - return None; - }; - let value = number.as_f64()?; - if !value.is_finite() || value.fract() != 0.0 || value < 0.0 { - return None; - } - format!("{value:.0}").parse::().ok() -} - fn regorus_value_to_struct(value: ®orus::Value) -> prost_types::Struct { let regorus::Value::Object(map) = value else { return prost_types::Struct::default(); @@ -1383,6 +1189,29 @@ fn validate_middleware_policies(data: &serde_json::Value) -> Vec { "middleware config '{name}' has invalid on_error '{on_error}'" )); } + + let Some(selector) = mw.get("endpoints") else { + errors.push(format!( + "middleware config '{name}' requires an endpoint selector" + )); + continue; + }; + let includes = json_string_array(selector.get("include")); + let excludes = json_string_array(selector.get("exclude")); + if includes.is_empty() { + errors.push(format!( + "middleware config '{name}' endpoint selector must include at least one host pattern" + )); + } + for pattern in includes.iter().chain(&excludes) { + if let Err(error) = + openshell_policy::middleware_host_matches(pattern, "validation.invalid") + { + errors.push(format!( + "middleware config '{name}' has invalid endpoint selector pattern '{pattern}': {error}" + )); + } + } } let Some(policies) = data @@ -1393,127 +1222,25 @@ fn validate_middleware_policies(data: &serde_json::Value) -> Vec { }; for (policy_name, policy) in policies { - let policy_middleware = json_string_array(policy.get("middleware")); - for name in &policy_middleware { - if !names.contains(name) { - errors.push(format!( - "network policy '{policy_name}' references unknown middleware config '{name}'" - )); - } - } for endpoint in policy .get("endpoints") .and_then(serde_json::Value::as_array) .map_or(&[][..], Vec::as_slice) { - let endpoint_middleware = json_string_array(endpoint.get("middleware")); - for name in &endpoint_middleware { - if !names.contains(name) { - errors.push(format!( - "network policy '{policy_name}' endpoint references unknown middleware config '{name}'" - )); - } - } let tls_skip = endpoint .get("tls") .and_then(serde_json::Value::as_str) .is_some_and(|tls| tls == "skip"); - if tls_skip && (!policy_middleware.is_empty() || !endpoint_middleware.is_empty()) { - errors.push(format!( - "network policy '{policy_name}' attaches middleware to a tls: skip endpoint" - )); - } if tls_skip && global_selector_matches_any_middleware(middlewares, endpoint) { errors.push(format!( "network policy '{policy_name}' tls: skip endpoint matches a global middleware selector" )); } } - validate_ambiguous_middleware_endpoints( - policy_name, - policy, - &policy_middleware, - &mut errors, - ); } errors } -fn validate_ambiguous_middleware_endpoints( - policy_name: &str, - policy: &serde_json::Value, - policy_middleware: &[String], - errors: &mut Vec, -) { - let endpoints = policy - .get("endpoints") - .and_then(serde_json::Value::as_array) - .map_or(&[][..], Vec::as_slice); - let mut seen: Vec<(usize, MiddlewareEndpointKey, Vec)> = Vec::new(); - for (index, endpoint) in endpoints.iter().enumerate() { - let key = middleware_endpoint_key(endpoint); - let endpoint_middleware = json_string_array(endpoint.get("middleware")); - let chain = - dedup_middleware_names(policy_middleware.iter().chain(endpoint_middleware.iter())); - for (previous_index, previous_key, previous_chain) in &seen { - if previous_key == &key && previous_chain != &chain { - errors.push(format!( - "network policy '{policy_name}' endpoints[{previous_index}] and endpoints[{index}] have equivalent middleware selection keys ({key}) but different middleware chains" - )); - } - } - seen.push((index, key, chain)); - } -} - -#[derive(Debug, PartialEq, Eq)] -struct MiddlewareEndpointKey { - host: String, - ports: Vec, - path: String, -} - -impl std::fmt::Display for MiddlewareEndpointKey { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!( - f, - "host={} ports={:?} path={}", - if self.host.is_empty() { - "" - } else { - self.host.as_str() - }, - self.ports, - if self.path.is_empty() { - "" - } else { - self.path.as_str() - } - ) - } -} - -fn middleware_endpoint_key(endpoint: &serde_json::Value) -> MiddlewareEndpointKey { - let host = endpoint - .get("host") - .and_then(serde_json::Value::as_str) - .unwrap_or_default() - .to_ascii_lowercase(); - let mut ports: Vec = endpoint - .get("ports") - .and_then(serde_json::Value::as_array) - .map(|ports| ports.iter().filter_map(serde_json::Value::as_u64).collect()) - .unwrap_or_default(); - ports.sort_unstable(); - ports.dedup(); - let path = endpoint - .get("path") - .and_then(serde_json::Value::as_str) - .unwrap_or_default() - .to_string(); - MiddlewareEndpointKey { host, ports, path } -} - fn json_string_array(value: Option<&serde_json::Value>) -> Vec { value .and_then(serde_json::Value::as_array) @@ -1542,8 +1269,12 @@ fn global_selector_matches_any_middleware( let includes = json_string_array(selector.get("include")); let excludes = json_string_array(selector.get("exclude")); !includes.is_empty() - && includes.iter().any(|pattern| host_matches(pattern, host)) - && !excludes.iter().any(|pattern| host_matches(pattern, host)) + && includes.iter().any(|pattern| { + openshell_policy::middleware_host_matches(pattern, host).unwrap_or(false) + }) + && !excludes.iter().any(|pattern| { + openshell_policy::middleware_host_matches(pattern, host).unwrap_or(false) + }) }) } @@ -1908,9 +1639,6 @@ fn proto_to_opa_data_json(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> St allow_all_known_mcp_methods.into(); } } - if !e.middleware.is_empty() { - ep["middleware"] = e.middleware.clone().into(); - } ep }) .collect(); @@ -1936,14 +1664,11 @@ fn proto_to_opa_data_json(proto: &ProtoSandboxPolicy, entrypoint_pid: u32) -> St entries }) .collect(); - let mut policy = serde_json::json!({ + let policy = serde_json::json!({ "name": rule.name, "endpoints": endpoints, "binaries": binaries, }); - if !rule.middleware.is_empty() { - policy["middleware"] = rule.middleware.clone().into(); - } (key.clone(), policy) }) .collect(); @@ -2058,7 +1783,6 @@ mod tests { path: "/usr/local/bin/claude".to_string(), ..Default::default() }], - ..Default::default() }, ); network_policies.insert( @@ -2074,7 +1798,6 @@ mod tests { path: "/usr/bin/glab".to_string(), ..Default::default() }], - ..Default::default() }, ); ProtoSandboxPolicy { @@ -3417,7 +3140,6 @@ network_policies: path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }, ); @@ -3490,7 +3212,6 @@ network_policies: path: "/usr/bin/curl".to_string(), ..Default::default() }], - middleware: vec![], }, ); @@ -3564,7 +3285,6 @@ network_policies: path: "/usr/bin/curl".to_string(), ..Default::default() }], - middleware: vec![], }, ); @@ -4443,7 +4163,6 @@ network_policies: path: "/usr/bin/node".to_string(), ..Default::default() }], - ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -4502,7 +4221,6 @@ network_policies: path: "/usr/bin/node".to_string(), ..Default::default() }], - ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -4562,7 +4280,6 @@ network_policies: path: "/usr/local/bin/claude".to_string(), ..Default::default() }], - middleware: vec![], }, ); let proto = ProtoSandboxPolicy { @@ -4624,7 +4341,6 @@ network_policies: path: "/usr/local/bin/aws".to_string(), ..Default::default() }], - middleware: vec![], }, ); let proto = ProtoSandboxPolicy { @@ -4685,7 +4401,6 @@ network_policies: path: "/usr/bin/node".to_string(), ..Default::default() }], - ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -5636,7 +5351,6 @@ process: ..Default::default() }], binaries: vec![proposal_binary], - ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -5692,7 +5406,6 @@ process: path: "/usr/bin/python".to_string(), ..Default::default() }], - ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -5764,7 +5477,6 @@ process: path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -5996,7 +5708,6 @@ network_policies: path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -6701,7 +6412,6 @@ network_policies: path: "/usr/bin/python3".to_string(), ..Default::default() }], - ..Default::default() }, ); @@ -7099,7 +6809,7 @@ network_policies: } #[test] - fn middleware_chain_orders_global_policy_endpoint_once() { + fn middleware_chain_uses_matching_selector_declaration_order() { let data = r#" network_middlewares: - name: global-redactor @@ -7108,18 +6818,20 @@ network_middlewares: include: ["api.example.com"] - name: policy-redactor middleware: openshell/secrets + endpoints: + include: ["api.example.com"] - name: endpoint-redactor middleware: openshell/secrets + endpoints: + include: ["api.example.com"] network_policies: api: name: api - middleware: ["global-redactor", "policy-redactor"] endpoints: - host: api.example.com port: 443 protocol: rest enforcement: enforce - middleware: ["policy-redactor", "endpoint-redactor"] rules: - allow: { method: POST, path: "/v1/**" } binaries: @@ -7135,7 +6847,7 @@ network_policies: cmdline_paths: vec![], }; let (chain, _) = engine - .query_middleware_chain_with_generation(&input, "/v1/messages") + .query_middleware_chain_with_generation(&input) .unwrap(); let names: Vec<_> = chain.iter().map(|entry| entry.name.as_str()).collect(); assert_eq!( @@ -7144,63 +6856,9 @@ network_policies: ); } - #[test] - fn middleware_validation_rejects_ambiguous_duplicate_endpoint_middleware() { - let data = r#" -network_middlewares: - - name: first-redactor - middleware: openshell/secrets - - name: second-redactor - middleware: openshell/secrets -network_policies: - api: - name: api - endpoints: - - host: api.example.com - port: 443 - protocol: rest - enforcement: enforce - middleware: ["first-redactor"] - access: full - - host: api.example.com - port: 443 - protocol: rest - enforcement: enforce - middleware: ["second-redactor"] - access: full - binaries: - - { path: /usr/bin/curl } -"#; - let err = match OpaEngine::from_strings(TEST_POLICY, data) { - Ok(_) => panic!("equivalent endpoints with different middleware should be invalid"), - Err(err) => err, - }; - assert!( - err.to_string() - .contains("equivalent middleware selection keys"), - "{err:?}" - ); - } - #[test] fn middleware_policy_validation_rejects_bad_configs() { let cases = [ - ( - "missing reference", - r#" -network_middlewares: - - name: redactor - middleware: openshell/secrets -network_policies: - api: - middleware: ["missing"] - endpoints: - - { host: api.example.com, port: 443 } - binaries: - - { path: /usr/bin/curl } -"#, - "unknown middleware config 'missing'", - ), ( "invalid on_error", r#" @@ -7208,6 +6866,8 @@ network_middlewares: - name: redactor middleware: openshell/secrets on_error: maybe + endpoints: + include: ["api.example.com"] "#, "invalid on_error", ), @@ -7217,8 +6877,12 @@ network_middlewares: network_middlewares: - name: redactor middleware: openshell/secrets + endpoints: + include: ["api.example.com"] - name: redactor middleware: openshell/secrets + endpoints: + include: ["api.example.com"] "#, "duplicate middleware config 'redactor'", ), @@ -7228,22 +6892,45 @@ network_middlewares: network_middlewares: - name: sigv4 middleware: openshell/sigv4 + endpoints: + include: ["api.example.com"] "#, "unsupported built-in", ), ( - "tls skip attachment", + "missing selector", + r#" +network_middlewares: + - name: redactor + middleware: openshell/secrets +"#, + "requires an endpoint selector", + ), + ( + "malformed selector", + r#" +network_middlewares: + - name: redactor + middleware: openshell/secrets + endpoints: + include: ["api[.example.com"] +"#, + "invalid host pattern", + ), + ( + "tls skip selector", r#" network_middlewares: - name: redactor middleware: openshell/secrets + endpoints: + include: ["api.example.com"] network_policies: api: endpoints: - host: api.example.com port: 443 tls: skip - middleware: ["redactor"] binaries: - { path: /usr/bin/curl } "#, @@ -7263,6 +6950,29 @@ network_policies: } } + #[test] + fn from_proto_revalidates_middleware_policy() { + let mut policy = openshell_policy::restrictive_default_policy(); + policy + .network_middlewares + .push(openshell_core::proto::NetworkMiddlewareConfig { + name: "redactor".into(), + middleware: "openshell/secrets".into(), + endpoints: Some(openshell_core::proto::MiddlewareEndpointSelector { + include: vec!["api[.example.com".into()], + exclude: Vec::new(), + }), + ..Default::default() + }); + + let error = OpaEngine::from_proto(&policy) + .err() + .expect("supervisor must reject invalid effective middleware policy") + .to_string(); + assert!(error.contains("policy validation failed"), "{error}"); + assert!(error.contains("invalid host pattern"), "{error}"); + } + #[test] fn l7_head_denied_when_only_post_allowed() { let engine = OpaEngine::from_strings( diff --git a/crates/openshell-supervisor-network/src/policy_local.rs b/crates/openshell-supervisor-network/src/policy_local.rs index fa8029c72..3cbc31502 100644 --- a/crates/openshell-supervisor-network/src/policy_local.rs +++ b/crates/openshell-supervisor-network/src/policy_local.rs @@ -1047,7 +1047,6 @@ fn network_rule_from_json( name: rule.name.unwrap_or_default(), endpoints, binaries, - middleware: Vec::new(), }) } @@ -1134,7 +1133,6 @@ fn network_endpoint_from_json( credential_signing: String::new(), signing_service: String::new(), signing_region: String::new(), - middleware: Vec::new(), }) } @@ -1831,7 +1829,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() }), ..Default::default() }; @@ -1856,7 +1853,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() } } @@ -1920,7 +1916,6 @@ mod tests { path: "/usr/bin/curl".to_string(), ..Default::default() }], - ..Default::default() })); }) }; diff --git a/crates/openshell-supervisor-network/src/proxy.rs b/crates/openshell-supervisor-network/src/proxy.rs index f8310fbdc..afec98666 100644 --- a/crates/openshell-supervisor-network/src/proxy.rs +++ b/crates/openshell-supervisor-network/src/proxy.rs @@ -4194,7 +4194,7 @@ async fn handle_forward_proxy( cmdline_paths: decision.cmdline_paths.clone(), }; let (chain, generation) = - opa_engine.query_middleware_chain_with_generation(&middleware_input, middleware_path)?; + opa_engine.query_middleware_chain_with_generation(&middleware_input)?; if generation != forward_generation_guard.captured_generation() { emit_l7_tunnel_close_after_policy_change( &host_lc, @@ -4224,10 +4224,11 @@ async fn handle_forward_proxy( &upstream_target, forward_request_bytes, )?; - forward_request_bytes = match crate::l7::relay::apply_middleware_chain( + forward_request_bytes = match crate::l7::relay::apply_middleware_chain_for_scheme( request, client, &l7_ctx, + &scheme, chain, &forward_generation_guard, ) diff --git a/proto/middleware.proto b/proto/middleware.proto index d5d2ad48d..2944227d8 100644 --- a/proto/middleware.proto +++ b/proto/middleware.proto @@ -25,6 +25,8 @@ message MiddlewareBinding { string id = 1; string operation = 2; string phase = 3; + // Maximum request or replacement body this binding can process. + uint64 max_body_bytes = 4; } message ValidateConfigRequest { diff --git a/proto/sandbox.proto b/proto/sandbox.proto index a73d762e5..04cbd6776 100644 --- a/proto/sandbox.proto +++ b/proto/sandbox.proto @@ -63,11 +63,9 @@ message NetworkPolicyRule { repeated NetworkEndpoint endpoints = 2; // Allowed binary identities. repeated NetworkBinary binaries = 3; - // Ordered middleware configs applied to every endpoint in this policy. - repeated string middleware = 4; } -// A reusable middleware config referenced by network policies/endpoints. +// A reusable middleware config selected for admitted egress by host. message NetworkMiddlewareConfig { // Policy-local config name. string name = 1; @@ -77,7 +75,7 @@ message NetworkMiddlewareConfig { google.protobuf.Struct config = 3; // Failure behavior: "fail_closed" (default) or "fail_open". string on_error = 4; - // Optional global endpoint selector for this config. + // Host selector controlling which admitted destinations use this config. MiddlewareEndpointSelector endpoints = 5; } @@ -168,8 +166,6 @@ message NetworkEndpoint { uint32 json_rpc_max_body_bytes = 22; // MCP-only policy and inspection options. Only used when protocol is "mcp". McpOptions mcp = 23; - // Ordered middleware configs applied to this endpoint after policy-level middleware. - repeated string middleware = 24; } // MCP options are grouped so MCP-specific policy can grow without adding more From f0004a7dcf423d5973d6fba705883c4535735a68 Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Mon, 29 Jun 2026 15:28:14 -0700 Subject: [PATCH 07/16] feat(supervisor-middleware): support external services Signed-off-by: Piotr Mlocek --- Cargo.lock | 3 + architecture/gateway.md | 8 + architecture/sandbox.md | 10 + crates/openshell-core/src/grpc_client.rs | 41 +- crates/openshell-sandbox/Cargo.toml | 1 + crates/openshell-sandbox/src/lib.rs | 68 +- crates/openshell-server/Cargo.toml | 1 + crates/openshell-server/src/config_file.rs | 55 ++ crates/openshell-server/src/grpc/policy.rs | 71 ++- crates/openshell-server/src/grpc/sandbox.rs | 1 + crates/openshell-server/src/lib.rs | 25 + crates/openshell-server/src/middleware.rs | 43 ++ .../Cargo.toml | 5 +- .../src/lib.rs | 601 ++++++++++++++++-- .../src/remote.rs | 91 +++ .../src/l7/relay.rs | 266 ++++---- .../openshell-supervisor-network/src/opa.rs | 34 +- .../openshell-supervisor-network/src/proxy.rs | 2 + docs/reference/gateway-config.mdx | 28 + docs/reference/policy-schema.mdx | 32 +- docs/sandboxes/policies.mdx | 36 +- proto/sandbox.proto | 16 + 22 files changed, 1246 insertions(+), 192 deletions(-) create mode 100644 crates/openshell-server/src/middleware.rs create mode 100644 crates/openshell-supervisor-middleware/src/remote.rs diff --git a/Cargo.lock b/Cargo.lock index 2de3c0353..eadbef8f6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3831,6 +3831,7 @@ dependencies = [ "openshell-core", "openshell-ocsf", "openshell-policy", + "openshell-supervisor-middleware", "openshell-supervisor-network", "openshell-supervisor-process", "rustls", @@ -3885,6 +3886,7 @@ dependencies = [ "openshell-providers", "openshell-router", "openshell-server-macros", + "openshell-supervisor-middleware", "petname", "pin-project-lite", "prost", @@ -3936,6 +3938,7 @@ dependencies = [ "prost-types", "regex", "tokio", + "tokio-stream", "tonic", ] diff --git a/architecture/gateway.md b/architecture/gateway.md index d873b2a10..ba8437b7f 100644 --- a/architecture/gateway.md +++ b/architecture/gateway.md @@ -271,6 +271,14 @@ config path. A gateway-global policy can override sandbox-scoped policy. The sandbox supervisor polls for config revisions and hot-reloads dynamic policy when the policy engine accepts the update. +External supervisor middleware registration is operator-owned gateway +configuration. At startup the gateway connects to each service, validates its +described bindings and operator body limit, and rejects duplicate binding IDs. +Before persisting a policy, the gateway asks each selected implementation to +validate its config. The effective sandbox config contains only the registered +services required by that policy; supervisors invoke those services directly on +the request path. + Provider credential expiry is enforced during gateway-to-sandbox credential resolution and again by the sandbox placeholder resolver. This keeps expired credentials from resolving even when a running sandbox still has retained diff --git a/architecture/sandbox.md b/architecture/sandbox.md index 580d8f96d..deec00f32 100644 --- a/architecture/sandbox.md +++ b/architecture/sandbox.md @@ -66,6 +66,14 @@ matchers; generic JSON-RPC rules match only the method. JSON-RPC responses and server-to-client MCP messages on response or SSE streams are relayed but are not currently parsed for policy enforcement. +For admitted HTTP requests, the proxy can run an ordered supervisor middleware +chain before credential injection. Host selectors choose the chain independently +of the network rule that admitted the request. Built-ins run in-process; +operator-registered external services are called directly from the supervisor +over the common middleware gRPC contract. The gateway validates external +service capabilities and policy-owned config before delivery. Supervisors keep +the last-known-good service registry when a live config reload fails. + `https://inference.local` is special. It bypasses OPA network policy and is handled by the inference interception path: @@ -176,6 +184,8 @@ quickly. - If gateway config polling fails, the sandbox keeps its last-known-good policy. - If a live policy update is invalid, the supervisor rejects it and keeps the current policy. +- If an external middleware call fails, the selected config's `on_error` + behavior decides whether to deny the request or continue without that stage. - Existing raw byte streams are connection scoped. Dynamic policy changes apply to new connections or the next parsed HTTP request where the proxy can safely re-evaluate. diff --git a/crates/openshell-core/src/grpc_client.rs b/crates/openshell-core/src/grpc_client.rs index 96158a1d1..57f72dca6 100644 --- a/crates/openshell-core/src/grpc_client.rs +++ b/crates/openshell-core/src/grpc_client.rs @@ -24,11 +24,11 @@ use std::time::{Duration, SystemTime, UNIX_EPOCH}; use crate::proto::{ DenialSummary, GetDraftPolicyRequest, GetInferenceBundleRequest, GetInferenceBundleResponse, - GetSandboxConfigRequest, GetSandboxProviderEnvironmentRequest, IssueSandboxTokenRequest, - NetworkActivitySummary, PolicyChunk, PolicySource, PolicyStatus, RefreshSandboxTokenRequest, - ReportPolicyStatusRequest, SandboxPolicy as ProtoSandboxPolicy, SubmitPolicyAnalysisRequest, - SubmitPolicyAnalysisResponse, UpdateConfigRequest, inference_client::InferenceClient, - open_shell_client::OpenShellClient, + GetSandboxConfigRequest, GetSandboxConfigResponse, GetSandboxProviderEnvironmentRequest, + IssueSandboxTokenRequest, NetworkActivitySummary, PolicyChunk, PolicySource, PolicyStatus, + RefreshSandboxTokenRequest, ReportPolicyStatusRequest, SandboxPolicy as ProtoSandboxPolicy, + SubmitPolicyAnalysisRequest, SubmitPolicyAnalysisResponse, UpdateConfigRequest, + inference_client::InferenceClient, open_shell_client::OpenShellClient, }; use crate::sandbox_env; use miette::{IntoDiagnostic, Result, WrapErr}; @@ -573,19 +573,36 @@ pub async fn fetch_policy(endpoint: &str, sandbox_id: &str) -> Result Result { + debug!(endpoint = %endpoint, sandbox_id = %sandbox_id, "Connecting to OpenShell server"); + let mut client = connect(endpoint).await?; + fetch_sandbox_config_with_client(&mut client, sandbox_id).await +} + +async fn fetch_sandbox_config_with_client( client: &mut OpenShellClient, sandbox_id: &str, -) -> Result> { - let response = client +) -> Result { + client .get_sandbox_config(GetSandboxConfigRequest { sandbox_id: sandbox_id.to_string(), }) .await - .into_diagnostic()?; + .map(tonic::Response::into_inner) + .into_diagnostic() +} - let inner = response.into_inner(); +/// Fetch sandbox policy using an existing client connection. +async fn fetch_policy_with_client( + client: &mut OpenShellClient, + sandbox_id: &str, +) -> Result> { + let inner = fetch_sandbox_config_with_client(client, sandbox_id).await?; // version 0 with no policy means the sandbox was created without one. if inner.version == 0 && inner.policy.is_none() { @@ -711,6 +728,7 @@ pub struct SettingsPollResult { /// When `policy_source` is `Global`, the version of the global policy revision. pub global_policy_version: u32, pub provider_env_revision: u64, + pub external_middleware: Vec, } pub struct ProviderEnvironmentResult { @@ -755,6 +773,7 @@ impl CachedOpenShellClient { settings: inner.settings, global_policy_version: inner.global_policy_version, provider_env_revision: inner.provider_env_revision, + external_middleware: inner.external_middleware, }) } diff --git a/crates/openshell-sandbox/Cargo.toml b/crates/openshell-sandbox/Cargo.toml index 086dbe02c..d3c3e7108 100644 --- a/crates/openshell-sandbox/Cargo.toml +++ b/crates/openshell-sandbox/Cargo.toml @@ -19,6 +19,7 @@ openshell-core = { path = "../openshell-core", default-features = false } openshell-ocsf = { path = "../openshell-ocsf" } openshell-policy = { path = "../openshell-policy" } openshell-supervisor-network = { path = "../openshell-supervisor-network" } +openshell-supervisor-middleware = { path = "../openshell-supervisor-middleware" } openshell-supervisor-process = { path = "../openshell-supervisor-process" } # Async runtime diff --git a/crates/openshell-sandbox/src/lib.rs b/crates/openshell-sandbox/src/lib.rs index d5967d1f3..af173abc3 100644 --- a/crates/openshell-sandbox/src/lib.rs +++ b/crates/openshell-sandbox/src/lib.rs @@ -1376,12 +1376,12 @@ async fn load_policy( endpoint = %endpoint, "Fetching sandbox policy via gRPC" ); - let proto_policy = grpc_retry("Policy fetch", || { - openshell_core::grpc_client::fetch_policy(endpoint, id) + let mut sandbox_config = grpc_retry("Policy fetch", || { + openshell_core::grpc_client::fetch_sandbox_config(endpoint, id) }) .await?; - let mut proto_policy = if let Some(p) = proto_policy { + let mut proto_policy = if let Some(p) = sandbox_config.policy.take() { p } else { // No policy configured on the server. Discover from disk or @@ -1409,7 +1409,7 @@ async fn load_policy( // Sync and re-fetch over a single connection to avoid extra // TLS handshakes. - grpc_retry("Policy discovery sync", || { + let synced = grpc_retry("Policy discovery sync", || { openshell_core::grpc_client::discover_and_sync_policy( endpoint, id, @@ -1417,7 +1417,12 @@ async fn load_policy( &discovered, ) }) - .await? + .await?; + sandbox_config = grpc_retry("Policy refetch after discovery", || { + openshell_core::grpc_client::fetch_sandbox_config(endpoint, id) + }) + .await?; + sandbox_config.policy.take().unwrap_or(synced) }; // Ensure baseline filesystem paths are present for proxy-mode @@ -1443,7 +1448,14 @@ async fn load_policy( // container hasn't started yet. After the entrypoint spawns, the // engine is rebuilt with the real PID for symlink resolution. info!("Creating OPA engine from proto policy data"); - let opa_engine = Some(Arc::new(OpaEngine::from_proto(&proto_policy)?)); + let engine = OpaEngine::from_proto(&proto_policy)?; + let middleware_registry = + openshell_supervisor_middleware::MiddlewareRegistry::connect_external( + sandbox_config.external_middleware, + ) + .await?; + engine.replace_middleware_registry(middleware_registry)?; + let opa_engine = Some(Arc::new(engine)); let policy = SandboxPolicy::try_from(proto_policy.clone())?; return Ok((policy, opa_engine, Some(proto_policy))); @@ -1593,6 +1605,7 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { let mut current_config_revision: u64 = 0; let mut current_provider_env_revision: u64 = ctx.provider_credentials.snapshot().revision; let mut current_policy_hash = String::new(); + let mut current_external_middleware = Vec::new(); let mut current_settings: std::collections::HashMap< String, openshell_core::proto::EffectiveSetting, @@ -1604,6 +1617,7 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { apply_ocsf_json_setting(&ctx.ocsf_enabled, &result.settings); current_config_revision = result.config_revision; current_policy_hash = result.policy_hash.clone(); + current_external_middleware = result.external_middleware; current_settings = result.settings; debug!( config_revision = current_config_revision, @@ -1633,6 +1647,7 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { } let policy_changed = result.policy_hash != current_policy_hash; + let middleware_changed = result.external_middleware != current_external_middleware; // Log which settings changed. log_setting_changes(¤t_settings, &result.settings); @@ -1691,6 +1706,47 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { } } + if middleware_changed { + match openshell_supervisor_middleware::MiddlewareRegistry::connect_external( + result.external_middleware.clone(), + ) + .await + .and_then(|registry| ctx.opa_engine.replace_middleware_registry(registry)) + { + Ok(()) => { + current_external_middleware = result.external_middleware.clone(); + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "loaded") + .unmapped( + "external_middleware_count", + serde_json::json!(current_external_middleware.len()) + ) + .message(format!( + "External middleware registry reloaded [service_count:{}]", + current_external_middleware.len() + )) + .build() + ); + } + Err(error) => { + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .state(StateId::Other, "failed") + .message(format!( + "External middleware registry reload failed, keeping last-known-good registry [error:{error}]" + )) + .build() + ); + continue; + } + } + } + // Only reload OPA when the policy payload actually changed. if policy_changed { let Some(policy) = result.policy.as_ref() else { diff --git a/crates/openshell-server/Cargo.toml b/crates/openshell-server/Cargo.toml index b5c9b34d7..fafc72ba7 100644 --- a/crates/openshell-server/Cargo.toml +++ b/crates/openshell-server/Cargo.toml @@ -26,6 +26,7 @@ openshell-prover = { path = "../openshell-prover" } openshell-providers = { path = "../openshell-providers" } openshell-router = { path = "../openshell-router" } openshell-server-macros = { path = "../openshell-server-macros" } +openshell-supervisor-middleware = { path = "../openshell-supervisor-middleware" } # Kubernetes client (used by the `generate-certs` subcommand) kube = { workspace = true } diff --git a/crates/openshell-server/src/config_file.rs b/crates/openshell-server/src/config_file.rs index b65b5f3b0..13c7e9ebb 100644 --- a/crates/openshell-server/src/config_file.rs +++ b/crates/openshell-server/src/config_file.rs @@ -25,6 +25,7 @@ use std::net::SocketAddr; use std::path::{Path, PathBuf}; use openshell_core::config::ComputeDriverKind; +use openshell_core::proto::ExternalMiddlewareService; use openshell_core::{GatewayAuthConfig, GatewayJwtConfig, MtlsAuthConfig, OidcConfig, TlsConfig}; use serde::{Deserialize, Serialize}; @@ -151,6 +152,12 @@ pub struct GatewayFileSection { #[serde(default)] pub gateway_jwt: Option, + // ── Supervisor middleware ───────────────────────────────────────────── + /// Statically registered external middleware services. Registration is + /// operator-owned and changes require a gateway restart. + #[serde(default)] + pub middleware: Vec, + // ── Disallowed-in-file fields ──────────────────────────────────────── // // Captured so we can produce a friendly "set this via env/CLI instead" @@ -160,6 +167,32 @@ pub struct GatewayFileSection { pub database_url: Option, } +/// One `[[openshell.gateway.middleware]]` external middleware registration. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct MiddlewareServiceFileConfig { + /// Operator-facing name used for diagnostics. + pub name: String, + /// Plaintext gRPC endpoint reachable by the gateway and supervisors. + pub endpoint: String, + /// Required explicit opt-in to the local-development-only insecure mode. + #[serde(default)] + pub allow_insecure: bool, + /// Operator-owned body limit for every binding exposed by this service. + pub max_body_bytes: u64, +} + +impl From<&MiddlewareServiceFileConfig> for ExternalMiddlewareService { + fn from(config: &MiddlewareServiceFileConfig) -> Self { + Self { + name: config.name.clone(), + endpoint: config.endpoint.clone(), + allow_insecure: config.allow_insecure, + max_body_bytes: config.max_body_bytes, + } + } +} + #[derive(Debug, thiserror::Error)] pub enum ConfigFileError { #[error("failed to read gateway config file '{}': {source}", path.display())] @@ -401,6 +434,28 @@ allow_unauthenticated_users = true assert!(auth.allow_unauthenticated_users); } + #[test] + fn parses_external_middleware_registration() { + let toml = r#" +[[openshell.gateway.middleware]] +name = "local-guard" +endpoint = "http://127.0.0.1:50051" +allow_insecure = true +max_body_bytes = 262144 +"#; + let tmp = write_tmp(toml); + let file = load(tmp.path()).expect("valid middleware registration parses"); + assert_eq!( + file.openshell.gateway.middleware, + vec![MiddlewareServiceFileConfig { + name: "local-guard".into(), + endpoint: "http://127.0.0.1:50051".into(), + allow_insecure: true, + max_body_bytes: 262_144, + }] + ); + } + #[test] fn rejects_database_url_in_file() { let toml = r#" diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index d3bc213ba..c29de71ff 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -1218,8 +1218,27 @@ pub(super) async fn handle_get_sandbox_config( } } + if let Some(policy) = policy.as_ref() { + state + .middleware_registry + .ensure_policy_bindings_registered(policy) + .map_err(|error| { + Status::failed_precondition(format!( + "effective policy middleware registration is invalid: {error}" + )) + })?; + } + let settings = merge_effective_settings(&global_settings, &sandbox_settings)?; - let config_revision = compute_config_revision(policy.as_ref(), &settings, policy_source); + let external_middleware = state + .middleware_registry + .required_external_services(policy.as_ref()); + let config_revision = compute_config_revision( + policy.as_ref(), + &settings, + policy_source, + &external_middleware, + ); let provider_env_revision = compute_provider_env_revision(state.store.as_ref(), &sandbox_provider_names).await?; @@ -1232,6 +1251,7 @@ pub(super) async fn handle_get_sandbox_config( policy_source: policy_source.into(), global_policy_version, provider_env_revision, + external_middleware, })) } @@ -1510,6 +1530,8 @@ async fn handle_update_config_inner( openshell_policy::ensure_sandbox_process_identity(&mut new_policy); validate_no_reserved_provider_policy_keys(&new_policy)?; validate_policy_safety(&new_policy)?; + crate::middleware::validate_policy(state.middleware_registry.as_ref(), &new_policy) + .await?; let payload = new_policy.encode_to_vec(); let hash = deterministic_policy_hash(&new_policy); @@ -1827,9 +1849,11 @@ async fn handle_update_config_inner( validate_no_reserved_provider_policy_keys(&new_policy)?; } + validate_policy_safety(&new_policy)?; + crate::middleware::validate_policy(state.middleware_registry.as_ref(), &new_policy).await?; + if let Some(baseline_policy) = spec.policy.as_ref() { validate_static_fields_unchanged(baseline_policy, &new_policy)?; - validate_policy_safety(&new_policy)?; } else { // Backfill spec.policy using CAS (first-time policy discovery) let _sandbox_sync_guard = state.compute.sandbox_sync_guard().await; @@ -3120,6 +3144,7 @@ fn compute_config_revision( policy: Option<&ProtoSandboxPolicy>, settings: &HashMap, policy_source: PolicySource, + external_middleware: &[openshell_core::proto::ExternalMiddlewareService], ) -> u64 { let mut hasher = Sha256::new(); hasher.update((policy_source as i32).to_le_bytes()); @@ -3152,6 +3177,11 @@ fn compute_config_revision( } } } + let mut middleware = external_middleware.iter().collect::>(); + middleware.sort_by(|left, right| left.name.cmp(&right.name)); + for service in middleware { + hasher.update(service.encode_to_vec()); + } let digest = hasher.finalize(); let mut bytes = [0_u8; 8]; @@ -8989,7 +9019,7 @@ mod tests { }, ); - let rev_a = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox); + let rev_a = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox, &[]); settings.insert( "mode".to_string(), EffectiveSetting { @@ -8999,7 +9029,7 @@ mod tests { scope: SettingScope::Sandbox.into(), }, ); - let rev_b = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox); + let rev_b = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox, &[]); assert_ne!(rev_a, rev_b); } @@ -9264,8 +9294,8 @@ mod tests { }, ); - let rev_a = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox); - let rev_b = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox); + let rev_a = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox, &[]); + let rev_b = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox, &[]); assert_eq!(rev_a, rev_b); } @@ -9281,8 +9311,8 @@ mod tests { }; let settings = HashMap::new(); - let rev_a = compute_config_revision(Some(&policy_a), &settings, PolicySource::Sandbox); - let rev_b = compute_config_revision(Some(&policy_b), &settings, PolicySource::Sandbox); + let rev_a = compute_config_revision(Some(&policy_a), &settings, PolicySource::Sandbox, &[]); + let rev_b = compute_config_revision(Some(&policy_b), &settings, PolicySource::Sandbox, &[]); assert_ne!(rev_a, rev_b); } @@ -9291,11 +9321,28 @@ mod tests { let policy = ProtoSandboxPolicy::default(); let settings = HashMap::new(); - let rev_a = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox); - let rev_b = compute_config_revision(Some(&policy), &settings, PolicySource::Global); + let rev_a = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox, &[]); + let rev_b = compute_config_revision(Some(&policy), &settings, PolicySource::Global, &[]); assert_ne!(rev_a, rev_b); } + #[test] + fn config_revision_changes_when_external_middleware_changes() { + let policy = ProtoSandboxPolicy::default(); + let settings = HashMap::new(); + let service = openshell_core::proto::ExternalMiddlewareService { + name: "local-guard".into(), + endpoint: "http://127.0.0.1:50051".into(), + allow_insecure: true, + max_body_bytes: 1024, + }; + + let without = compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox, &[]); + let with = + compute_config_revision(Some(&policy), &settings, PolicySource::Sandbox, &[service]); + assert_ne!(without, with); + } + #[test] fn config_revision_without_policy_still_hashes_settings() { let mut settings = HashMap::new(); @@ -9309,7 +9356,7 @@ mod tests { }, ); - let rev_a = compute_config_revision(None, &settings, PolicySource::Sandbox); + let rev_a = compute_config_revision(None, &settings, PolicySource::Sandbox, &[]); settings.insert( "log_level".to_string(), @@ -9321,7 +9368,7 @@ mod tests { }, ); - let rev_b = compute_config_revision(None, &settings, PolicySource::Sandbox); + let rev_b = compute_config_revision(None, &settings, PolicySource::Sandbox, &[]); assert_ne!(rev_a, rev_b); } diff --git a/crates/openshell-server/src/grpc/sandbox.rs b/crates/openshell-server/src/grpc/sandbox.rs index 04d5a4ed5..203cd7dbe 100644 --- a/crates/openshell-server/src/grpc/sandbox.rs +++ b/crates/openshell-server/src/grpc/sandbox.rs @@ -164,6 +164,7 @@ async fn handle_create_sandbox_inner( openshell_policy::ensure_sandbox_process_identity(policy); validate_no_reserved_provider_policy_keys(policy)?; validate_policy_safety(policy)?; + crate::middleware::validate_policy(state.middleware_registry.as_ref(), policy).await?; } let id = uuid::Uuid::new_v4().to_string(); diff --git a/crates/openshell-server/src/lib.rs b/crates/openshell-server/src/lib.rs index 6462ccbbf..bca8abe2e 100644 --- a/crates/openshell-server/src/lib.rs +++ b/crates/openshell-server/src/lib.rs @@ -32,6 +32,7 @@ mod defaults; mod grpc; mod http; mod inference; +mod middleware; mod multiplex; mod persistence; pub(crate) mod policy_store; @@ -53,6 +54,8 @@ mod ws_tunnel; use metrics_exporter_prometheus::PrometheusBuilder; use openshell_core::{ComputeDriverKind, Config, Error, Result}; +use openshell_supervisor_middleware::MiddlewareRegistry; +use serde::Deserialize; use std::collections::HashMap; use std::io::ErrorKind; use std::net::SocketAddr; @@ -126,6 +129,9 @@ pub struct ServerState { /// query session state to surface supervisor readiness. pub supervisor_sessions: Arc, + /// Validated built-in and operator-registered supervisor middleware. + pub middleware_registry: Arc, + /// OIDC JWKS cache for JWT validation. `None` when OIDC is not configured. pub oidc_cache: Option>, @@ -192,6 +198,7 @@ impl ServerState { ssh_connections_by_sandbox: Mutex::new(HashMap::new()), settings_mutex: tokio::sync::Mutex::new(()), supervisor_sessions, + middleware_registry: Arc::new(MiddlewareRegistry::default()), oidc_cache, sandbox_jwt_issuer: None, sandbox_jwt_authenticator: None, @@ -223,6 +230,23 @@ pub(crate) async fn run_server( return Err(Error::config("database_url is required")); } + let middleware_registrations = config_file + .as_ref() + .map(|file| { + file.openshell + .gateway + .middleware + .iter() + .map(Into::into) + .collect() + }) + .unwrap_or_default(); + let middleware_registry = Arc::new( + MiddlewareRegistry::connect_external(middleware_registrations) + .await + .map_err(|error| Error::config(format!("middleware registration failed: {error}")))?, + ); + let store = Arc::new(Store::connect(database_url).await?); let oidc_cache = if let Some(ref oidc) = config.oidc { @@ -273,6 +297,7 @@ pub(crate) async fn run_server( supervisor_sessions, oidc_cache, ); + state.middleware_registry = middleware_registry; // Load the gateway-minted sandbox JWT signing key when configured. // Optional so single-driver dev deployments without certgen continue diff --git a/crates/openshell-server/src/middleware.rs b/crates/openshell-server/src/middleware.rs new file mode 100644 index 000000000..4c94f021a --- /dev/null +++ b/crates/openshell-server/src/middleware.rs @@ -0,0 +1,43 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use openshell_core::proto::SandboxPolicy; +use openshell_supervisor_middleware::MiddlewareRegistry; +use tonic::Status; + +/// Validate implementation-owned middleware config before accepting a policy. +pub async fn validate_policy( + registry: &MiddlewareRegistry, + policy: &SandboxPolicy, +) -> Result<(), Status> { + registry + .validate_policy_configs(policy) + .await + .map_err(|error| { + Status::invalid_argument(format!("policy middleware validation failed: {error}")) + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use openshell_core::proto::NetworkMiddlewareConfig; + + #[tokio::test] + async fn unregistered_external_binding_is_rejected_before_admission() { + let policy = SandboxPolicy { + network_middlewares: vec![NetworkMiddlewareConfig { + name: "guard".into(), + middleware: "example/content-guard".into(), + ..Default::default() + }], + ..Default::default() + }; + + let error = validate_policy(&MiddlewareRegistry::default(), &policy) + .await + .expect_err("unregistered binding must fail"); + assert_eq!(error.code(), tonic::Code::InvalidArgument); + assert!(error.message().contains("not registered")); + } +} diff --git a/crates/openshell-supervisor-middleware/Cargo.toml b/crates/openshell-supervisor-middleware/Cargo.toml index 4ae355894..e5e53618d 100644 --- a/crates/openshell-supervisor-middleware/Cargo.toml +++ b/crates/openshell-supervisor-middleware/Cargo.toml @@ -17,7 +17,10 @@ miette = { workspace = true } prost-types = { workspace = true } regex = { workspace = true } tokio = { workspace = true } -tonic = { workspace = true } +tonic = { workspace = true, features = ["channel", "server"] } + +[dev-dependencies] +tokio-stream = { workspace = true, features = ["net"] } [lints] workspace = true diff --git a/crates/openshell-supervisor-middleware/src/lib.rs b/crates/openshell-supervisor-middleware/src/lib.rs index a9cb52434..828179d18 100644 --- a/crates/openshell-supervisor-middleware/src/lib.rs +++ b/crates/openshell-supervisor-middleware/src/lib.rs @@ -1,9 +1,10 @@ // SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 -//! In-process supervisor middleware chain execution. +//! Supervisor middleware registration and chain execution. mod builtins; +mod remote; mod service; use std::collections::{BTreeMap, HashMap, HashSet}; @@ -14,14 +15,17 @@ pub use service::InProcessMiddlewareService; use openshell_core::proto::middleware::v1::supervisor_middleware_server::SupervisorMiddleware; use openshell_core::proto::{ - Decision, Finding, HttpRequestEvaluation, HttpRequestTarget, MiddlewareBinding, - MiddlewareManifest, NetworkMiddlewareConfig, RequestContext, + Decision, ExternalMiddlewareService, Finding, HttpRequestEvaluation, HttpRequestTarget, + MiddlewareBinding, MiddlewareManifest, NetworkMiddlewareConfig, RequestContext, SandboxPolicy, + ValidateConfigRequest, }; use tokio::sync::OnceCell; use tonic::Request; pub const API_VERSION: &str = "openshell.middleware.v1"; pub const BUILTIN_SECRETS: &str = builtins::secrets::BINDING_ID; +const HTTP_REQUEST_OPERATION: &str = "HttpRequest"; +const PRE_CREDENTIALS_PHASE: &str = "pre_credentials"; /// Validate the configuration for an in-process middleware implementation. /// @@ -82,9 +86,10 @@ impl TryFrom<&NetworkMiddlewareConfig> for ChainEntry { /// A policy-selected middleware config joined with metadata reported by its /// service's `Describe` call. A missing binding is retained so `on_error` can /// decide whether the request fails open or closed. -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct DescribedChainEntry { entry: ChainEntry, + service: Option>, binding: Option, max_body_bytes: usize, } @@ -182,82 +187,380 @@ fn apply_on_error( #[derive(Clone)] pub struct ChainRunner { - state: Arc, + registry: Arc, } struct MiddlewareServiceState { service: Arc, manifest: OnceCell, + operator_max_body_bytes: Option, } static IN_PROCESS_SERVICE: LazyLock> = LazyLock::new(|| { Arc::new(MiddlewareServiceState { service: Arc::new(InProcessMiddlewareService), manifest: OnceCell::new(), + operator_max_body_bytes: None, }) }); -impl Default for ChainRunner { +/// Validated middleware services available to a gateway or one supervisor. +/// +/// The registry always contains the in-process built-ins. External services +/// are connected and described before construction succeeds, so callers never +/// observe a partially registered service set. +#[derive(Clone)] +pub struct MiddlewareRegistry { + services: Arc>>, + external: Arc>, +} + +impl std::fmt::Debug for MiddlewareRegistry { + fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter + .debug_struct("MiddlewareRegistry") + .field("service_count", &self.services.len()) + .field("external_count", &self.external.len()) + .finish() + } +} + +#[derive(Clone)] +struct RegisteredExternalService { + registration: ExternalMiddlewareService, + binding_ids: Vec, +} + +impl Default for MiddlewareRegistry { fn default() -> Self { Self { - state: Arc::clone(&IN_PROCESS_SERVICE), + services: Arc::new(vec![Arc::clone(&IN_PROCESS_SERVICE)]), + external: Arc::new(Vec::new()), } } } +fn validate_registration(registration: &ExternalMiddlewareService) -> Result<()> { + if registration.name.trim().is_empty() { + return Err(miette!( + "external middleware registration name cannot be empty" + )); + } + if !registration.allow_insecure { + return Err(miette!( + "middleware registration '{}' must set allow_insecure = true; TLS is not supported in V1", + registration.name + )); + } + if !registration.endpoint.starts_with("http://") { + return Err(miette!( + "middleware registration '{}' endpoint must use http:// in the local-development-only V1", + registration.name + )); + } + if registration.max_body_bytes == 0 { + return Err(miette!( + "middleware registration '{}' max_body_bytes must be greater than zero", + registration.name + )); + } + Ok(()) +} + +fn validate_external_manifest( + registration: &ExternalMiddlewareService, + manifest: &MiddlewareManifest, + operator_max_body_bytes: usize, + known_binding_ids: &mut HashSet, +) -> Result> { + if manifest.api_version != API_VERSION { + return Err(miette!( + "middleware registration '{}' reports unsupported API version '{}'", + registration.name, + manifest.api_version + )); + } + if manifest.bindings.is_empty() { + return Err(miette!( + "middleware registration '{}' describes no bindings", + registration.name + )); + } + + let mut described_ids = Vec::with_capacity(manifest.bindings.len()); + for binding in &manifest.bindings { + if binding.id.trim().is_empty() { + return Err(miette!( + "middleware registration '{}' describes an empty binding id", + registration.name + )); + } + if binding.id.starts_with("openshell/") { + return Err(miette!( + "external middleware registration '{}' cannot claim reserved binding '{}'", + registration.name, + binding.id + )); + } + if binding.operation != HTTP_REQUEST_OPERATION || binding.phase != PRE_CREDENTIALS_PHASE { + return Err(miette!( + "middleware binding '{}' must support {HTTP_REQUEST_OPERATION}/{PRE_CREDENTIALS_PHASE}", + binding.id + )); + } + let advertised = usize::try_from(binding.max_body_bytes).map_err(|_| { + miette!( + "middleware binding '{}' reports a body limit too large for this platform", + binding.id + ) + })?; + if advertised == 0 { + return Err(miette!( + "middleware binding '{}' must advertise a non-zero body limit", + binding.id + )); + } + if operator_max_body_bytes > advertised { + return Err(miette!( + "middleware registration '{}' max_body_bytes ({operator_max_body_bytes}) exceeds binding '{}' capability ({advertised})", + registration.name, + binding.id + )); + } + if !known_binding_ids.insert(binding.id.clone()) { + return Err(miette!( + "middleware binding '{}' is described by more than one service", + binding.id + )); + } + described_ids.push(binding.id.clone()); + } + Ok(described_ids) +} + +impl MiddlewareRegistry { + /// Connect and validate every external service registration. + pub async fn connect_external(registrations: Vec) -> Result { + let mut services = vec![Arc::clone(&IN_PROCESS_SERVICE)]; + let mut external = Vec::with_capacity(registrations.len()); + let mut registration_names = HashSet::new(); + let mut binding_ids = HashSet::from([BUILTIN_SECRETS.to_string()]); + + for registration in registrations { + validate_registration(®istration)?; + if !registration_names.insert(registration.name.clone()) { + return Err(miette!( + "duplicate external middleware registration name '{}'", + registration.name + )); + } + + let operator_max_body_bytes = + usize::try_from(registration.max_body_bytes).map_err(|_| { + miette!( + "middleware registration '{}' body limit is too large for this platform", + registration.name + ) + })?; + let service = Arc::new( + remote::RemoteMiddlewareService::connect( + ®istration.name, + ®istration.endpoint, + ) + .await?, + ); + let manifest = service + .describe(Request::new(())) + .await + .map(tonic::Response::into_inner) + .map_err(|error| { + miette!( + "middleware registration '{}' Describe failed: {}", + registration.name, + safe_reason(&error.to_string()) + ) + })?; + let described_ids = validate_external_manifest( + ®istration, + &manifest, + operator_max_body_bytes, + &mut binding_ids, + )?; + let manifest_cell = OnceCell::new(); + manifest_cell + .set(manifest) + .map_err(|_| miette!("middleware manifest cache initialized twice"))?; + services.push(Arc::new(MiddlewareServiceState { + service, + manifest: manifest_cell, + operator_max_body_bytes: Some(operator_max_body_bytes), + })); + external.push(RegisteredExternalService { + registration, + binding_ids: described_ids, + }); + } + + Ok(Self { + services: Arc::new(services), + external: Arc::new(external), + }) + } + + /// Validate implementation-owned configuration for every middleware entry. + pub async fn validate_policy_configs(&self, policy: &SandboxPolicy) -> Result<()> { + let runner = ChainRunner::from_registry(self.clone()); + for config in &policy.network_middlewares { + runner + .validate_config( + &config.middleware, + config.config.clone().unwrap_or_default(), + ) + .await + .map_err(|error| { + miette!( + "middleware config '{}' is invalid: {}", + config.name, + safe_reason(&error.to_string()) + ) + })?; + } + Ok(()) + } + + /// Check that every policy binding still belongs to the current static + /// registry without making a network call. + pub fn ensure_policy_bindings_registered(&self, policy: &SandboxPolicy) -> Result<()> { + for config in &policy.network_middlewares { + let registered = config.middleware == BUILTIN_SECRETS + || self.external.iter().any(|service| { + service + .binding_ids + .iter() + .any(|binding| binding == &config.middleware) + }); + if !registered { + return Err(miette!( + "middleware binding '{}' used by config '{}' is not registered", + config.middleware, + config.name + )); + } + } + Ok(()) + } + + /// Return only external services referenced by the effective policy. + pub fn required_external_services( + &self, + policy: Option<&SandboxPolicy>, + ) -> Vec { + let Some(policy) = policy else { + return Vec::new(); + }; + let selected: HashSet<&str> = policy + .network_middlewares + .iter() + .map(|config| config.middleware.as_str()) + .collect(); + self.external + .iter() + .filter(|service| { + service + .binding_ids + .iter() + .any(|binding| selected.contains(binding.as_str())) + }) + .map(|service| service.registration.clone()) + .collect() + } +} + +impl Default for ChainRunner { + fn default() -> Self { + Self::from_registry(MiddlewareRegistry::default()) + } +} + impl ChainRunner { pub fn new(service: Arc) -> Self { Self { - state: Arc::new(MiddlewareServiceState { - service, - manifest: OnceCell::new(), + registry: Arc::new(MiddlewareRegistry { + services: Arc::new(vec![Arc::new(MiddlewareServiceState { + service, + manifest: OnceCell::new(), + operator_max_body_bytes: None, + })]), + external: Arc::new(Vec::new()), }), } } - async fn manifest(&self) -> Result<&MiddlewareManifest> { - self.state - .manifest - .get_or_try_init(|| async { - self.state - .service - .describe(Request::new(())) - .await - .map(tonic::Response::into_inner) - .map_err(|error| { - miette!( - "middleware Describe failed: {}", - safe_reason(&error.to_string()) - ) - }) - }) - .await + pub fn from_registry(registry: MiddlewareRegistry) -> Self { + Self { + registry: Arc::new(registry), + } + } + + async fn manifests(&self) -> Result, MiddlewareManifest)>> { + let mut manifests = Vec::with_capacity(self.registry.services.len()); + for state in self.registry.services.iter() { + let manifest = state + .manifest + .get_or_try_init(|| async { + state + .service + .describe(Request::new(())) + .await + .map(tonic::Response::into_inner) + .map_err(|error| { + miette!( + "middleware Describe failed: {}", + safe_reason(&error.to_string()) + ) + }) + }) + .await?; + manifests.push((Arc::clone(state), manifest.clone())); + } + Ok(manifests) } pub async fn describe_chain(&self, entries: &[ChainEntry]) -> Result> { - let manifest = self.manifest().await?; + let manifests = self.manifests().await?; entries .iter() .map(|entry| { - let binding = manifest - .bindings - .iter() - .find(|binding| binding.id == entry.implementation) - .cloned(); + let described = manifests.iter().find_map(|(state, manifest)| { + manifest + .bindings + .iter() + .find(|binding| binding.id == entry.implementation) + .cloned() + .map(|binding| (Arc::clone(state), binding)) + }); + let (service, binding) = described.map_or((None, None), |(service, binding)| { + (Some(service), Some(binding)) + }); let max_body_bytes = binding .as_ref() .map(|binding| { - usize::try_from(binding.max_body_bytes).map_err(|_| { + let advertised = usize::try_from(binding.max_body_bytes).map_err(|_| { miette!( "middleware binding '{}' reports a body limit too large for this platform", binding.id ) - }) + })?; + Ok::<_, miette::Report>(service + .as_ref() + .and_then(|state| state.operator_max_body_bytes) + .unwrap_or(advertised)) }) .transpose()? .unwrap_or(0); Ok(DescribedChainEntry { entry: entry.clone(), + service, binding, max_body_bytes, }) @@ -265,6 +568,44 @@ impl ChainRunner { .collect() } + pub async fn validate_config( + &self, + implementation: &str, + config: prost_types::Struct, + ) -> Result<()> { + let manifests = self.manifests().await?; + let Some((state, _)) = manifests.iter().find(|(_, manifest)| { + manifest + .bindings + .iter() + .any(|binding| binding.id == implementation) + }) else { + return Err(miette!( + "middleware binding '{implementation}' is not registered" + )); + }; + let response = state + .service + .validate_config(Request::new(ValidateConfigRequest { + api_version: API_VERSION.into(), + binding_id: implementation.into(), + config: Some(config), + })) + .await + .map(tonic::Response::into_inner) + .map_err(|error| { + miette!( + "middleware ValidateConfig failed: {}", + safe_reason(&error.to_string()) + ) + })?; + if response.valid { + Ok(()) + } else { + Err(miette!("{}", safe_reason(&response.reason))) + } + } + pub async fn evaluate( &self, entries: &[ChainEntry], @@ -320,8 +661,10 @@ impl ChainRunner { } } let evaluation = build_evaluation(entry, binding, &input, &headers, &body); - let result = match self - .state + let Some(service) = entry.service.as_ref() else { + unreachable!("described binding always has a service") + }; + let result = match service .service .evaluate_http_request(Request::new(evaluation)) .await @@ -545,7 +888,10 @@ pub(crate) fn safe_reason(reason: &str) -> String { #[cfg(test)] mod tests { use super::*; - use openshell_core::proto::middleware::v1::supervisor_middleware_server::SupervisorMiddleware; + use openshell_core::proto::middleware::v1::supervisor_middleware_server::{ + SupervisorMiddleware, SupervisorMiddlewareServer, + }; + use tokio_stream::wrappers::TcpListenerStream; fn entry(name: &str, on_error: OnError) -> ChainEntry { ChainEntry { @@ -736,7 +1082,7 @@ mod tests { async fn validate_config( &self, - _request: Request, + _request: Request, ) -> std::result::Result< tonic::Response, tonic::Status, @@ -780,6 +1126,51 @@ mod tests { } } + fn external_registration(max_body_bytes: u64) -> ExternalMiddlewareService { + ExternalMiddlewareService { + name: "local-guard-service".into(), + endpoint: "http://127.0.0.1:50051".into(), + allow_insecure: true, + max_body_bytes, + } + } + + async fn registry_with_external( + service: Arc, + registration: ExternalMiddlewareService, + ) -> MiddlewareRegistry { + let manifest = service + .describe(Request::new(())) + .await + .expect("describe test service") + .into_inner(); + let operator_max_body_bytes = usize::try_from(registration.max_body_bytes).unwrap(); + let mut known = HashSet::from([BUILTIN_SECRETS.to_string()]); + let binding_ids = validate_external_manifest( + ®istration, + &manifest, + operator_max_body_bytes, + &mut known, + ) + .expect("valid external manifest"); + let manifest_cell = OnceCell::new(); + manifest_cell.set(manifest).expect("manifest cache"); + MiddlewareRegistry { + services: Arc::new(vec![ + Arc::clone(&IN_PROCESS_SERVICE), + Arc::new(MiddlewareServiceState { + service, + manifest: manifest_cell, + operator_max_body_bytes: Some(operator_max_body_bytes), + }), + ]), + external: Arc::new(vec![RegisteredExternalService { + registration, + binding_ids, + }]), + } + } + #[tokio::test] async fn descriptors_are_resolved_from_any_middleware_service() { let runner = ChainRunner::new(Arc::new(ScriptedService { @@ -815,6 +1206,138 @@ mod tests { assert!(outcome.allowed); } + #[tokio::test] + async fn mixed_builtin_and_external_chain_uses_operator_limit() { + let external = Arc::new(ScriptedService { + binding_id: "example/content-guard".into(), + max_body_bytes: 4096, + result: allow_result(), + }); + let registry = registry_with_external(external, external_registration(1024)).await; + let runner = ChainRunner::from_registry(registry); + let external_entry = ChainEntry { + name: "external".into(), + implementation: "example/content-guard".into(), + config: prost_types::Struct::default(), + on_error: OnError::FailClosed, + }; + let entries = [entry("builtin", OnError::FailClosed), external_entry]; + + let described = runner + .describe_chain(&entries) + .await + .expect("describe chain"); + assert_eq!(described[0].max_body_bytes(), 256 * 1024); + assert_eq!(described[1].max_body_bytes(), 1024); + + let outcome = runner + .evaluate_described(&described, input(r#"password="top-secret""#)) + .await + .expect("evaluate mixed chain"); + assert!(outcome.allowed); + assert_eq!(outcome.applied.len(), 2); + assert_eq!( + String::from_utf8(outcome.body).expect("utf8"), + r#"password="[REDACTED]""# + ); + } + + #[test] + fn external_manifest_rejects_operator_limit_above_capability() { + let registration = external_registration(4097); + let manifest = MiddlewareManifest { + api_version: API_VERSION.into(), + name: "example/service".into(), + service_version: "test".into(), + bindings: vec![MiddlewareBinding { + id: "example/content-guard".into(), + operation: HTTP_REQUEST_OPERATION.into(), + phase: PRE_CREDENTIALS_PHASE.into(), + max_body_bytes: 4096, + }], + }; + let error = validate_external_manifest( + ®istration, + &manifest, + 4097, + &mut HashSet::from([BUILTIN_SECRETS.to_string()]), + ) + .expect_err("operator limit must fit capability"); + assert!(error.to_string().contains("exceeds")); + } + + #[test] + fn external_registration_requires_explicit_insecure_opt_in() { + let mut registration = external_registration(4096); + registration.allow_insecure = false; + let error = validate_registration(®istration).expect_err("opt-in required"); + assert!(error.to_string().contains("allow_insecure")); + } + + #[tokio::test] + async fn external_registry_invokes_remote_service_over_grpc() { + let listener = tokio::net::TcpListener::bind("127.0.0.1:0") + .await + .expect("bind test middleware"); + let address = listener.local_addr().expect("test middleware address"); + let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel(); + let server = tonic::transport::Server::builder() + .add_service(SupervisorMiddlewareServer::new(ScriptedService { + binding_id: "example/content-guard".into(), + max_body_bytes: 4096, + result: allow_result(), + })) + .serve_with_incoming_shutdown(TcpListenerStream::new(listener), async { + let _ = shutdown_rx.await; + }); + let server_task = tokio::spawn(server); + + let mut registration = external_registration(1024); + registration.endpoint = format!("http://{address}"); + let registry = MiddlewareRegistry::connect_external(vec![registration.clone()]) + .await + .expect("connect external middleware"); + let policy = SandboxPolicy { + network_middlewares: vec![NetworkMiddlewareConfig { + name: "guard".into(), + middleware: "example/content-guard".into(), + config: Some(prost_types::Struct::default()), + on_error: "fail_closed".into(), + endpoints: None, + }], + ..Default::default() + }; + + registry + .validate_policy_configs(&policy) + .await + .expect("remote config validates"); + assert_eq!( + registry.required_external_services(Some(&policy)), + vec![registration] + ); + + let outcome = ChainRunner::from_registry(registry) + .evaluate( + &[ChainEntry { + name: "guard".into(), + implementation: "example/content-guard".into(), + config: prost_types::Struct::default(), + on_error: OnError::FailClosed, + }], + input("hello"), + ) + .await + .expect("remote evaluation"); + assert!(outcome.allowed); + + let _ = shutdown_tx.send(()); + server_task + .await + .expect("join test middleware") + .expect("serve"); + } + #[tokio::test] async fn deny_decision_short_circuits_chain() { let runner = ChainRunner::new(Arc::new(scripted_service( diff --git a/crates/openshell-supervisor-middleware/src/remote.rs b/crates/openshell-supervisor-middleware/src/remote.rs new file mode 100644 index 000000000..dd147788b --- /dev/null +++ b/crates/openshell-supervisor-middleware/src/remote.rs @@ -0,0 +1,91 @@ +// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +use std::time::Duration; + +use miette::{IntoDiagnostic, Result, WrapErr}; +use openshell_core::proto::middleware::v1::supervisor_middleware_client::SupervisorMiddlewareClient; +use openshell_core::proto::middleware::v1::supervisor_middleware_server::SupervisorMiddleware; +use openshell_core::proto::{ + HttpRequestEvaluation, HttpRequestResult, MiddlewareManifest, ValidateConfigRequest, + ValidateConfigResponse, +}; +use tonic::transport::{Channel, Endpoint}; +use tonic::{Request, Response, Status}; + +const CONNECT_TIMEOUT: Duration = Duration::from_secs(5); +const RPC_TIMEOUT: Duration = Duration::from_secs(5); + +#[derive(Clone)] +pub struct RemoteMiddlewareService { + registration_name: String, + client: SupervisorMiddlewareClient, +} + +impl RemoteMiddlewareService { + pub async fn connect(registration_name: &str, endpoint: &str) -> Result { + let channel = Endpoint::from_shared(endpoint.to_string()) + .into_diagnostic() + .wrap_err_with(|| { + format!("middleware registration '{registration_name}' has an invalid endpoint") + })? + .connect_timeout(CONNECT_TIMEOUT) + .connect() + .await + .into_diagnostic() + .wrap_err_with(|| { + format!( + "middleware registration '{registration_name}' could not connect to {endpoint}" + ) + })?; + Ok(Self { + registration_name: registration_name.to_string(), + client: SupervisorMiddlewareClient::new(channel), + }) + } + + async fn with_timeout( + &self, + operation: &'static str, + future: impl Future, Status>>, + ) -> std::result::Result, Status> { + tokio::time::timeout(RPC_TIMEOUT, future) + .await + .map_err(|_| { + Status::deadline_exceeded(format!( + "middleware '{}' {operation} timed out", + self.registration_name + )) + })? + } +} + +#[tonic::async_trait] +impl SupervisorMiddleware for RemoteMiddlewareService { + async fn describe( + &self, + request: Request<()>, + ) -> std::result::Result, Status> { + let mut client = self.client.clone(); + self.with_timeout("Describe", client.describe(request)) + .await + } + + async fn validate_config( + &self, + request: Request, + ) -> std::result::Result, Status> { + let mut client = self.client.clone(); + self.with_timeout("ValidateConfig", client.validate_config(request)) + .await + } + + async fn evaluate_http_request( + &self, + request: Request, + ) -> std::result::Result, Status> { + let mut client = self.client.clone(); + self.with_timeout("EvaluateHttpRequest", client.evaluate_http_request(request)) + .await + } +} diff --git a/crates/openshell-supervisor-network/src/l7/relay.rs b/crates/openshell-supervisor-network/src/l7/relay.rs index 8383b6bb2..84853f751 100644 --- a/crates/openshell-supervisor-network/src/l7/relay.rs +++ b/crates/openshell-supervisor-network/src/l7/relay.rs @@ -454,35 +454,41 @@ where if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { let chain = engine.query_middleware_chain(&middleware_network_input(ctx))?; - let req = - match apply_middleware_chain(req, client, ctx, chain, engine.generation_guard()) - .await? - { - MiddlewareApplyResult::Allowed(req) => req, - MiddlewareApplyResult::Denied(reason) => { - crate::l7::rest::RestProvider::default() - .deny_with_redacted_target( - &crate::l7::provider::L7Request { - action: request_info.action.clone(), - target: redacted_target.clone(), - query_params: request_info.query_params.clone(), - raw_header: Vec::new(), - body_length: crate::l7::provider::BodyLength::None, - }, - &ctx.policy_name, - &reason, - client, - Some(&redacted_target), - Some(crate::l7::rest::DenyResponseContext { - host: Some(&ctx.host), - port: Some(ctx.port), - binary: Some(&ctx.binary_path), - }), - ) - .await?; - return Ok(()); - } - }; + let req = match apply_middleware_chain( + req, + client, + ctx, + chain, + engine.middleware_runner(), + engine.generation_guard(), + ) + .await? + { + MiddlewareApplyResult::Allowed(req) => req, + MiddlewareApplyResult::Denied(reason) => { + crate::l7::rest::RestProvider::default() + .deny_with_redacted_target( + &crate::l7::provider::L7Request { + action: request_info.action.clone(), + target: redacted_target.clone(), + query_params: request_info.query_params.clone(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }, + &ctx.policy_name, + &reason, + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } + }; let outcome = crate::l7::rest::relay_http_request_with_options_guarded( &req, client, @@ -786,9 +792,11 @@ pub(crate) async fn apply_middleware_chain, + runner: &openshell_supervisor_middleware::ChainRunner, generation_guard: &PolicyGenerationGuard, ) -> Result { - apply_middleware_chain_for_scheme(req, client, ctx, "https", chain, generation_guard).await + apply_middleware_chain_for_scheme(req, client, ctx, "https", chain, runner, generation_guard) + .await } pub(crate) async fn apply_middleware_chain_for_scheme( @@ -797,12 +805,12 @@ pub(crate) async fn apply_middleware_chain_for_scheme, + runner: &openshell_supervisor_middleware::ChainRunner, generation_guard: &PolicyGenerationGuard, ) -> Result { if chain.is_empty() { return Ok(MiddlewareApplyResult::Allowed(req)); } - let runner = openshell_supervisor_middleware::ChainRunner::default(); let chain = runner.describe_chain(&chain).await?; let max_body_bytes = middleware_chain_body_limit(&chain).expect("non-empty middleware chain has a body limit"); @@ -1221,35 +1229,41 @@ where if allowed || config.enforcement == EnforcementMode::Audit { let chain = engine.query_middleware_chain(&middleware_network_input(ctx))?; - let req = - match apply_middleware_chain(req, client, ctx, chain, engine.generation_guard()) - .await? - { - MiddlewareApplyResult::Allowed(req) => req, - MiddlewareApplyResult::Denied(reason) => { - provider - .deny_with_redacted_target( - &crate::l7::provider::L7Request { - action: request_info.action.clone(), - target: redacted_target.clone(), - query_params: request_info.query_params.clone(), - raw_header: Vec::new(), - body_length: crate::l7::provider::BodyLength::None, - }, - &ctx.policy_name, - &reason, - client, - Some(&redacted_target), - Some(crate::l7::rest::DenyResponseContext { - host: Some(&ctx.host), - port: Some(ctx.port), - binary: Some(&ctx.binary_path), - }), - ) - .await?; - return Ok(()); - } - }; + let req = match apply_middleware_chain( + req, + client, + ctx, + chain, + engine.middleware_runner(), + engine.generation_guard(), + ) + .await? + { + MiddlewareApplyResult::Allowed(req) => req, + MiddlewareApplyResult::Denied(reason) => { + provider + .deny_with_redacted_target( + &crate::l7::provider::L7Request { + action: request_info.action.clone(), + target: redacted_target.clone(), + query_params: request_info.query_params.clone(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }, + &ctx.policy_name, + &reason, + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } + }; let req_with_auth = match crate::l7::token_grant_injection::inject_if_needed(req, ctx).await { Ok(req) => req, @@ -1490,35 +1504,41 @@ where if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { let chain = engine.query_middleware_chain(&middleware_network_input(ctx))?; - let req = - match apply_middleware_chain(req, client, ctx, chain, engine.generation_guard()) - .await? - { - MiddlewareApplyResult::Allowed(req) => req, - MiddlewareApplyResult::Denied(reason) => { - crate::l7::rest::RestProvider::default() - .deny_with_redacted_target( - &crate::l7::provider::L7Request { - action: request_info.action.clone(), - target: redacted_target.clone(), - query_params: request_info.query_params.clone(), - raw_header: Vec::new(), - body_length: crate::l7::provider::BodyLength::None, - }, - &ctx.policy_name, - &reason, - client, - Some(&redacted_target), - Some(crate::l7::rest::DenyResponseContext { - host: Some(&ctx.host), - port: Some(ctx.port), - binary: Some(&ctx.binary_path), - }), - ) - .await?; - return Ok(()); - } - }; + let req = match apply_middleware_chain( + req, + client, + ctx, + chain, + engine.middleware_runner(), + engine.generation_guard(), + ) + .await? + { + MiddlewareApplyResult::Allowed(req) => req, + MiddlewareApplyResult::Denied(reason) => { + crate::l7::rest::RestProvider::default() + .deny_with_redacted_target( + &crate::l7::provider::L7Request { + action: request_info.action.clone(), + target: redacted_target.clone(), + query_params: request_info.query_params.clone(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }, + &ctx.policy_name, + &reason, + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } + }; // Future MCP response/SSE introspection or rewrite would hook here // before returning upstream bytes. The current policy schema has no // trusted-annotations or version-profile field, so MCP responses and @@ -1714,35 +1734,41 @@ where if allowed || (config.enforcement == EnforcementMode::Audit && !force_deny) { let chain = engine.query_middleware_chain(&middleware_network_input(ctx))?; - let req = - match apply_middleware_chain(req, client, ctx, chain, engine.generation_guard()) - .await? - { - MiddlewareApplyResult::Allowed(req) => req, - MiddlewareApplyResult::Denied(reason) => { - crate::l7::rest::RestProvider::default() - .deny_with_redacted_target( - &crate::l7::provider::L7Request { - action: request_info.action.clone(), - target: redacted_target.clone(), - query_params: request_info.query_params.clone(), - raw_header: Vec::new(), - body_length: crate::l7::provider::BodyLength::None, - }, - &ctx.policy_name, - &reason, - client, - Some(&redacted_target), - Some(crate::l7::rest::DenyResponseContext { - host: Some(&ctx.host), - port: Some(ctx.port), - binary: Some(&ctx.binary_path), - }), - ) - .await?; - return Ok(()); - } - }; + let req = match apply_middleware_chain( + req, + client, + ctx, + chain, + engine.middleware_runner(), + engine.generation_guard(), + ) + .await? + { + MiddlewareApplyResult::Allowed(req) => req, + MiddlewareApplyResult::Denied(reason) => { + crate::l7::rest::RestProvider::default() + .deny_with_redacted_target( + &crate::l7::provider::L7Request { + action: request_info.action.clone(), + target: redacted_target.clone(), + query_params: request_info.query_params.clone(), + raw_header: Vec::new(), + body_length: crate::l7::provider::BodyLength::None, + }, + &ctx.policy_name, + &reason, + client, + Some(&redacted_target), + Some(crate::l7::rest::DenyResponseContext { + host: Some(&ctx.host), + port: Some(ctx.port), + binary: Some(&ctx.binary_path), + }), + ) + .await?; + return Ok(()); + } + }; let outcome = crate::l7::rest::relay_http_request_with_resolver_guarded( &req, client, @@ -2170,7 +2196,9 @@ where if generation != generation_guard.captured_generation() { return Ok(()); } - match apply_middleware_chain(req, client, ctx, chain, generation_guard).await? { + let runner = engine.middleware_runner()?; + match apply_middleware_chain(req, client, ctx, chain, &runner, generation_guard).await? + { MiddlewareApplyResult::Allowed(req) => req, MiddlewareApplyResult::Denied(reason) => { crate::l7::rest::RestProvider::default() diff --git a/crates/openshell-supervisor-network/src/opa.rs b/crates/openshell-supervisor-network/src/opa.rs index c4e773996..9e1427dcd 100644 --- a/crates/openshell-supervisor-network/src/opa.rs +++ b/crates/openshell-supervisor-network/src/opa.rs @@ -13,11 +13,11 @@ use openshell_core::policy::{ }; use openshell_core::proto::SandboxPolicy as ProtoSandboxPolicy; use openshell_policy::L7ConfigStanza; -use openshell_supervisor_middleware::ChainEntry; +use openshell_supervisor_middleware::{ChainEntry, ChainRunner, MiddlewareRegistry}; use std::collections::HashSet; use std::path::{Path, PathBuf}; use std::sync::{ - Arc, Mutex, + Arc, Mutex, RwLock, atomic::{AtomicU64, Ordering}, }; @@ -73,6 +73,7 @@ pub struct SandboxConfig { pub struct OpaEngine { engine: Mutex, generation: Arc, + middleware_runner: RwLock, } /// Generation guard captured when an HTTP tunnel or request path starts. @@ -112,6 +113,7 @@ impl PolicyGenerationGuard { pub struct TunnelPolicyEngine { engine: Mutex, generation_guard: PolicyGenerationGuard, + middleware_runner: ChainRunner, } impl TunnelPolicyEngine { @@ -135,6 +137,10 @@ impl TunnelPolicyEngine { &self.engine } + pub(crate) fn middleware_runner(&self) -> &ChainRunner { + &self.middleware_runner + } + /// Query the ordered middleware chain for a destination within this tunnel. pub fn query_middleware_chain(&self, input: &NetworkInput) -> Result> { let mut engine = self @@ -164,6 +170,7 @@ impl OpaEngine { Ok(Self { engine: Mutex::new(engine), generation: Arc::new(AtomicU64::new(0)), + middleware_runner: RwLock::new(ChainRunner::default()), }) } @@ -182,6 +189,7 @@ impl OpaEngine { Ok(Self { engine: Mutex::new(engine), generation: Arc::new(AtomicU64::new(0)), + middleware_runner: RwLock::new(ChainRunner::default()), }) } @@ -254,6 +262,7 @@ impl OpaEngine { Ok(Self { engine: Mutex::new(engine), generation: Arc::new(AtomicU64::new(0)), + middleware_runner: RwLock::new(ChainRunner::default()), }) } @@ -451,6 +460,25 @@ impl OpaEngine { self.generation.load(Ordering::Acquire) } + /// Replace the complete middleware service registry and invalidate + /// existing tunnels so subsequent requests use the new service set. + pub fn replace_middleware_registry(&self, registry: MiddlewareRegistry) -> Result<()> { + let mut runner = self + .middleware_runner + .write() + .map_err(|_| miette::miette!("middleware runner lock poisoned"))?; + *runner = ChainRunner::from_registry(registry); + self.generation.fetch_add(1, Ordering::AcqRel); + Ok(()) + } + + pub(crate) fn middleware_runner(&self) -> Result { + self.middleware_runner + .read() + .map(|runner| runner.clone()) + .map_err(|_| miette::miette!("middleware runner lock poisoned")) + } + /// Return a guard for a previously captured policy generation. pub fn generation_guard(&self, expected_generation: u64) -> Result { let generation = self.current_generation(); @@ -662,6 +690,7 @@ impl OpaEngine { captured_generation: generation, current_generation: Arc::clone(&self.generation), }, + middleware_runner: self.middleware_runner()?, }) } } @@ -2941,6 +2970,7 @@ network_policies: let engine = OpaEngine { engine: Mutex::new(rego), generation: Arc::new(AtomicU64::new(0)), + middleware_runner: RwLock::new(ChainRunner::default()), }; let input = l7_websocket_graphql_input( "realtime.graphql.com", diff --git a/crates/openshell-supervisor-network/src/proxy.rs b/crates/openshell-supervisor-network/src/proxy.rs index afec98666..8616c3b2c 100644 --- a/crates/openshell-supervisor-network/src/proxy.rs +++ b/crates/openshell-supervisor-network/src/proxy.rs @@ -4218,6 +4218,7 @@ async fn handle_forward_proxy( return Ok(()); } if !chain.is_empty() { + let middleware_runner = opa_engine.middleware_runner()?; let request = crate::l7::rest::request_from_buffered_http( method, middleware_path, @@ -4230,6 +4231,7 @@ async fn handle_forward_proxy( &l7_ctx, &scheme, chain, + &middleware_runner, &forward_generation_guard, ) .await? diff --git a/docs/reference/gateway-config.mdx b/docs/reference/gateway-config.mdx index 2aaa6e7b0..dc4659410 100644 --- a/docs/reference/gateway-config.mdx +++ b/docs/reference/gateway-config.mdx @@ -103,6 +103,14 @@ guest_tls_key = "/etc/openshell/certs/client-key.pem" grpc_rate_limit_requests = 120 grpc_rate_limit_window_seconds = 60 +# Local-development-only external supervisor middleware. The endpoint must be +# reachable from both the gateway and sandbox supervisors. +[[openshell.gateway.middleware]] +name = "local-content-guard" +endpoint = "http://host.openshell.internal:50051" +allow_insecure = true +max_body_bytes = 262144 + # Gateway listener TLS (distinct from the per-driver guest_tls_*). [openshell.gateway.tls] cert_path = "/etc/openshell/certs/gateway.pem" @@ -140,6 +148,26 @@ Local Docker, Podman, and VM gateways can also set `[openshell.gateway.mtls_auth `[openshell.gateway.auth] allow_unauthenticated_users = true` is an unsafe local-development and trusted-proxy escape hatch. It accepts user-facing CLI/API calls without OIDC or mTLS credentials while sandbox supervisors still authenticate with gateway-minted sandbox JWTs. Leave it false for shared and production gateways. +## External Supervisor Middleware + +Register external supervisor middleware with one or more `[[openshell.gateway.middleware]]` entries. Registration is static and operator-owned; changing it requires restarting the gateway. + +```toml +[[openshell.gateway.middleware]] +name = "local-content-guard" +endpoint = "http://host.openshell.internal:50051" +allow_insecure = true +max_body_bytes = 262144 +``` + +Each service implements the supervisor middleware gRPC contract and may expose multiple binding IDs through `Describe`. Policies reference those binding IDs, not the registration `name`. The gateway rejects duplicate binding IDs across services and prevents external services from claiming the reserved `openshell/` namespace. + +The gateway connects to every registered service and validates `Describe` before it starts. The service must therefore be running before the gateway. Policy creation and full policy updates call `ValidateConfig`; an unavailable service or invalid middleware configuration rejects the policy before persistence. + +`max_body_bytes` is the operator limit for every binding exposed by the service. It must be greater than zero and no larger than each binding's advertised limit. OpenShell rejects an oversized value instead of silently clamping it. + +External middleware is a local-development preview. The endpoint must use plaintext `http://`, and `allow_insecure = true` is required as an explicit acknowledgement that inspected request content is sent without transport encryption or peer authentication. TLS, authentication, health checks, and runtime registration are not supported. The endpoint must be reachable from both the gateway and sandbox supervisors; use `host.openshell.internal` or another shared address when both runtimes resolve it. + `image_pull_policy` is intentionally not a shared gateway key. Kubernetes and Docker use `Always`, `IfNotPresent`, or `Never`. Podman uses `always`, `missing`, `never`, or `newer`. Set it inside the relevant driver table. ## Driver References diff --git a/docs/reference/policy-schema.mdx b/docs/reference/policy-schema.mdx index 906d55d58..5cd2070f2 100644 --- a/docs/reference/policy-schema.mdx +++ b/docs/reference/policy-schema.mdx @@ -20,6 +20,7 @@ filesystem_policy: { ... } landlock: { ... } process: { ... } network_policies: { ... } +network_middlewares: [ ... ] ``` | Field | Type | Required | Category | Description | @@ -29,6 +30,7 @@ network_policies: { ... } | `landlock` | object | No | Static | Configures Landlock LSM enforcement behavior. | | `process` | object | No | Static | Sets the user and group the agent process runs as. | | `network_policies` | map | No | Dynamic | Declares which binaries can reach which network endpoints. | +| `network_middlewares` | list | No | Dynamic | Selects ordered HTTP request middleware by destination host. | Static fields are set at sandbox creation time. Changing them requires destroying and recreating the sandbox. Dynamic fields can be updated on a running sandbox with `openshell policy update` for incremental merges or `openshell policy set` for full replacement, and take effect without restarting. @@ -468,7 +470,35 @@ Identifies an executable that is permitted to use the associated endpoints. |---|---|---|---| | `path` | string | Yes | Filesystem path to the executable. Supports glob patterns with `*` and `**`. For example, `/sandbox/.vscode-server/**` matches any executable under that directory tree. | -### Full Example +## Network Middleware + +**Category:** Dynamic + +An ordered list of middleware configs selected after network and L7 policy admit an HTTP request. Middleware selection is independent of the network policy entry that admitted the request. Every matching config runs once in list order before provider credential injection. + +```yaml showLineNumbers={false} +network_middlewares: + - name: redact-secrets + middleware: openshell/secrets + config: + secrets: redact + on_error: fail_closed + endpoints: + include: ["*.example.com"] + exclude: ["trusted.example.com"] +``` + +| Field | Type | Required | Description | +|---|---|---|---| +| `name` | string | Yes | Policy-local config name. Names must be unique within the list. | +| `middleware` | string | Yes | Built-in or operator-registered binding ID. `openshell/` is reserved for built-ins. | +| `config` | object | No | Implementation-owned configuration validated by the selected middleware. | +| `on_error` | string | No | `fail_closed` denies the request when the stage fails; `fail_open` skips the failed stage. Defaults to `fail_closed`. | +| `endpoints` | object | Yes | Host selector with required non-empty `include` and optional `exclude` lists. Exclusions take precedence. | + +Host selectors use the same case-insensitive exact and DNS glob semantics as network endpoints. Middleware runs only on HTTP requests the supervisor parses. A selector that can require middleware on a `tls: skip` endpoint is rejected because OpenShell cannot inspect that traffic. + +## Full Example The following policy grants read-only GitHub API access and npm registry access: diff --git a/docs/sandboxes/policies.mdx b/docs/sandboxes/policies.mdx index ea8716422..cf61989e8 100644 --- a/docs/sandboxes/policies.mdx +++ b/docs/sandboxes/policies.mdx @@ -12,7 +12,7 @@ Use this page to apply and iterate policy changes on running sandboxes. For a fu ## Policy Structure -A policy has static sections `filesystem_policy`, `landlock`, and `process` that are locked at sandbox creation, and a dynamic section `network_policies` that is hot-reloadable on a running sandbox. +A policy has static sections `filesystem_policy`, `landlock`, and `process` that are locked at sandbox creation, and dynamic `network_policies` and `network_middlewares` sections that are hot-reloadable on a running sandbox. ```yaml wordWrap showLineNumbers={false} version: 1 @@ -44,6 +44,17 @@ network_policies: binaries: - path: /usr/bin/curl +# Dynamic: ordered middleware selected independently by admitted host. +network_middlewares: + - name: redact-secrets + middleware: openshell/secrets + config: + secrets: redact + on_error: fail_closed + endpoints: + include: ["api.example.com"] + exclude: [] + ``` Static sections are locked at sandbox creation. Changing them requires destroying and recreating the sandbox. @@ -57,6 +68,29 @@ Raw streams are connection-scoped and outside L7 live-reload guarantees. This in | `landlock` | Static | Configures Landlock LSM enforcement behavior. Set `compatibility` to `best_effort` (skip individual inaccessible paths while applying remaining rules) or `hard_requirement` (fail if any path is inaccessible or the required kernel ABI is unavailable). Refer to the [Policy Schema Reference](/reference/policy-schema#landlock) for the full behavior table. | | `process` | Static | Sets the OS-level identity for the agent process. `run_as_user` and `run_as_group` default to `sandbox`. Root (`root` or `0`) is rejected. The agent also runs with seccomp filters that block dangerous system calls. | | `network_policies` | Dynamic | Controls network access for ordinary outbound traffic from the sandbox. Each block has a name, a list of endpoints (host, port, protocol, and optional rules), and a list of binaries allowed to use those endpoints.
Every outbound connection except `https://inference.local` goes through the proxy, which queries the [policy engine](/about/how-it-works#core-components) with the destination and calling binary. A connection is allowed only when both match an entry in the same policy block.
For endpoints with `protocol: rest`, the proxy auto-detects TLS and terminates it so each HTTP request can be checked against that endpoint's `rules` (method and path). For endpoints with `protocol: websocket`, the proxy validates the RFC 6455 upgrade and evaluates `GET` rules for the handshake plus either `WEBSOCKET_TEXT` rules for raw client text messages or GraphQL operation rules for GraphQL-over-WebSocket messages. Set `websocket_credential_rewrite: true` only when a WebSocket or REST compatibility endpoint must keep placeholder credentials in sandbox-owned text frames and resolve them at the OpenShell relay boundary.
Endpoints without `protocol` allow the TCP stream through without inspecting payloads.
If no endpoint matches, the connection is denied. Configure managed inference separately through [Inference Routing](/sandboxes/inference-routing). | +| `network_middlewares` | Dynamic | Declares ordered HTTP request middleware configs. After network and L7 policy admit a request, OpenShell matches each config's host selectors independently and runs matching entries in declaration order before credential injection. | + +## Supervisor Middleware + +Supervisor middleware can inspect, deny, or replace admitted HTTP request bodies before provider credentials are injected. Middleware selection is independent of the `network_policies` rule that admitted the request: each `network_middlewares` entry matches the destination host through `endpoints.include` and `endpoints.exclude`. + +```yaml +network_middlewares: + - name: redact-secrets + middleware: openshell/secrets + config: + secrets: redact + on_error: fail_closed + endpoints: + include: ["*.example.com"] + exclude: ["trusted.example.com"] +``` + +Matching entries run once each in top-level declaration order. Config names must be unique. Different config names may use the same implementation and run as distinct stages. `exclude` takes precedence over `include`. + +`openshell/secrets` is built into the supervisor. External binding IDs must be registered by the gateway operator before a policy can reference them; see [External Supervisor Middleware](/reference/gateway-config#external-supervisor-middleware). The gateway calls the implementation's `ValidateConfig` before accepting the policy. + +`on_error` defaults to `fail_closed`. Use `fail_open` only when skipping a failed middleware is acceptable. Middleware applies only to HTTP traffic the supervisor can parse and inspect; policy validation rejects a required selector that can cover a `tls: skip` endpoint. ## Baseline Filesystem Paths diff --git a/proto/sandbox.proto b/proto/sandbox.proto index 04cbd6776..afec58723 100644 --- a/proto/sandbox.proto +++ b/proto/sandbox.proto @@ -352,4 +352,20 @@ message GetSandboxConfigResponse { // Fingerprint for provider credential inputs attached to this sandbox. // Changes when attached provider names or attached provider records change. uint64 provider_env_revision = 8; + // Operator-registered external middleware services required by the effective + // policy. Built-in middleware is not included. + repeated ExternalMiddlewareService external_middleware = 9; +} + +// Connection details for one operator-registered external middleware service. +// V1 supports only explicitly enabled plaintext gRPC for local development. +message ExternalMiddlewareService { + // Operator-facing registration name used for diagnostics. + string name = 1; + // gRPC endpoint reachable from the sandbox supervisor. + string endpoint = 2; + // Explicit acknowledgement that request content is sent without TLS. + bool allow_insecure = 3; + // Operator-owned body limit applied to every binding exposed by the service. + uint64 max_body_bytes = 4; } From 03ae94e66dc120937546e6f2295124a20d4a7e49 Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Mon, 29 Jun 2026 15:38:08 -0700 Subject: [PATCH 08/16] refactor(supervisor-middleware): clarify service contract Signed-off-by: Piotr Mlocek --- architecture/sandbox.md | 4 +- crates/openshell-core/src/grpc_client.rs | 4 +- crates/openshell-sandbox/src/lib.rs | 27 ++++---- crates/openshell-server/src/config_file.rs | 10 +-- crates/openshell-server/src/grpc/policy.rs | 17 +++-- crates/openshell-server/src/lib.rs | 2 +- .../src/lib.rs | 62 +++++++++---------- docs/reference/gateway-config.mdx | 12 ++-- docs/reference/policy-schema.mdx | 4 ++ docs/sandboxes/policies.mdx | 6 +- proto/sandbox.proto | 10 +-- 11 files changed, 85 insertions(+), 73 deletions(-) diff --git a/architecture/sandbox.md b/architecture/sandbox.md index deec00f32..d63c4fbaa 100644 --- a/architecture/sandbox.md +++ b/architecture/sandbox.md @@ -69,7 +69,7 @@ are relayed but are not currently parsed for policy enforcement. For admitted HTTP requests, the proxy can run an ordered supervisor middleware chain before credential injection. Host selectors choose the chain independently of the network rule that admitted the request. Built-ins run in-process; -operator-registered external services are called directly from the supervisor +operator-registered services are called directly from the supervisor over the common middleware gRPC contract. The gateway validates external service capabilities and policy-owned config before delivery. Supervisors keep the last-known-good service registry when a live config reload fails. @@ -184,7 +184,7 @@ quickly. - If gateway config polling fails, the sandbox keeps its last-known-good policy. - If a live policy update is invalid, the supervisor rejects it and keeps the current policy. -- If an external middleware call fails, the selected config's `on_error` +- If an operator-run middleware call fails, the selected config's `on_error` behavior decides whether to deny the request or continue without that stage. - Existing raw byte streams are connection scoped. Dynamic policy changes apply to new connections or the next parsed HTTP request where the proxy can safely diff --git a/crates/openshell-core/src/grpc_client.rs b/crates/openshell-core/src/grpc_client.rs index 57f72dca6..836c7880c 100644 --- a/crates/openshell-core/src/grpc_client.rs +++ b/crates/openshell-core/src/grpc_client.rs @@ -728,7 +728,7 @@ pub struct SettingsPollResult { /// When `policy_source` is `Global`, the version of the global policy revision. pub global_policy_version: u32, pub provider_env_revision: u64, - pub external_middleware: Vec, + pub supervisor_middleware_services: Vec, } pub struct ProviderEnvironmentResult { @@ -773,7 +773,7 @@ impl CachedOpenShellClient { settings: inner.settings, global_policy_version: inner.global_policy_version, provider_env_revision: inner.provider_env_revision, - external_middleware: inner.external_middleware, + supervisor_middleware_services: inner.supervisor_middleware_services, }) } diff --git a/crates/openshell-sandbox/src/lib.rs b/crates/openshell-sandbox/src/lib.rs index af173abc3..c1531023a 100644 --- a/crates/openshell-sandbox/src/lib.rs +++ b/crates/openshell-sandbox/src/lib.rs @@ -1450,8 +1450,8 @@ async fn load_policy( info!("Creating OPA engine from proto policy data"); let engine = OpaEngine::from_proto(&proto_policy)?; let middleware_registry = - openshell_supervisor_middleware::MiddlewareRegistry::connect_external( - sandbox_config.external_middleware, + openshell_supervisor_middleware::MiddlewareRegistry::connect_services( + sandbox_config.supervisor_middleware_services, ) .await?; engine.replace_middleware_registry(middleware_registry)?; @@ -1605,7 +1605,7 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { let mut current_config_revision: u64 = 0; let mut current_provider_env_revision: u64 = ctx.provider_credentials.snapshot().revision; let mut current_policy_hash = String::new(); - let mut current_external_middleware = Vec::new(); + let mut current_middleware_services = Vec::new(); let mut current_settings: std::collections::HashMap< String, openshell_core::proto::EffectiveSetting, @@ -1617,7 +1617,7 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { apply_ocsf_json_setting(&ctx.ocsf_enabled, &result.settings); current_config_revision = result.config_revision; current_policy_hash = result.policy_hash.clone(); - current_external_middleware = result.external_middleware; + current_middleware_services = result.supervisor_middleware_services; current_settings = result.settings; debug!( config_revision = current_config_revision, @@ -1647,7 +1647,8 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { } let policy_changed = result.policy_hash != current_policy_hash; - let middleware_changed = result.external_middleware != current_external_middleware; + let middleware_changed = + result.supervisor_middleware_services != current_middleware_services; // Log which settings changed. log_setting_changes(¤t_settings, &result.settings); @@ -1707,26 +1708,26 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { } if middleware_changed { - match openshell_supervisor_middleware::MiddlewareRegistry::connect_external( - result.external_middleware.clone(), + match openshell_supervisor_middleware::MiddlewareRegistry::connect_services( + result.supervisor_middleware_services.clone(), ) .await .and_then(|registry| ctx.opa_engine.replace_middleware_registry(registry)) { Ok(()) => { - current_external_middleware = result.external_middleware.clone(); + current_middleware_services = result.supervisor_middleware_services.clone(); ocsf_emit!( ConfigStateChangeBuilder::new(ocsf_ctx()) .severity(SeverityId::Informational) .status(StatusId::Success) .state(StateId::Enabled, "loaded") .unmapped( - "external_middleware_count", - serde_json::json!(current_external_middleware.len()) + "supervisor_middleware_service_count", + serde_json::json!(current_middleware_services.len()) ) .message(format!( - "External middleware registry reloaded [service_count:{}]", - current_external_middleware.len() + "Supervisor middleware registry reloaded [service_count:{}]", + current_middleware_services.len() )) .build() ); @@ -1738,7 +1739,7 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { .status(StatusId::Failure) .state(StateId::Other, "failed") .message(format!( - "External middleware registry reload failed, keeping last-known-good registry [error:{error}]" + "Supervisor middleware registry reload failed, keeping last-known-good registry [error:{error}]" )) .build() ); diff --git a/crates/openshell-server/src/config_file.rs b/crates/openshell-server/src/config_file.rs index 13c7e9ebb..4b0fbc919 100644 --- a/crates/openshell-server/src/config_file.rs +++ b/crates/openshell-server/src/config_file.rs @@ -25,7 +25,7 @@ use std::net::SocketAddr; use std::path::{Path, PathBuf}; use openshell_core::config::ComputeDriverKind; -use openshell_core::proto::ExternalMiddlewareService; +use openshell_core::proto::SupervisorMiddlewareService; use openshell_core::{GatewayAuthConfig, GatewayJwtConfig, MtlsAuthConfig, OidcConfig, TlsConfig}; use serde::{Deserialize, Serialize}; @@ -153,7 +153,7 @@ pub struct GatewayFileSection { pub gateway_jwt: Option, // ── Supervisor middleware ───────────────────────────────────────────── - /// Statically registered external middleware services. Registration is + /// Statically registered supervisor middleware services. Registration is /// operator-owned and changes require a gateway restart. #[serde(default)] pub middleware: Vec, @@ -167,7 +167,7 @@ pub struct GatewayFileSection { pub database_url: Option, } -/// One `[[openshell.gateway.middleware]]` external middleware registration. +/// One `[[openshell.gateway.middleware]]` supervisor middleware registration. #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(deny_unknown_fields)] pub struct MiddlewareServiceFileConfig { @@ -182,7 +182,7 @@ pub struct MiddlewareServiceFileConfig { pub max_body_bytes: u64, } -impl From<&MiddlewareServiceFileConfig> for ExternalMiddlewareService { +impl From<&MiddlewareServiceFileConfig> for SupervisorMiddlewareService { fn from(config: &MiddlewareServiceFileConfig) -> Self { Self { name: config.name.clone(), @@ -435,7 +435,7 @@ allow_unauthenticated_users = true } #[test] - fn parses_external_middleware_registration() { + fn parses_supervisor_middleware_registration() { let toml = r#" [[openshell.gateway.middleware]] name = "local-guard" diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index c29de71ff..9587e7295 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -1230,14 +1230,13 @@ pub(super) async fn handle_get_sandbox_config( } let settings = merge_effective_settings(&global_settings, &sandbox_settings)?; - let external_middleware = state - .middleware_registry - .required_external_services(policy.as_ref()); + let supervisor_middleware_services = + state.middleware_registry.required_services(policy.as_ref()); let config_revision = compute_config_revision( policy.as_ref(), &settings, policy_source, - &external_middleware, + &supervisor_middleware_services, ); let provider_env_revision = compute_provider_env_revision(state.store.as_ref(), &sandbox_provider_names).await?; @@ -1251,7 +1250,7 @@ pub(super) async fn handle_get_sandbox_config( policy_source: policy_source.into(), global_policy_version, provider_env_revision, - external_middleware, + supervisor_middleware_services, })) } @@ -3144,7 +3143,7 @@ fn compute_config_revision( policy: Option<&ProtoSandboxPolicy>, settings: &HashMap, policy_source: PolicySource, - external_middleware: &[openshell_core::proto::ExternalMiddlewareService], + supervisor_middleware_services: &[openshell_core::proto::SupervisorMiddlewareService], ) -> u64 { let mut hasher = Sha256::new(); hasher.update((policy_source as i32).to_le_bytes()); @@ -3177,7 +3176,7 @@ fn compute_config_revision( } } } - let mut middleware = external_middleware.iter().collect::>(); + let mut middleware = supervisor_middleware_services.iter().collect::>(); middleware.sort_by(|left, right| left.name.cmp(&right.name)); for service in middleware { hasher.update(service.encode_to_vec()); @@ -9327,10 +9326,10 @@ mod tests { } #[test] - fn config_revision_changes_when_external_middleware_changes() { + fn config_revision_changes_when_supervisor_middleware_services_change() { let policy = ProtoSandboxPolicy::default(); let settings = HashMap::new(); - let service = openshell_core::proto::ExternalMiddlewareService { + let service = openshell_core::proto::SupervisorMiddlewareService { name: "local-guard".into(), endpoint: "http://127.0.0.1:50051".into(), allow_insecure: true, diff --git a/crates/openshell-server/src/lib.rs b/crates/openshell-server/src/lib.rs index bca8abe2e..bd974d4b4 100644 --- a/crates/openshell-server/src/lib.rs +++ b/crates/openshell-server/src/lib.rs @@ -242,7 +242,7 @@ pub(crate) async fn run_server( }) .unwrap_or_default(); let middleware_registry = Arc::new( - MiddlewareRegistry::connect_external(middleware_registrations) + MiddlewareRegistry::connect_services(middleware_registrations) .await .map_err(|error| Error::config(format!("middleware registration failed: {error}")))?, ); diff --git a/crates/openshell-supervisor-middleware/src/lib.rs b/crates/openshell-supervisor-middleware/src/lib.rs index 828179d18..ebc5817e4 100644 --- a/crates/openshell-supervisor-middleware/src/lib.rs +++ b/crates/openshell-supervisor-middleware/src/lib.rs @@ -15,9 +15,9 @@ pub use service::InProcessMiddlewareService; use openshell_core::proto::middleware::v1::supervisor_middleware_server::SupervisorMiddleware; use openshell_core::proto::{ - Decision, ExternalMiddlewareService, Finding, HttpRequestEvaluation, HttpRequestTarget, - MiddlewareBinding, MiddlewareManifest, NetworkMiddlewareConfig, RequestContext, SandboxPolicy, - ValidateConfigRequest, + Decision, Finding, HttpRequestEvaluation, HttpRequestTarget, MiddlewareBinding, + MiddlewareManifest, NetworkMiddlewareConfig, RequestContext, SandboxPolicy, + SupervisorMiddlewareService, ValidateConfigRequest, }; use tokio::sync::OnceCell; use tonic::Request; @@ -206,13 +206,13 @@ static IN_PROCESS_SERVICE: LazyLock> = LazyLock::new /// Validated middleware services available to a gateway or one supervisor. /// -/// The registry always contains the in-process built-ins. External services -/// are connected and described before construction succeeds, so callers never +/// The registry always contains the in-process built-ins. Operator-registered +/// services are connected and described before construction succeeds, so callers never /// observe a partially registered service set. #[derive(Clone)] pub struct MiddlewareRegistry { services: Arc>>, - external: Arc>, + registered_services: Arc>, } impl std::fmt::Debug for MiddlewareRegistry { @@ -220,14 +220,14 @@ impl std::fmt::Debug for MiddlewareRegistry { formatter .debug_struct("MiddlewareRegistry") .field("service_count", &self.services.len()) - .field("external_count", &self.external.len()) + .field("registered_service_count", &self.registered_services.len()) .finish() } } #[derive(Clone)] -struct RegisteredExternalService { - registration: ExternalMiddlewareService, +struct RegisteredMiddlewareService { + registration: SupervisorMiddlewareService, binding_ids: Vec, } @@ -235,15 +235,15 @@ impl Default for MiddlewareRegistry { fn default() -> Self { Self { services: Arc::new(vec![Arc::clone(&IN_PROCESS_SERVICE)]), - external: Arc::new(Vec::new()), + registered_services: Arc::new(Vec::new()), } } } -fn validate_registration(registration: &ExternalMiddlewareService) -> Result<()> { +fn validate_registration(registration: &SupervisorMiddlewareService) -> Result<()> { if registration.name.trim().is_empty() { return Err(miette!( - "external middleware registration name cannot be empty" + "supervisor middleware registration name cannot be empty" )); } if !registration.allow_insecure { @@ -268,7 +268,7 @@ fn validate_registration(registration: &ExternalMiddlewareService) -> Result<()> } fn validate_external_manifest( - registration: &ExternalMiddlewareService, + registration: &SupervisorMiddlewareService, manifest: &MiddlewareManifest, operator_max_body_bytes: usize, known_binding_ids: &mut HashSet, @@ -339,10 +339,10 @@ fn validate_external_manifest( } impl MiddlewareRegistry { - /// Connect and validate every external service registration. - pub async fn connect_external(registrations: Vec) -> Result { + /// Connect and validate every operator-provided service registration. + pub async fn connect_services(registrations: Vec) -> Result { let mut services = vec![Arc::clone(&IN_PROCESS_SERVICE)]; - let mut external = Vec::with_capacity(registrations.len()); + let mut registered_services = Vec::with_capacity(registrations.len()); let mut registration_names = HashSet::new(); let mut binding_ids = HashSet::from([BUILTIN_SECRETS.to_string()]); @@ -350,7 +350,7 @@ impl MiddlewareRegistry { validate_registration(®istration)?; if !registration_names.insert(registration.name.clone()) { return Err(miette!( - "duplicate external middleware registration name '{}'", + "duplicate supervisor middleware registration name '{}'", registration.name )); } @@ -395,7 +395,7 @@ impl MiddlewareRegistry { manifest: manifest_cell, operator_max_body_bytes: Some(operator_max_body_bytes), })); - external.push(RegisteredExternalService { + registered_services.push(RegisteredMiddlewareService { registration, binding_ids: described_ids, }); @@ -403,7 +403,7 @@ impl MiddlewareRegistry { Ok(Self { services: Arc::new(services), - external: Arc::new(external), + registered_services: Arc::new(registered_services), }) } @@ -433,7 +433,7 @@ impl MiddlewareRegistry { pub fn ensure_policy_bindings_registered(&self, policy: &SandboxPolicy) -> Result<()> { for config in &policy.network_middlewares { let registered = config.middleware == BUILTIN_SECRETS - || self.external.iter().any(|service| { + || self.registered_services.iter().any(|service| { service .binding_ids .iter() @@ -450,11 +450,11 @@ impl MiddlewareRegistry { Ok(()) } - /// Return only external services referenced by the effective policy. - pub fn required_external_services( + /// Return only operator-registered services referenced by the effective policy. + pub fn required_services( &self, policy: Option<&SandboxPolicy>, - ) -> Vec { + ) -> Vec { let Some(policy) = policy else { return Vec::new(); }; @@ -463,7 +463,7 @@ impl MiddlewareRegistry { .iter() .map(|config| config.middleware.as_str()) .collect(); - self.external + self.registered_services .iter() .filter(|service| { service @@ -491,7 +491,7 @@ impl ChainRunner { manifest: OnceCell::new(), operator_max_body_bytes: None, })]), - external: Arc::new(Vec::new()), + registered_services: Arc::new(Vec::new()), }), } } @@ -1126,8 +1126,8 @@ mod tests { } } - fn external_registration(max_body_bytes: u64) -> ExternalMiddlewareService { - ExternalMiddlewareService { + fn external_registration(max_body_bytes: u64) -> SupervisorMiddlewareService { + SupervisorMiddlewareService { name: "local-guard-service".into(), endpoint: "http://127.0.0.1:50051".into(), allow_insecure: true, @@ -1137,7 +1137,7 @@ mod tests { async fn registry_with_external( service: Arc, - registration: ExternalMiddlewareService, + registration: SupervisorMiddlewareService, ) -> MiddlewareRegistry { let manifest = service .describe(Request::new(())) @@ -1164,7 +1164,7 @@ mod tests { operator_max_body_bytes: Some(operator_max_body_bytes), }), ]), - external: Arc::new(vec![RegisteredExternalService { + registered_services: Arc::new(vec![RegisteredMiddlewareService { registration, binding_ids, }]), @@ -1294,7 +1294,7 @@ mod tests { let mut registration = external_registration(1024); registration.endpoint = format!("http://{address}"); - let registry = MiddlewareRegistry::connect_external(vec![registration.clone()]) + let registry = MiddlewareRegistry::connect_services(vec![registration.clone()]) .await .expect("connect external middleware"); let policy = SandboxPolicy { @@ -1313,7 +1313,7 @@ mod tests { .await .expect("remote config validates"); assert_eq!( - registry.required_external_services(Some(&policy)), + registry.required_services(Some(&policy)), vec![registration] ); diff --git a/docs/reference/gateway-config.mdx b/docs/reference/gateway-config.mdx index dc4659410..59ce8a544 100644 --- a/docs/reference/gateway-config.mdx +++ b/docs/reference/gateway-config.mdx @@ -148,9 +148,13 @@ Local Docker, Podman, and VM gateways can also set `[openshell.gateway.mtls_auth `[openshell.gateway.auth] allow_unauthenticated_users = true` is an unsafe local-development and trusted-proxy escape hatch. It accepts user-facing CLI/API calls without OIDC or mTLS credentials while sandbox supervisors still authenticate with gateway-minted sandbox JWTs. Leave it false for shared and production gateways. -## External Supervisor Middleware +## Supervisor Middleware Services -Register external supervisor middleware with one or more `[[openshell.gateway.middleware]]` entries. Registration is static and operator-owned; changing it requires restarting the gateway. + +Supervisor middleware is a research preview. Its policy and service contracts may change without compatibility guarantees. Use it only to prototype and evaluate middleware, not for production or long-lived integrations. + + +Register operator-run supervisor middleware services with one or more `[[openshell.gateway.middleware]]` entries. Registration is static and operator-owned; changing it requires restarting the gateway. ```toml [[openshell.gateway.middleware]] @@ -160,13 +164,13 @@ allow_insecure = true max_body_bytes = 262144 ``` -Each service implements the supervisor middleware gRPC contract and may expose multiple binding IDs through `Describe`. Policies reference those binding IDs, not the registration `name`. The gateway rejects duplicate binding IDs across services and prevents external services from claiming the reserved `openshell/` namespace. +Each service implements the supervisor middleware gRPC contract and may expose multiple binding IDs through `Describe`. Policies reference those binding IDs, not the registration `name`. The gateway rejects duplicate binding IDs across services and prevents operator-run services from claiming the reserved `openshell/` namespace. The gateway connects to every registered service and validates `Describe` before it starts. The service must therefore be running before the gateway. Policy creation and full policy updates call `ValidateConfig`; an unavailable service or invalid middleware configuration rejects the policy before persistence. `max_body_bytes` is the operator limit for every binding exposed by the service. It must be greater than zero and no larger than each binding's advertised limit. OpenShell rejects an oversized value instead of silently clamping it. -External middleware is a local-development preview. The endpoint must use plaintext `http://`, and `allow_insecure = true` is required as an explicit acknowledgement that inspected request content is sent without transport encryption or peer authentication. TLS, authentication, health checks, and runtime registration are not supported. The endpoint must be reachable from both the gateway and sandbox supervisors; use `host.openshell.internal` or another shared address when both runtimes resolve it. +The service endpoint must use plaintext `http://`, and `allow_insecure = true` is required as an explicit acknowledgement that inspected request content is sent without transport encryption or peer authentication. TLS, authentication, health checks, and runtime registration are not supported. The endpoint must be reachable from both the gateway and sandbox supervisors; use `host.openshell.internal` or another shared address when both runtimes resolve it. `image_pull_policy` is intentionally not a shared gateway key. Kubernetes and Docker use `Always`, `IfNotPresent`, or `Never`. Podman uses `always`, `missing`, `never`, or `newer`. Set it inside the relevant driver table. diff --git a/docs/reference/policy-schema.mdx b/docs/reference/policy-schema.mdx index 5cd2070f2..02c905405 100644 --- a/docs/reference/policy-schema.mdx +++ b/docs/reference/policy-schema.mdx @@ -472,6 +472,10 @@ Identifies an executable that is permitted to use the associated endpoints. ## Network Middleware + +Supervisor middleware is a research preview. Its policy and service contracts may change without compatibility guarantees. Use it only to prototype and evaluate middleware, not for production or long-lived integrations. + + **Category:** Dynamic An ordered list of middleware configs selected after network and L7 policy admit an HTTP request. Middleware selection is independent of the network policy entry that admitted the request. Every matching config runs once in list order before provider credential injection. diff --git a/docs/sandboxes/policies.mdx b/docs/sandboxes/policies.mdx index cf61989e8..ce4425c20 100644 --- a/docs/sandboxes/policies.mdx +++ b/docs/sandboxes/policies.mdx @@ -72,6 +72,10 @@ Raw streams are connection-scoped and outside L7 live-reload guarantees. This in ## Supervisor Middleware + +Supervisor middleware is a research preview. Its policy and service contracts may change without compatibility guarantees. Use it only to prototype and evaluate middleware, not for production or long-lived integrations. + + Supervisor middleware can inspect, deny, or replace admitted HTTP request bodies before provider credentials are injected. Middleware selection is independent of the `network_policies` rule that admitted the request: each `network_middlewares` entry matches the destination host through `endpoints.include` and `endpoints.exclude`. ```yaml @@ -88,7 +92,7 @@ network_middlewares: Matching entries run once each in top-level declaration order. Config names must be unique. Different config names may use the same implementation and run as distinct stages. `exclude` takes precedence over `include`. -`openshell/secrets` is built into the supervisor. External binding IDs must be registered by the gateway operator before a policy can reference them; see [External Supervisor Middleware](/reference/gateway-config#external-supervisor-middleware). The gateway calls the implementation's `ValidateConfig` before accepting the policy. +`openshell/secrets` is built into the supervisor. Operator-provided binding IDs must be registered before a policy can reference them; see [Supervisor Middleware Services](/reference/gateway-config#supervisor-middleware-services). The gateway calls the implementation's `ValidateConfig` before accepting the policy. `on_error` defaults to `fail_closed`. Use `fail_open` only when skipping a failed middleware is acceptable. Middleware applies only to HTTP traffic the supervisor can parse and inspect; policy validation rejects a required selector that can cover a `tls: skip` endpoint. diff --git a/proto/sandbox.proto b/proto/sandbox.proto index afec58723..644fd86cb 100644 --- a/proto/sandbox.proto +++ b/proto/sandbox.proto @@ -352,14 +352,14 @@ message GetSandboxConfigResponse { // Fingerprint for provider credential inputs attached to this sandbox. // Changes when attached provider names or attached provider records change. uint64 provider_env_revision = 8; - // Operator-registered external middleware services required by the effective - // policy. Built-in middleware is not included. - repeated ExternalMiddlewareService external_middleware = 9; + // Operator-registered supervisor middleware services required by the + // effective policy. Built-in middleware is not included. + repeated SupervisorMiddlewareService supervisor_middleware_services = 9; } -// Connection details for one operator-registered external middleware service. +// Connection details for one operator-registered supervisor middleware service. // V1 supports only explicitly enabled plaintext gRPC for local development. -message ExternalMiddlewareService { +message SupervisorMiddlewareService { // Operator-facing registration name used for diagnostics. string name = 1; // gRPC endpoint reachable from the sandbox supervisor. From 3a94baad60429887a2630685c9738532a7cb1c44 Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Mon, 29 Jun 2026 20:18:37 -0700 Subject: [PATCH 09/16] docs(supervisor-middleware): refine preview warning Signed-off-by: Piotr Mlocek --- docs/reference/gateway-config.mdx | 2 +- docs/reference/policy-schema.mdx | 2 +- docs/sandboxes/policies.mdx | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/reference/gateway-config.mdx b/docs/reference/gateway-config.mdx index 59ce8a544..e9d2e579e 100644 --- a/docs/reference/gateway-config.mdx +++ b/docs/reference/gateway-config.mdx @@ -151,7 +151,7 @@ Local Docker, Podman, and VM gateways can also set `[openshell.gateway.mtls_auth ## Supervisor Middleware Services -Supervisor middleware is a research preview. Its policy and service contracts may change without compatibility guarantees. Use it only to prototype and evaluate middleware, not for production or long-lived integrations. +Supervisor middleware is a research preview. Its policy and service contracts may change without compatibility guarantees. Use it only to prototype and evaluate middleware integrations. Register operator-run supervisor middleware services with one or more `[[openshell.gateway.middleware]]` entries. Registration is static and operator-owned; changing it requires restarting the gateway. diff --git a/docs/reference/policy-schema.mdx b/docs/reference/policy-schema.mdx index 02c905405..1d24b28df 100644 --- a/docs/reference/policy-schema.mdx +++ b/docs/reference/policy-schema.mdx @@ -473,7 +473,7 @@ Identifies an executable that is permitted to use the associated endpoints. ## Network Middleware -Supervisor middleware is a research preview. Its policy and service contracts may change without compatibility guarantees. Use it only to prototype and evaluate middleware, not for production or long-lived integrations. +Supervisor middleware is a research preview. Its policy and service contracts may change without compatibility guarantees. Use it only to prototype and evaluate middleware integrations. **Category:** Dynamic diff --git a/docs/sandboxes/policies.mdx b/docs/sandboxes/policies.mdx index ce4425c20..ac089ac94 100644 --- a/docs/sandboxes/policies.mdx +++ b/docs/sandboxes/policies.mdx @@ -73,7 +73,7 @@ Raw streams are connection-scoped and outside L7 live-reload guarantees. This in ## Supervisor Middleware -Supervisor middleware is a research preview. Its policy and service contracts may change without compatibility guarantees. Use it only to prototype and evaluate middleware, not for production or long-lived integrations. +Supervisor middleware is a research preview. Its policy and service contracts may change without compatibility guarantees. Use it only to prototype and evaluate middleware integrations. Supervisor middleware can inspect, deny, or replace admitted HTTP request bodies before provider credentials are injected. Middleware selection is independent of the `network_policies` rule that admitted the request: each `network_middlewares` entry matches the destination host through `endpoints.include` and `endpoints.exclude`. From 00f29a113dbc80d9d480c408bf588ce280bafb79 Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Tue, 30 Jun 2026 10:22:43 -0700 Subject: [PATCH 10/16] docs(extensibility): add supervisor middleware guide Signed-off-by: Piotr Mlocek --- docs/extensibility/supervisor-middleware.mdx | 141 +++++++++++++++++++ docs/index.yml | 2 + docs/reference/gateway-config.mdx | 6 +- docs/reference/policy-schema.mdx | 6 +- docs/sandboxes/policies.mdx | 8 +- 5 files changed, 150 insertions(+), 13 deletions(-) create mode 100644 docs/extensibility/supervisor-middleware.mdx diff --git a/docs/extensibility/supervisor-middleware.mdx b/docs/extensibility/supervisor-middleware.mdx new file mode 100644 index 000000000..320495e59 --- /dev/null +++ b/docs/extensibility/supervisor-middleware.mdx @@ -0,0 +1,141 @@ +--- +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +title: "Supervisor Middleware" +sidebar-title: "Supervisor Middleware" +description: "Configure and operate built-in and operator-run middleware for sandbox HTTP requests." +keywords: "Generative AI, Cybersecurity, AI Agents, Supervisor Middleware, Extensibility, Request Filtering" +--- + +Supervisor middleware adds ordered request-processing stages to allowed HTTP egress. Middleware runs after network and L7 policy admit a request and before OpenShell injects provider credentials. A stage can allow or deny the request, replace its body, add approved headers, and report audit-safe findings. + +Middleware selection is independent of the network policy rule that admitted the request. OpenShell matches middleware by destination host, so the same middleware applies consistently across broad, specific, user-authored, and provider-derived network policies. + +## Request Flow + +For each inspected HTTP request, the supervisor: + +1. Evaluates network and L7 policy. +2. Selects middleware whose host selectors match the admitted destination. +3. Buffers the request body using the smallest body limit in the selected chain. +4. Runs matching middleware in policy declaration order. +5. Applies allowed transformations, injects provider credentials, and forwards the request. + +Middleware receives the request before credential injection. Operator-run services cannot inspect OpenShell-managed credentials. + +## Choose a Middleware Type + +| Type | Registration | Body limit | Deployment | +| --- | --- | --- | --- | +| Built-in | None | Defined by OpenShell | Runs inside the supervisor | +| Operator-run service | Required in gateway TOML | Set by the operator, up to the service capability | Runs as a separate service reachable by the gateway and supervisors | + +`openshell/secrets` is the built-in middleware currently available. It identifies common secret patterns in UTF-8 request bodies and replaces matched values before the request leaves the sandbox. + +Operator-run services expose one or more binding IDs. Policies reference a binding ID, such as `example/content-guard`, rather than the gateway registration name. + +## Register a Middleware Service + +Start an operator-run service before starting the gateway, then add a registration to the local gateway TOML: + +```toml +[[openshell.gateway.middleware]] +name = "local-content-guard" +endpoint = "http://host.openshell.internal:50051" +allow_insecure = true +max_body_bytes = 262144 +``` + +| Field | Description | +| --- | --- | +| `name` | Operator-facing registration name used in diagnostics. Policies do not reference this value. | +| `endpoint` | Service address reachable from both the gateway and sandbox supervisors. | +| `allow_insecure` | Required acknowledgement for the currently supported plaintext endpoint. | +| `max_body_bytes` | Operator limit applied to every binding exposed by the service. | + +The gateway connects to every registered service and verifies its capabilities before accepting traffic. Gateway startup fails when a service is unavailable, reports an invalid capability, or exposes a binding ID already owned by another service. Operator-run services cannot claim the reserved `openshell/` namespace. + +Registration is static. Restart the gateway after adding, removing, or changing a service. See [Gateway Configuration](/reference/gateway-config#supervisor-middleware-services) for the complete gateway TOML context. + +## Apply Middleware with Policy + +Add middleware configs to the top-level `network_middlewares` list: + +```yaml +network_middlewares: + - name: redact-secrets + middleware: openshell/secrets + config: + secrets: redact + on_error: fail_closed + endpoints: + include: ["*.example.com"] + exclude: ["trusted.example.com"] +``` + +Each config has a policy-local `name`, a built-in or operator-provided binding ID in `middleware`, implementation-owned `config`, failure behavior, and host selectors. + +`include` selects destination hosts. `exclude` takes precedence and removes hosts from that selection. Matching is case-insensitive and uses the same exact-host and DNS glob behavior as network policy endpoints. + +Matching configs run once each in top-level declaration order. Different config names may reference the same binding and run as separate stages. Config names must be unique. + +See [Policy Schema](/reference/policy-schema#network-middleware) for the complete field reference. + +## Configure Failure Behavior + +`on_error` controls what happens when middleware is unavailable, rejects its configuration, returns an invalid result, or exceeds a body limit. + +| Value | Behavior | +| --- | --- | +| `fail_closed` | Denies the request when the middleware stage fails. This is the default. | +| `fail_open` | Skips the failed stage and continues the request through the remaining chain. | + +Use `fail_open` only when bypassing the middleware preserves the intended security policy. OpenShell emits a detection finding when a failed stage is bypassed. + +An explicit deny decision always stops the chain and denies the request, regardless of `on_error`. + +## Set Body Limits + +Every middleware binding declares the largest request or replacement body it supports. + +- Built-in middleware uses its OpenShell-defined limit. +- Each operator-run registration sets `max_body_bytes` no higher than the service capability. +- A selected chain buffers using its smallest stage limit. +- The same per-stage limit applies to request bodies and replacement bodies. + +The gateway rejects a registration whose operator limit exceeds the service capability instead of silently clamping it. At request time, exceeding a selected stage's limit is a middleware failure and follows that config's `on_error` behavior. + +## Operate Middleware Services + +Plan startup and updates around these boundaries: + +- Start registered services before the gateway. The gateway validates every registration during startup. +- Keep service endpoints reachable from both the gateway and sandbox supervisors. The supervisors call operator-run services directly on the request path. +- Restart the gateway after changing registrations. +- Keep required services available before creating or updating policies. The gateway validates implementation-owned config before persisting a policy. +- Treat `fail_open` as an explicit availability-over-enforcement decision. + +When the effective sandbox configuration changes, a running supervisor validates the new service registry before installing it. If the reload fails, the supervisor keeps its last-known-good registry and emits a configuration failure event. + +## Observe Middleware + +Middleware activity is emitted through OpenShell's OCSF logging: + +- Each invocation records its policy-local middleware name, binding, decision, transformation state, and failure state. +- A bypass under `fail_open` emits a detection finding. +- A required stage that fails closed emits a high-severity detection finding. +- Findings include the service-provided type and label plus aggregate counts. Middleware services should keep those fields audit-safe and omit request content or matched values. +- Registry reload success and failure are emitted as configuration state changes. + +See [Logging](/observability/logging) for log access and [OCSF JSON Export](/observability/ocsf-json-export) for structured export. + +## Current Limitations + +- Middleware applies only to HTTP requests parsed by the supervisor. +- The supported operation and phase are `HttpRequest/pre_credentials`. +- Selection uses destination host include and exclude patterns. +- Required middleware cannot cover `tls: skip` endpoints because OpenShell cannot inspect that traffic. +- Operator-run services currently use explicitly enabled plaintext `http://` endpoints. +- TLS, service authentication, health checks, and runtime registration are not available. + +For a runnable operator workflow, see the [content guard example](https://github.com/NVIDIA/OpenShell/tree/main/examples/supervisor-middleware-content-guard). diff --git a/docs/index.yml b/docs/index.yml index b2443e4af..45db451a7 100644 --- a/docs/index.yml +++ b/docs/index.yml @@ -19,6 +19,8 @@ navigation: title: "Manage OpenShell" - folder: providers title: "Providers" +- folder: extensibility + title: "Extensibility" - folder: observability title: "Observability" - folder: kubernetes diff --git a/docs/reference/gateway-config.mdx b/docs/reference/gateway-config.mdx index e9d2e579e..9e9c44e97 100644 --- a/docs/reference/gateway-config.mdx +++ b/docs/reference/gateway-config.mdx @@ -150,10 +150,6 @@ Local Docker, Podman, and VM gateways can also set `[openshell.gateway.mtls_auth ## Supervisor Middleware Services - -Supervisor middleware is a research preview. Its policy and service contracts may change without compatibility guarantees. Use it only to prototype and evaluate middleware integrations. - - Register operator-run supervisor middleware services with one or more `[[openshell.gateway.middleware]]` entries. Registration is static and operator-owned; changing it requires restarting the gateway. ```toml @@ -172,6 +168,8 @@ The gateway connects to every registered service and validates `Describe` before The service endpoint must use plaintext `http://`, and `allow_insecure = true` is required as an explicit acknowledgement that inspected request content is sent without transport encryption or peer authentication. TLS, authentication, health checks, and runtime registration are not supported. The endpoint must be reachable from both the gateway and sandbox supervisors; use `host.openshell.internal` or another shared address when both runtimes resolve it. +See [Supervisor Middleware](/extensibility/supervisor-middleware) for selection, failure, body-limit, and operational guidance. + `image_pull_policy` is intentionally not a shared gateway key. Kubernetes and Docker use `Always`, `IfNotPresent`, or `Never`. Podman uses `always`, `missing`, `never`, or `newer`. Set it inside the relevant driver table. ## Driver References diff --git a/docs/reference/policy-schema.mdx b/docs/reference/policy-schema.mdx index 1d24b28df..3ed651564 100644 --- a/docs/reference/policy-schema.mdx +++ b/docs/reference/policy-schema.mdx @@ -472,10 +472,6 @@ Identifies an executable that is permitted to use the associated endpoints. ## Network Middleware - -Supervisor middleware is a research preview. Its policy and service contracts may change without compatibility guarantees. Use it only to prototype and evaluate middleware integrations. - - **Category:** Dynamic An ordered list of middleware configs selected after network and L7 policy admit an HTTP request. Middleware selection is independent of the network policy entry that admitted the request. Every matching config runs once in list order before provider credential injection. @@ -502,6 +498,8 @@ network_middlewares: Host selectors use the same case-insensitive exact and DNS glob semantics as network endpoints. Middleware runs only on HTTP requests the supervisor parses. A selector that can require middleware on a `tls: skip` endpoint is rejected because OpenShell cannot inspect that traffic. +See [Supervisor Middleware](/extensibility/supervisor-middleware) for registration, failure behavior, body limits, and operational guidance. + ## Full Example The following policy grants read-only GitHub API access and npm registry access: diff --git a/docs/sandboxes/policies.mdx b/docs/sandboxes/policies.mdx index ac089ac94..ee07dd7a7 100644 --- a/docs/sandboxes/policies.mdx +++ b/docs/sandboxes/policies.mdx @@ -72,10 +72,6 @@ Raw streams are connection-scoped and outside L7 live-reload guarantees. This in ## Supervisor Middleware - -Supervisor middleware is a research preview. Its policy and service contracts may change without compatibility guarantees. Use it only to prototype and evaluate middleware integrations. - - Supervisor middleware can inspect, deny, or replace admitted HTTP request bodies before provider credentials are injected. Middleware selection is independent of the `network_policies` rule that admitted the request: each `network_middlewares` entry matches the destination host through `endpoints.include` and `endpoints.exclude`. ```yaml @@ -92,10 +88,12 @@ network_middlewares: Matching entries run once each in top-level declaration order. Config names must be unique. Different config names may use the same implementation and run as distinct stages. `exclude` takes precedence over `include`. -`openshell/secrets` is built into the supervisor. Operator-provided binding IDs must be registered before a policy can reference them; see [Supervisor Middleware Services](/reference/gateway-config#supervisor-middleware-services). The gateway calls the implementation's `ValidateConfig` before accepting the policy. +`openshell/secrets` is built into the supervisor. Operator-provided binding IDs must be registered before a policy can reference them. The gateway validates implementation-owned config before accepting the policy. `on_error` defaults to `fail_closed`. Use `fail_open` only when skipping a failed middleware is acceptable. Middleware applies only to HTTP traffic the supervisor can parse and inspect; policy validation rejects a required selector that can cover a `tls: skip` endpoint. +See [Supervisor Middleware](/extensibility/supervisor-middleware) for registration, chain ordering, body limits, failure behavior, and operations. + ## Baseline Filesystem Paths When a sandbox runs in proxy mode (the default), OpenShell automatically adds baseline filesystem paths required for the sandbox child process to function: `/usr`, `/lib`, `/etc`, `/var/log` (read-only) and `/sandbox`, `/tmp` (read-write). Paths like `/app` are included in the baseline set but are only added if they exist in the container image. From 1fbcdbc347f55905ec764ab3810be2b5d39328fb Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Tue, 30 Jun 2026 12:41:09 -0700 Subject: [PATCH 11/16] fix(server): remove stale middleware import Signed-off-by: Piotr Mlocek --- crates/openshell-server/src/lib.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/openshell-server/src/lib.rs b/crates/openshell-server/src/lib.rs index bd974d4b4..86afe4ef4 100644 --- a/crates/openshell-server/src/lib.rs +++ b/crates/openshell-server/src/lib.rs @@ -55,7 +55,6 @@ mod ws_tunnel; use metrics_exporter_prometheus::PrometheusBuilder; use openshell_core::{ComputeDriverKind, Config, Error, Result}; use openshell_supervisor_middleware::MiddlewareRegistry; -use serde::Deserialize; use std::collections::HashMap; use std::io::ErrorKind; use std::net::SocketAddr; From df8ea82e8f19817d443bfcfda174cb85bf48d4d3 Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Tue, 30 Jun 2026 13:26:06 -0700 Subject: [PATCH 12/16] fix(network): remove needless test struct updates Signed-off-by: Piotr Mlocek --- crates/openshell-supervisor-network/src/opa.rs | 2 -- 1 file changed, 2 deletions(-) diff --git a/crates/openshell-supervisor-network/src/opa.rs b/crates/openshell-supervisor-network/src/opa.rs index 9e1427dcd..f0aea9c38 100644 --- a/crates/openshell-supervisor-network/src/opa.rs +++ b/crates/openshell-supervisor-network/src/opa.rs @@ -6698,7 +6698,6 @@ network_policies: path: link_path, ..Default::default() }], - ..Default::default() }, ); let proto = ProtoSandboxPolicy { @@ -6777,7 +6776,6 @@ network_policies: path: link_path, ..Default::default() }], - ..Default::default() }, ); let proto = ProtoSandboxPolicy { From 8d5b4b3b2ebe48a3138c7aea9f061168029b5d10 Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Tue, 30 Jun 2026 13:48:26 -0700 Subject: [PATCH 13/16] fix(middleware): avoid enabling core telemetry Signed-off-by: Piotr Mlocek --- crates/openshell-supervisor-middleware/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/openshell-supervisor-middleware/Cargo.toml b/crates/openshell-supervisor-middleware/Cargo.toml index e5e53618d..0af6e70c2 100644 --- a/crates/openshell-supervisor-middleware/Cargo.toml +++ b/crates/openshell-supervisor-middleware/Cargo.toml @@ -11,7 +11,7 @@ repository.workspace = true rust-version.workspace = true [dependencies] -openshell-core = { path = "../openshell-core" } +openshell-core = { path = "../openshell-core", default-features = false } miette = { workspace = true } prost-types = { workspace = true } From 7f887cc9906c2de91304e85bdef26d3dfa3f76db Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Tue, 30 Jun 2026 14:47:35 -0700 Subject: [PATCH 14/16] refactor(supervisor-middleware): simplify service endpoints Signed-off-by: Piotr Mlocek --- crates/openshell-server/src/config_file.rs | 14 ++-- crates/openshell-server/src/grpc/policy.rs | 3 +- .../Cargo.toml | 2 +- .../src/lib.rs | 39 ++++++----- .../src/remote.rs | 23 +++++-- docs/extensibility/supervisor-middleware.mdx | 10 ++- docs/reference/gateway-config.mdx | 12 ++-- proto/middleware.proto | 66 +++++++++++++++++++ proto/sandbox.proto | 10 +-- 9 files changed, 127 insertions(+), 52 deletions(-) diff --git a/crates/openshell-server/src/config_file.rs b/crates/openshell-server/src/config_file.rs index 4b0fbc919..7306c80e7 100644 --- a/crates/openshell-server/src/config_file.rs +++ b/crates/openshell-server/src/config_file.rs @@ -174,10 +174,7 @@ pub struct MiddlewareServiceFileConfig { /// Operator-facing name used for diagnostics. pub name: String, /// Plaintext gRPC endpoint reachable by the gateway and supervisors. - pub endpoint: String, - /// Required explicit opt-in to the local-development-only insecure mode. - #[serde(default)] - pub allow_insecure: bool, + pub grpc_endpoint: String, /// Operator-owned body limit for every binding exposed by this service. pub max_body_bytes: u64, } @@ -186,8 +183,7 @@ impl From<&MiddlewareServiceFileConfig> for SupervisorMiddlewareService { fn from(config: &MiddlewareServiceFileConfig) -> Self { Self { name: config.name.clone(), - endpoint: config.endpoint.clone(), - allow_insecure: config.allow_insecure, + grpc_endpoint: config.grpc_endpoint.clone(), max_body_bytes: config.max_body_bytes, } } @@ -439,8 +435,7 @@ allow_unauthenticated_users = true let toml = r#" [[openshell.gateway.middleware]] name = "local-guard" -endpoint = "http://127.0.0.1:50051" -allow_insecure = true +grpc_endpoint = "http://127.0.0.1:50051" max_body_bytes = 262144 "#; let tmp = write_tmp(toml); @@ -449,8 +444,7 @@ max_body_bytes = 262144 file.openshell.gateway.middleware, vec![MiddlewareServiceFileConfig { name: "local-guard".into(), - endpoint: "http://127.0.0.1:50051".into(), - allow_insecure: true, + grpc_endpoint: "http://127.0.0.1:50051".into(), max_body_bytes: 262_144, }] ); diff --git a/crates/openshell-server/src/grpc/policy.rs b/crates/openshell-server/src/grpc/policy.rs index 9587e7295..97d459813 100644 --- a/crates/openshell-server/src/grpc/policy.rs +++ b/crates/openshell-server/src/grpc/policy.rs @@ -9331,8 +9331,7 @@ mod tests { let settings = HashMap::new(); let service = openshell_core::proto::SupervisorMiddlewareService { name: "local-guard".into(), - endpoint: "http://127.0.0.1:50051".into(), - allow_insecure: true, + grpc_endpoint: "http://127.0.0.1:50051".into(), max_body_bytes: 1024, }; diff --git a/crates/openshell-supervisor-middleware/Cargo.toml b/crates/openshell-supervisor-middleware/Cargo.toml index 0af6e70c2..f36fc854b 100644 --- a/crates/openshell-supervisor-middleware/Cargo.toml +++ b/crates/openshell-supervisor-middleware/Cargo.toml @@ -17,7 +17,7 @@ miette = { workspace = true } prost-types = { workspace = true } regex = { workspace = true } tokio = { workspace = true } -tonic = { workspace = true, features = ["channel", "server"] } +tonic = { workspace = true, features = ["channel", "server", "tls-native-roots"] } [dev-dependencies] tokio-stream = { workspace = true, features = ["net"] } diff --git a/crates/openshell-supervisor-middleware/src/lib.rs b/crates/openshell-supervisor-middleware/src/lib.rs index ebc5817e4..00324d75f 100644 --- a/crates/openshell-supervisor-middleware/src/lib.rs +++ b/crates/openshell-supervisor-middleware/src/lib.rs @@ -246,15 +246,11 @@ fn validate_registration(registration: &SupervisorMiddlewareService) -> Result<( "supervisor middleware registration name cannot be empty" )); } - if !registration.allow_insecure { - return Err(miette!( - "middleware registration '{}' must set allow_insecure = true; TLS is not supported in V1", - registration.name - )); - } - if !registration.endpoint.starts_with("http://") { + if !registration.grpc_endpoint.starts_with("http://") + && !registration.grpc_endpoint.starts_with("https://") + { return Err(miette!( - "middleware registration '{}' endpoint must use http:// in the local-development-only V1", + "middleware registration '{}' grpc_endpoint must use http:// or https://", registration.name )); } @@ -365,7 +361,7 @@ impl MiddlewareRegistry { let service = Arc::new( remote::RemoteMiddlewareService::connect( ®istration.name, - ®istration.endpoint, + ®istration.grpc_endpoint, ) .await?, ); @@ -1129,8 +1125,7 @@ mod tests { fn external_registration(max_body_bytes: u64) -> SupervisorMiddlewareService { SupervisorMiddlewareService { name: "local-guard-service".into(), - endpoint: "http://127.0.0.1:50051".into(), - allow_insecure: true, + grpc_endpoint: "http://127.0.0.1:50051".into(), max_body_bytes, } } @@ -1267,11 +1262,23 @@ mod tests { } #[test] - fn external_registration_requires_explicit_insecure_opt_in() { + fn external_registration_accepts_http_and_https_grpc_endpoints() { + for grpc_endpoint in [ + "http://127.0.0.1:50051", + "https://middleware.example.com:443", + ] { + let mut registration = external_registration(4096); + registration.grpc_endpoint = grpc_endpoint.into(); + validate_registration(®istration).expect("supported gRPC endpoint scheme"); + } + } + + #[test] + fn external_registration_rejects_unsupported_grpc_endpoint_scheme() { let mut registration = external_registration(4096); - registration.allow_insecure = false; - let error = validate_registration(®istration).expect_err("opt-in required"); - assert!(error.to_string().contains("allow_insecure")); + registration.grpc_endpoint = "ftp://middleware.example.com".into(); + let error = validate_registration(®istration).expect_err("unsupported scheme"); + assert!(error.to_string().contains("http:// or https://")); } #[tokio::test] @@ -1293,7 +1300,7 @@ mod tests { let server_task = tokio::spawn(server); let mut registration = external_registration(1024); - registration.endpoint = format!("http://{address}"); + registration.grpc_endpoint = format!("http://{address}"); let registry = MiddlewareRegistry::connect_services(vec![registration.clone()]) .await .expect("connect external middleware"); diff --git a/crates/openshell-supervisor-middleware/src/remote.rs b/crates/openshell-supervisor-middleware/src/remote.rs index dd147788b..7645ed811 100644 --- a/crates/openshell-supervisor-middleware/src/remote.rs +++ b/crates/openshell-supervisor-middleware/src/remote.rs @@ -10,7 +10,7 @@ use openshell_core::proto::{ HttpRequestEvaluation, HttpRequestResult, MiddlewareManifest, ValidateConfigRequest, ValidateConfigResponse, }; -use tonic::transport::{Channel, Endpoint}; +use tonic::transport::{Channel, ClientTlsConfig, Endpoint}; use tonic::{Request, Response, Status}; const CONNECT_TIMEOUT: Duration = Duration::from_secs(5); @@ -23,19 +23,30 @@ pub struct RemoteMiddlewareService { } impl RemoteMiddlewareService { - pub async fn connect(registration_name: &str, endpoint: &str) -> Result { - let channel = Endpoint::from_shared(endpoint.to_string()) + pub async fn connect(registration_name: &str, grpc_endpoint: &str) -> Result { + let mut endpoint = Endpoint::from_shared(grpc_endpoint.to_string()) .into_diagnostic() .wrap_err_with(|| { - format!("middleware registration '{registration_name}' has an invalid endpoint") - })? + format!( + "middleware registration '{registration_name}' has an invalid grpc_endpoint" + ) + })?; + if grpc_endpoint.starts_with("https://") { + endpoint = endpoint + .tls_config(ClientTlsConfig::new().with_enabled_roots()) + .into_diagnostic() + .wrap_err_with(|| { + format!("middleware registration '{registration_name}' could not configure TLS") + })?; + } + let channel = endpoint .connect_timeout(CONNECT_TIMEOUT) .connect() .await .into_diagnostic() .wrap_err_with(|| { format!( - "middleware registration '{registration_name}' could not connect to {endpoint}" + "middleware registration '{registration_name}' could not connect to {grpc_endpoint}" ) })?; Ok(Self { diff --git a/docs/extensibility/supervisor-middleware.mdx b/docs/extensibility/supervisor-middleware.mdx index 320495e59..0ebd518bd 100644 --- a/docs/extensibility/supervisor-middleware.mdx +++ b/docs/extensibility/supervisor-middleware.mdx @@ -41,16 +41,14 @@ Start an operator-run service before starting the gateway, then add a registrati ```toml [[openshell.gateway.middleware]] name = "local-content-guard" -endpoint = "http://host.openshell.internal:50051" -allow_insecure = true +grpc_endpoint = "http://host.openshell.internal:50051" max_body_bytes = 262144 ``` | Field | Description | | --- | --- | | `name` | Operator-facing registration name used in diagnostics. Policies do not reference this value. | -| `endpoint` | Service address reachable from both the gateway and sandbox supervisors. | -| `allow_insecure` | Required acknowledgement for the currently supported plaintext endpoint. | +| `grpc_endpoint` | Service address reachable from both the gateway and sandbox supervisors. Supports plaintext `http://` and TLS `https://` with platform trust roots. | | `max_body_bytes` | Operator limit applied to every binding exposed by the service. | The gateway connects to every registered service and verifies its capabilities before accepting traffic. Gateway startup fails when a service is unavailable, reports an invalid capability, or exposes a binding ID already owned by another service. Operator-run services cannot claim the reserved `openshell/` namespace. @@ -135,7 +133,7 @@ See [Logging](/observability/logging) for log access and [OCSF JSON Export](/obs - The supported operation and phase are `HttpRequest/pre_credentials`. - Selection uses destination host include and exclude patterns. - Required middleware cannot cover `tls: skip` endpoints because OpenShell cannot inspect that traffic. -- Operator-run services currently use explicitly enabled plaintext `http://` endpoints. -- TLS, service authentication, health checks, and runtime registration are not available. +- Operator-run services support plaintext `http://` and TLS `https://` endpoints. HTTPS certificates must chain to a CA in the platform trust store. +- Custom trust roots, client authentication, health checks, and runtime registration are not available. For a runnable operator workflow, see the [content guard example](https://github.com/NVIDIA/OpenShell/tree/main/examples/supervisor-middleware-content-guard). diff --git a/docs/reference/gateway-config.mdx b/docs/reference/gateway-config.mdx index 9e9c44e97..592839775 100644 --- a/docs/reference/gateway-config.mdx +++ b/docs/reference/gateway-config.mdx @@ -103,12 +103,11 @@ guest_tls_key = "/etc/openshell/certs/client-key.pem" grpc_rate_limit_requests = 120 grpc_rate_limit_window_seconds = 60 -# Local-development-only external supervisor middleware. The endpoint must be -# reachable from both the gateway and sandbox supervisors. +# Operator-run supervisor middleware. The gRPC endpoint must be reachable from +# both the gateway and sandbox supervisors. [[openshell.gateway.middleware]] name = "local-content-guard" -endpoint = "http://host.openshell.internal:50051" -allow_insecure = true +grpc_endpoint = "http://host.openshell.internal:50051" max_body_bytes = 262144 # Gateway listener TLS (distinct from the per-driver guest_tls_*). @@ -155,8 +154,7 @@ Register operator-run supervisor middleware services with one or more `[[openshe ```toml [[openshell.gateway.middleware]] name = "local-content-guard" -endpoint = "http://host.openshell.internal:50051" -allow_insecure = true +grpc_endpoint = "http://host.openshell.internal:50051" max_body_bytes = 262144 ``` @@ -166,7 +164,7 @@ The gateway connects to every registered service and validates `Describe` before `max_body_bytes` is the operator limit for every binding exposed by the service. It must be greater than zero and no larger than each binding's advertised limit. OpenShell rejects an oversized value instead of silently clamping it. -The service endpoint must use plaintext `http://`, and `allow_insecure = true` is required as an explicit acknowledgement that inspected request content is sent without transport encryption or peer authentication. TLS, authentication, health checks, and runtime registration are not supported. The endpoint must be reachable from both the gateway and sandbox supervisors; use `host.openshell.internal` or another shared address when both runtimes resolve it. +The service `grpc_endpoint` currently supports plaintext `http://` and TLS `https://` using the platform trust store. Custom trust roots, client authentication, health checks, and runtime registration are not currently supported. The endpoint must be reachable from both the gateway and sandbox supervisors; use `host.openshell.internal` or another shared address that can be resolved in both places. See [Supervisor Middleware](/extensibility/supervisor-middleware) for selection, failure, body-limit, and operational guidance. diff --git a/proto/middleware.proto b/proto/middleware.proto index 2944227d8..9b988b930 100644 --- a/proto/middleware.proto +++ b/proto/middleware.proto @@ -8,90 +8,156 @@ package openshell.middleware.v1; import "google/protobuf/empty.proto"; import "google/protobuf/struct.proto"; +// SupervisorMiddleware lets an operator-run service inspect and transform +// sandbox HTTP egress before OpenShell injects credentials. service SupervisorMiddleware { + // Describe returns the service manifest and declared bindings. rpc Describe(google.protobuf.Empty) returns (MiddlewareManifest); + + // ValidateConfig checks service-specific configuration for one binding. rpc ValidateConfig(ValidateConfigRequest) returns (ValidateConfigResponse); + + // EvaluateHttpRequest returns an allow, deny, or mutation decision for one + // buffered HTTP request. rpc EvaluateHttpRequest(HttpRequestEvaluation) returns (HttpRequestResult); } +// MiddlewareManifest describes one service and the bindings it exposes. message MiddlewareManifest { + // Middleware protocol version implemented by the service. string api_version = 1; + // Human-readable service name used for diagnostics. string name = 2; + // Service-defined version string used for diagnostics. string service_version = 3; + // Bindings exposed by this service. repeated MiddlewareBinding bindings = 4; } +// MiddlewareBinding declares one operation and phase supported by a service. message MiddlewareBinding { + // Stable binding id used by policy configuration and audit logs. string id = 1; + // Supported operation name. V1 supports "HttpRequest". string operation = 2; + // Supported evaluation phase. V1 supports "pre_credentials". string phase = 3; // Maximum request or replacement body this binding can process. uint64 max_body_bytes = 4; } +// ValidateConfigRequest contains one policy configuration to validate. message ValidateConfigRequest { + // Middleware protocol version selected by OpenShell. string api_version = 1; + // Manifest binding id associated with this configuration. string binding_id = 2; + // Service-specific policy configuration. google.protobuf.Struct config = 3; } +// ValidateConfigResponse reports whether a policy configuration is accepted. message ValidateConfigResponse { + // True when the service accepts the configuration. bool valid = 1; + // Human-readable validation failure reason. Empty when valid is true. string reason = 2; } +// HttpRequestEvaluation contains one buffered HTTP request to evaluate. message HttpRequestEvaluation { + // Middleware protocol version selected by OpenShell. string api_version = 1; + // Manifest binding id selected for this evaluation. string binding_id = 2; + // Evaluation phase selected for this request. string phase = 3; + // Sandbox and request identity available to the supervisor. RequestContext context = 4; + // Validated service-specific policy configuration. google.protobuf.Struct config = 5; + // Destination and HTTP request target. HttpRequestTarget target = 6; + // HTTP request headers before OpenShell injects credentials. map headers = 7; + // Buffered request body. Empty for a bodyless request. bytes body = 8; } +// RequestContext identifies the sandbox request being evaluated. message RequestContext { + // Request id used to correlate middleware and supervisor logs. string request_id = 1; + // Sandbox id that originated the request. string sandbox_id = 2; + // Workload process that originated the request, when available. Process originating_process = 3; } +// HttpRequestTarget describes the admitted HTTP destination and request target. message HttpRequestTarget { + // Request scheme, such as "http" or "https". string scheme = 1; + // Destination hostname selected by network policy. string host = 2; + // Destination TCP port. uint32 port = 3; + // HTTP request method. string method = 4; + // Request path without the query string. string path = 5; + // Raw request query string without the leading question mark. string query = 6; } +// Process identifies a workload process and its executable ancestry. message Process { + // Executable path for the originating process. string binary = 1; + // Process id within the sandbox. uint32 pid = 2; + // Executable paths for ancestor processes, nearest parent first. repeated string ancestors = 3; } +// Decision controls whether OpenShell continues processing the request. enum Decision { + // Invalid response value handled according to the policy failure mode. DECISION_UNSPECIFIED = 0; + // Continue processing the request and apply any returned mutations. DECISION_ALLOW = 1; + // Deny the request before credentials are injected or data is sent upstream. DECISION_DENY = 2; } +// Finding is an audit-safe observation produced during evaluation. message Finding { + // Stable, service-defined finding type. string type = 1; + // Human-readable finding label that does not contain request content. string label = 2; + // Number of matching observations represented by this finding. uint32 count = 3; + // Service-defined confidence level. string confidence = 4; + // Service-defined severity level. string severity = 5; } +// HttpRequestResult contains the decision and optional request mutations. message HttpRequestResult { + // Allow or deny decision for this request. Decision decision = 1; + // Human-readable reason used for diagnostics and denied responses. string reason = 2; + // Replacement request body when has_body is true. bytes body = 3; + // True when body should replace the request body, including with an empty body. bool has_body = 4; + // Request headers to add before forwarding. Protected headers are rejected. map add_headers = 5; + // Audit-safe findings produced during evaluation. repeated Finding findings = 6; + // Non-secret service-defined metadata included in diagnostics. map metadata = 7; } diff --git a/proto/sandbox.proto b/proto/sandbox.proto index 644fd86cb..c4018573d 100644 --- a/proto/sandbox.proto +++ b/proto/sandbox.proto @@ -79,8 +79,12 @@ message NetworkMiddlewareConfig { MiddlewareEndpointSelector endpoints = 5; } +// Host selector controlling which admitted destinations use a middleware config. message MiddlewareEndpointSelector { + // Exact host or DNS glob patterns included in the selection. repeated string include = 1; + // Exact host or DNS glob patterns removed from the selection. + // Exclusions take precedence over inclusions. repeated string exclude = 2; } @@ -358,14 +362,12 @@ message GetSandboxConfigResponse { } // Connection details for one operator-registered supervisor middleware service. -// V1 supports only explicitly enabled plaintext gRPC for local development. +// V1 supports plaintext and server-authenticated TLS gRPC. message SupervisorMiddlewareService { // Operator-facing registration name used for diagnostics. string name = 1; // gRPC endpoint reachable from the sandbox supervisor. - string endpoint = 2; - // Explicit acknowledgement that request content is sent without TLS. - bool allow_insecure = 3; + string grpc_endpoint = 2; // Operator-owned body limit applied to every binding exposed by the service. uint64 max_body_bytes = 4; } From 71c0898ed1edf752bae28daa2f2133290006ccc6 Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Tue, 30 Jun 2026 17:31:09 -0700 Subject: [PATCH 15/16] fix(supervisor-middleware): keep sandbox startup resilient to middleware outages An unreachable operator-registered middleware service previously aborted sandbox startup via a hard error in load_policy, contradicting the per-request on_error contract and the resilient live-reload path. Retry the initial connect and, on failure, degrade to the built-in registry so matched requests are governed by each config's on_error (deny for fail_closed, allow for fail_open) instead of blocking the whole sandbox. The policy poll loop now reconciles the registry on every poll while an install is pending, so a recovered service is adopted without waiting for a config change; a failed reconcile also no longer blocks unrelated policy updates. Signed-off-by: Piotr Mlocek --- crates/openshell-sandbox/src/lib.rs | 159 +++++++++++++++++++--------- 1 file changed, 109 insertions(+), 50 deletions(-) diff --git a/crates/openshell-sandbox/src/lib.rs b/crates/openshell-sandbox/src/lib.rs index c1531023a..11395531f 100644 --- a/crates/openshell-sandbox/src/lib.rs +++ b/crates/openshell-sandbox/src/lib.rs @@ -128,7 +128,7 @@ pub async fn run_sandbox( // Load policy and initialize OPA engine let openshell_endpoint_for_proxy = openshell_endpoint.clone(); let sandbox_name_for_agg = sandbox.clone(); - let (policy, opa_engine, retained_proto) = load_policy( + let (policy, opa_engine, retained_proto, middleware_install_pending) = load_policy( sandbox_id.clone(), sandbox, openshell_endpoint.clone(), @@ -390,6 +390,7 @@ pub async fn run_sandbox( ocsf_enabled: poll_ocsf_enabled, provider_credentials: poll_provider_credentials, policy_local_ctx: poll_policy_local, + middleware_install_pending, }; tokio::spawn(async move { @@ -1337,6 +1338,11 @@ async fn load_policy( SandboxPolicy, Option>, Option, + // True when operator-registered middleware could not be connected at + // startup and the engine kept the built-in registry. The policy poll loop + // retries the install so a recovered service is picked up without a config + // change. + bool, )> { // File mode: load OPA engine from rego rules + YAML data (dev override) if let (Some(policy_file), Some(data_file)) = (&policy_rules, &policy_data) { @@ -1366,7 +1372,8 @@ async fn load_policy( process: config.process, }; enrich_sandbox_baseline_paths(&mut policy); - return Ok((policy, Some(Arc::new(engine)), None)); + // File mode has no operator-registered middleware to connect. + return Ok((policy, Some(Arc::new(engine)), None, false)); } // gRPC mode: fetch typed proto policy, construct OPA engine from baked rules + proto data @@ -1449,16 +1456,45 @@ async fn load_policy( // engine is rebuilt with the real PID for symlink resolution. info!("Creating OPA engine from proto policy data"); let engine = OpaEngine::from_proto(&proto_policy)?; - let middleware_registry = + // Connect operator-registered middleware services. A connect/describe + // failure must not abort sandbox startup: unlike the previous hard + // failure, we degrade to the built-in registry and let each request's + // `on_error` policy govern matched traffic (deny for fail_closed, allow + // for fail_open). The policy poll loop retries the install so a + // recovered service is picked up without a config change. This mirrors + // the resilient live-reload path. + let middleware_services = sandbox_config.supervisor_middleware_services.clone(); + let middleware_install_pending = match grpc_retry("Middleware connect", || { openshell_supervisor_middleware::MiddlewareRegistry::connect_services( - sandbox_config.supervisor_middleware_services, + middleware_services.clone(), ) - .await?; - engine.replace_middleware_registry(middleware_registry)?; + }) + .await + .and_then(|registry| engine.replace_middleware_registry(registry)) + { + Ok(()) => false, + Err(error) => { + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .state(StateId::Other, "degraded") + .unmapped( + "supervisor_middleware_service_count", + serde_json::json!(middleware_services.len()) + ) + .message(format!( + "Supervisor middleware connect failed at startup; continuing with built-in middleware only, per-request on_error governs matched requests [error:{error}]" + )) + .build() + ); + true + } + }; let opa_engine = Some(Arc::new(engine)); let policy = SandboxPolicy::try_from(proto_policy.clone())?; - return Ok((policy, opa_engine, Some(proto_policy))); + return Ok((policy, opa_engine, Some(proto_policy), middleware_install_pending)); } // No policy source available @@ -1594,6 +1630,10 @@ struct PolicyPollLoopContext { ocsf_enabled: Arc, provider_credentials: ProviderCredentialState, policy_local_ctx: Option>, + /// True when `load_policy` degraded to the built-in middleware registry + /// because operator services could not be connected at startup. The poll + /// loop retries the install until it succeeds. + middleware_install_pending: bool, } async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { @@ -1606,6 +1646,10 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { let mut current_provider_env_revision: u64 = ctx.provider_credentials.snapshot().revision; let mut current_policy_hash = String::new(); let mut current_middleware_services = Vec::new(); + // Set when a middleware install is outstanding (degraded at startup or a + // failed reload). Drives a retry on every poll, independent of the config + // revision, so a recovered operator service is picked up promptly. + let mut middleware_sync_pending = ctx.middleware_install_pending; let mut current_settings: std::collections::HashMap< String, openshell_core::proto::EffectiveSetting, @@ -1641,14 +1685,70 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { } }; + // Reconcile the supervisor middleware registry before evaluating the + // rest of the config. This runs independently of the config revision so + // an install that degraded at startup (or failed on an earlier poll) is + // retried here, letting a recovered operator service be picked up + // without waiting for a policy change. A failure keeps the + // last-known-good registry; the request path stays governed by each + // middleware's `on_error` policy, and a config change is still applied + // below rather than being blocked by the middleware outage. + if middleware_sync_pending + || result.supervisor_middleware_services != current_middleware_services + { + match openshell_supervisor_middleware::MiddlewareRegistry::connect_services( + result.supervisor_middleware_services.clone(), + ) + .await + .and_then(|registry| ctx.opa_engine.replace_middleware_registry(registry)) + { + Ok(()) => { + current_middleware_services = result.supervisor_middleware_services.clone(); + middleware_sync_pending = false; + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Informational) + .status(StatusId::Success) + .state(StateId::Enabled, "loaded") + .unmapped( + "supervisor_middleware_service_count", + serde_json::json!(current_middleware_services.len()) + ) + .message(format!( + "Supervisor middleware registry reloaded [service_count:{}]", + current_middleware_services.len() + )) + .build() + ); + } + Err(error) => { + // Emit only on the transition into the failed state to avoid + // repeating the same finding on every poll during a + // sustained outage. The startup degrade path emits its own + // event. + if !middleware_sync_pending { + ocsf_emit!( + ConfigStateChangeBuilder::new(ocsf_ctx()) + .severity(SeverityId::Medium) + .status(StatusId::Failure) + .state(StateId::Other, "failed") + .message(format!( + "Supervisor middleware registry reload failed, keeping last-known-good registry [error:{error}]" + )) + .build() + ); + } + middleware_sync_pending = true; + } + } + } + let provider_env_changed = result.provider_env_revision != current_provider_env_revision; if result.config_revision == current_config_revision && !provider_env_changed { continue; } let policy_changed = result.policy_hash != current_policy_hash; - let middleware_changed = - result.supervisor_middleware_services != current_middleware_services; // Log which settings changed. log_setting_changes(¤t_settings, &result.settings); @@ -1707,47 +1807,6 @@ async fn run_policy_poll_loop(ctx: PolicyPollLoopContext) -> Result<()> { } } - if middleware_changed { - match openshell_supervisor_middleware::MiddlewareRegistry::connect_services( - result.supervisor_middleware_services.clone(), - ) - .await - .and_then(|registry| ctx.opa_engine.replace_middleware_registry(registry)) - { - Ok(()) => { - current_middleware_services = result.supervisor_middleware_services.clone(); - ocsf_emit!( - ConfigStateChangeBuilder::new(ocsf_ctx()) - .severity(SeverityId::Informational) - .status(StatusId::Success) - .state(StateId::Enabled, "loaded") - .unmapped( - "supervisor_middleware_service_count", - serde_json::json!(current_middleware_services.len()) - ) - .message(format!( - "Supervisor middleware registry reloaded [service_count:{}]", - current_middleware_services.len() - )) - .build() - ); - } - Err(error) => { - ocsf_emit!( - ConfigStateChangeBuilder::new(ocsf_ctx()) - .severity(SeverityId::Medium) - .status(StatusId::Failure) - .state(StateId::Other, "failed") - .message(format!( - "Supervisor middleware registry reload failed, keeping last-known-good registry [error:{error}]" - )) - .build() - ); - continue; - } - } - } - // Only reload OPA when the policy payload actually changed. if policy_changed { let Some(policy) = result.policy.as_ref() else { From 9333943b1a67346b50e03140abfbfd2548a60b3d Mon Sep 17 00:00:00 2001 From: Piotr Mlocek Date: Tue, 30 Jun 2026 17:31:17 -0700 Subject: [PATCH 16/16] fix(supervisor-middleware): ignore unresolved bindings in chain body limit A chain entry whose binding did not resolve reported a zero body limit, which dragged the whole chain's buffer cap to zero and spuriously failed body-bearing requests over capacity even when a resolved middleware could have processed them. Exclude unresolved entries from the limit via a new DescribedChainEntry::is_resolved(); when no entry resolves, skip buffering and apply each entry's on_error directly. Also fix two parallel-test flakes found while validating the change: - Build middleware OCSF events into a Vec and assert on it directly instead of capturing through the global tracing pipeline, whose callsite-interest cache is process-global and raced under parallel runs. - Accumulate the websocket deny response until the reason marker arrives rather than assuming a single read returns the full body. Signed-off-by: Piotr Mlocek --- .../src/lib.rs | 27 +++ .../src/l7/relay.rs | 201 ++++++++++++++---- 2 files changed, 188 insertions(+), 40 deletions(-) diff --git a/crates/openshell-supervisor-middleware/src/lib.rs b/crates/openshell-supervisor-middleware/src/lib.rs index 00324d75f..27302dfe4 100644 --- a/crates/openshell-supervisor-middleware/src/lib.rs +++ b/crates/openshell-supervisor-middleware/src/lib.rs @@ -102,6 +102,14 @@ impl DescribedChainEntry { pub fn on_error(&self) -> OnError { self.entry.on_error } + + /// True when this entry resolved to a registered binding and will be + /// evaluated. When false, the binding is absent from the current registry + /// and the entry is handled entirely by its `on_error` policy, so it + /// imposes no body-buffering limit on the chain. + pub fn is_resolved(&self) -> bool { + self.binding.is_some() + } } #[derive(Debug, Clone)] @@ -1166,6 +1174,25 @@ mod tests { } } + #[tokio::test] + async fn describe_chain_marks_resolved_and_unresolved_entries() { + let unresolved = ChainEntry { + name: "missing".into(), + implementation: "third-party/missing".into(), + config: prost_types::Struct::default(), + on_error: OnError::FailOpen, + }; + let described = ChainRunner::default() + .describe_chain(&[entry("redact", OnError::FailClosed), unresolved]) + .await + .expect("describe chain"); + // The built-in resolves and reports its real limit; the missing binding + // does not resolve and must not contribute a body limit. + assert!(described[0].is_resolved()); + assert_eq!(described[0].max_body_bytes(), 256 * 1024); + assert!(!described[1].is_resolved()); + } + #[tokio::test] async fn descriptors_are_resolved_from_any_middleware_service() { let runner = ChainRunner::new(Arc::new(ScriptedService { diff --git a/crates/openshell-supervisor-network/src/l7/relay.rs b/crates/openshell-supervisor-network/src/l7/relay.rs index 84853f751..3e92a3218 100644 --- a/crates/openshell-supervisor-network/src/l7/relay.rs +++ b/crates/openshell-supervisor-network/src/l7/relay.rs @@ -778,11 +778,18 @@ pub(crate) enum MiddlewareApplyResult { Denied(String), } +/// Smallest body-buffering limit across the entries that actually resolved to a +/// registered binding. Unresolved entries (`is_resolved() == false`) report a +/// zero limit and are excluded here: they are handled by their `on_error` policy +/// in `evaluate_described` without inspecting the body, so letting a zero drag +/// the chain limit to zero would spuriously fail the whole chain over capacity. +/// Returns `None` when no entry resolved, so the caller can skip buffering. fn middleware_chain_body_limit( chain: &[openshell_supervisor_middleware::DescribedChainEntry], ) -> Option { chain .iter() + .filter(|entry| entry.is_resolved()) .map(openshell_supervisor_middleware::DescribedChainEntry::max_body_bytes) .min() } @@ -812,8 +819,20 @@ pub(crate) async fn apply_middleware_chain_for_scheme crate::opa::NetworkInput { } } -fn emit_middleware_events( +/// Build the OCSF events describing a middleware chain outcome, in emission +/// order. Separated from `emit_middleware_events` so tests can assert on the +/// events deterministically without routing through the global tracing pipeline, +/// whose callsite-interest cache is process-global and races under parallel +/// tests. +fn middleware_events( ctx: &L7EvalContext, req: &crate::l7::provider::L7Request, outcome: &openshell_supervisor_middleware::ChainOutcome, -) { +) -> Vec { + let mut events = Vec::new(); for invocation in &outcome.applied { let allowed = invocation.decision == openshell_core::proto::Decision::Allow; let event = HttpActivityBuilder::new(openshell_ocsf::ctx::ctx()) @@ -1000,7 +1025,7 @@ fn emit_middleware_events( invocation.failed )) .build(); - ocsf_emit!(event); + events.push(event); // A middleware that failed but was bypassed under `fail_open` is an // enforcement failure operators must be able to alert on, even though the @@ -1021,7 +1046,7 @@ fn emit_middleware_events( invocation.name )) .build(); - ocsf_emit!(event); + events.push(event); } } if !outcome.allowed && outcome.reason.starts_with("middleware_failed:") { @@ -1033,7 +1058,7 @@ fn emit_middleware_events( )) .message("Required supervisor middleware failed closed") .build(); - ocsf_emit!(event); + events.push(event); } for finding in &outcome.findings { let event = DetectionFindingBuilder::new(openshell_ocsf::ctx::ctx()) @@ -1055,6 +1080,19 @@ fn emit_middleware_events( finding.finding.r#type, finding.finding.count )) .build(); + events.push(event); + } + events +} + +/// Emit the OCSF events describing a middleware chain outcome through the +/// tracing pipeline. +fn emit_middleware_events( + ctx: &L7EvalContext, + req: &crate::l7::provider::L7Request, + outcome: &openshell_supervisor_middleware::ChainOutcome, +) { + for event in middleware_events(ctx, req, outcome) { ocsf_emit!(event); } } @@ -3051,6 +3089,101 @@ network_policies: )); } + #[tokio::test] + async fn body_limit_ignores_unresolved_entries() { + use openshell_supervisor_middleware::{ChainEntry, ChainRunner, OnError}; + + let resolved = ChainEntry { + name: "redact".into(), + implementation: openshell_supervisor_middleware::BUILTIN_SECRETS.into(), + config: prost_types::Struct::default(), + on_error: OnError::FailClosed, + }; + let unresolved = ChainEntry { + name: "missing".into(), + implementation: "third-party/missing".into(), + config: prost_types::Struct::default(), + on_error: OnError::FailOpen, + }; + + // A single unresolved (0-limit) entry must not drag the chain limit to + // zero: the buffer limit reflects only the resolved built-in. + let mixed = ChainRunner::default() + .describe_chain(&[resolved, unresolved.clone()]) + .await + .expect("describe mixed chain"); + assert_eq!(middleware_chain_body_limit(&mixed), Some(256 * 1024)); + + // When nothing resolves, there is no body limit and the caller skips + // buffering entirely. + let none = ChainRunner::default() + .describe_chain(std::slice::from_ref(&unresolved)) + .await + .expect("describe unresolved chain"); + assert_eq!(middleware_chain_body_limit(&none), None); + } + + #[tokio::test] + async fn all_unresolved_fail_open_forwards_body_unbuffered() { + // A chain whose only entry is an unregistered binding has no resolvable + // body limit. Under fail_open the request must pass through with its + // body intact rather than being denied over a phantom zero-byte cap. + let (config, tunnel_engine, ctx) = + middleware_relay_context("third-party/missing", "fail_open"); + let (mut app, mut relay_client) = tokio::io::duplex(8192); + let (mut relay_upstream, mut upstream) = tokio::io::duplex(8192); + let relay = tokio::spawn(async move { + relay_with_inspection( + &config, + tunnel_engine, + &mut relay_client, + &mut relay_upstream, + &ctx, + ) + .await + }); + + let body = br#"{"api_key":"sk-1234567890abcdef"}"#; + let request = format!( + "POST /v1/messages HTTP/1.1\r\nHost: api.example.test\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}", + body.len(), + std::str::from_utf8(body).unwrap() + ); + app.write_all(request.as_bytes()).await.unwrap(); + + let mut upstream_request = [0u8; 1024]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + upstream.read(&mut upstream_request), + ) + .await + .expect("request should reach upstream") + .unwrap(); + let upstream_request = String::from_utf8_lossy(&upstream_request[..n]); + // No middleware ran, so the body is forwarded verbatim. + assert!(upstream_request.contains(r#""api_key":"sk-1234567890abcdef""#)); + + upstream + .write_all(b"HTTP/1.1 204 No Content\r\nContent-Length: 0\r\nConnection: close\r\n\r\n") + .await + .unwrap(); + let mut client_response = [0u8; 512]; + let n = tokio::time::timeout( + std::time::Duration::from_secs(1), + app.read(&mut client_response), + ) + .await + .expect("response should reach client") + .unwrap(); + assert!(String::from_utf8_lossy(&client_response[..n]).contains("204 No Content")); + drop(app); + tokio::time::timeout(std::time::Duration::from_secs(1), relay) + .await + .expect("relay should finish") + .unwrap() + .unwrap(); + } + #[test] fn middleware_keeps_the_raw_request_query() { let query = raw_query_from_request_headers( @@ -3095,36 +3228,14 @@ network_policies: assert_eq!(input.scheme, "http"); } - /// Tracing layer that captures emitted `OcsfEvent`s for assertions. - struct OcsfCaptureLayer(Arc>>); - - impl tracing_subscriber::Layer for OcsfCaptureLayer { - fn on_event( - &self, - event: &tracing::Event<'_>, - _ctx: tracing_subscriber::layer::Context<'_, S>, - ) { - if event.metadata().target() == openshell_ocsf::OCSF_TARGET - && let Some(ocsf_event) = openshell_ocsf::clone_current_event() - { - self.0.lock().unwrap().push(ocsf_event); - } - } - } - #[test] fn middleware_ocsf_events_are_audit_safe() { use openshell_supervisor_middleware::{ ChainOutcome, MiddlewareInvocation, NamespacedFinding, }; - use tracing_subscriber::layer::SubscriberExt; const RAW_SECRET: &str = "sk-RAWSECRETVALUE0123456789"; - let events = Arc::new(std::sync::Mutex::new(Vec::new())); - let subscriber = tracing_subscriber::registry().with(OcsfCaptureLayer(Arc::clone(&events))); - let _guard = tracing::subscriber::set_default(subscriber); - let ctx = L7EvalContext { host: "api.example.test".into(), port: 443, @@ -3171,23 +3282,26 @@ network_policies: }], }; - emit_middleware_events(&ctx, &req, &outcome); + // Build the events directly rather than routing through the global + // tracing pipeline: its callsite-interest cache is process-global, so a + // parallel test that emits OCSF with no subscriber installed can cache + // the callsite as disabled and make captured-event assertions flaky. + let events = middleware_events(&ctx, &req, &outcome); - let captured = events.lock().unwrap(); // Per-invocation decisions are HTTP Activity (class 4002). assert!( - captured.iter().any(|e| e.class_uid() == 4002), + events.iter().any(|e| e.class_uid() == 4002), "expected an HTTP Activity event for the middleware invocation" ); // Findings are Detection Finding (class 2004) with the finding's severity. - let finding_event = captured + let finding_event = events .iter() .find(|e| e.class_uid() == 2004) .expect("expected a Detection Finding event"); assert_eq!(finding_event.base().severity, SeverityId::Medium); // No raw payload material may appear in any emitted event. - let serialized = serde_json::to_string(&*captured).expect("serialize events"); + let serialized = serde_json::to_string(&events).expect("serialize events"); assert!( !serialized.contains(RAW_SECRET), "raw secret leaked into OCSF events: {serialized}" @@ -3364,12 +3478,19 @@ network_policies: .await .unwrap(); - let mut response = [0u8; 512]; - let n = tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut response)) - .await - .expect("denial should reach client") - .unwrap(); - let response = String::from_utf8_lossy(&response[..n]); + // Accumulate until the reason marker arrives: the deny response can be + // delivered in more than one write, so a single read may return only the + // status line and flake the body assertion. + let mut response = Vec::new(); + let mut buf = [0u8; 512]; + while !String::from_utf8_lossy(&response).contains("middleware_failed") { + match tokio::time::timeout(std::time::Duration::from_secs(1), app.read(&mut buf)).await { + Ok(Ok(0)) | Err(_) => break, // clean EOF, or no more data before the deadline + Ok(Ok(n)) => response.extend_from_slice(&buf[..n]), + Ok(Err(e)) => panic!("read from relay failed: {e}"), + } + } + let response = String::from_utf8_lossy(&response); assert!(response.contains("403 Forbidden")); assert!(response.contains("middleware_failed"));