Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions crates/provider-proxy/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ path = "tests/e2e/tests.rs"
[dependencies]
anyhow = { workspace = true }
bytes = { version = "1.10.0", features = ["serde"] }
chrono = { version = "0.4", default-features = false, features = ["clock"] }
clap = { version = "4.6.0", features = ["derive"] }
hex = "0.4.3"
http = "1.4.0"
Expand All @@ -22,6 +23,7 @@ hyper-util = "0.1.19"
moka = { version = "0.12.15", features = ["sync"] }
pin-project = "1.1.11"
rcgen = "0.14.7"
regex = "1"
reqwest.workspace = true
rustls = { version = "0.23.37", features = ["ring"] }
serde = { workspace = true }
Expand Down
109 changes: 105 additions & 4 deletions crates/provider-proxy/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use std::future::Future;
use std::io::Write;
use std::net::SocketAddr;
use std::path::PathBuf;
use std::sync::Arc;
use std::sync::{Arc, Mutex};

use anyhow::Context as _;
use bytes::{Bytes, BytesMut};
Expand All @@ -34,6 +34,45 @@ use tracing::level_filters::LevelFilter;

const CACHE_HEADER_NAME: &str = "x-tensorzero-provider-proxy-cache";

#[derive(Serialize, Clone)]
struct CacheMissEntry {
cache_key: String,
host: String,
method: String,
path: String,
timestamp: String,
}

/// Thread-safe collector for cache miss events.
#[derive(Default, Clone)]
struct CacheMissTracker {
entries: Arc<Mutex<Vec<CacheMissEntry>>>,
}

impl CacheMissTracker {
fn record(&self, entry: CacheMissEntry) {
self.entries.lock().expect("lock poisoned").push(entry);
}

fn write_to_file(&self, path: &std::path::Path) -> Result<(), anyhow::Error> {
let entries = self.entries.lock().expect("lock poisoned");
if entries.is_empty() {
return Ok(());
}
let json = serde_json::to_string_pretty(&*entries)
.with_context(|| "Failed to serialize cache miss manifest")?;
std::fs::write(path, json).with_context(|| {
format!("Failed to write cache miss manifest to {}", path.display())
})?;
tracing::info!(
"Wrote {} cache miss entries to {}",
entries.len(),
path.display()
);
Ok(())
}
}

fn make_root_cert() -> rcgen::Issuer<'static, rcgen::KeyPair> {
let mut param = rcgen::CertificateParams::default();

Expand All @@ -52,6 +91,23 @@ fn make_root_cert() -> rcgen::Issuer<'static, rcgen::KeyPair> {
rcgen::Issuer::new(param, key_pair)
}

/// Sanitize the request body for cache key computation.
/// Replaces non-deterministic values (UUIDs, random localhost ports) with
/// placeholders so that test runs produce the same cache key.
fn sanitize_body_for_cache_key(body: &str) -> String {
use regex::Regex;
// UUIDv4/v7 pattern: 8-4-4-4-12 hex digits
let uuid_re =
Regex::new(r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}")
.expect("Invalid UUID regex");
let result = uuid_re.replace_all(body, "TENSORZERO_SANITIZED_UUID");

// Localhost with random port: 127.0.0.1:PORT or localhost:PORT
let port_re =
Regex::new(r"(127\.0\.0\.1|localhost):(\d{4,5})").expect("Invalid localhost port regex");
port_re.replace_all(&result, "${1}:0").into_owned()
}

fn hash_value(request: &serde_json::Value) -> Result<String, anyhow::Error> {
let mut hasher = Sha256::new();
hasher.update(
Expand Down Expand Up @@ -146,6 +202,7 @@ async fn check_cache<
>(
start_time: std::time::SystemTime,
args: &Args,
cache_miss_tracker: &CacheMissTracker,
mut request: hyper::Request<Bytes>,
missing: F,
) -> Result<hyper::Response<BoxBody<Bytes, E>>, anyhow::Error> {
Expand Down Expand Up @@ -207,8 +264,24 @@ async fn check_cache<
sanitized_header = true;
}
}
let json_request = http_serde_ext::request::serialize(&request, serde_json::value::Serializer)
.with_context(|| "Failed to serialize request")?;
let mut json_request =
http_serde_ext::request::serialize(&request, serde_json::value::Serializer)
.with_context(|| "Failed to serialize request")?;

if args.sanitize_body
&& let Some(body_array) = json_request.get("body").and_then(|b| b.as_array())
{
let body_bytes: Vec<u8> = body_array
.iter()
.filter_map(|v| v.as_u64().map(|n| n as u8))
.collect();
if let Ok(body_str) = String::from_utf8(body_bytes) {
let sanitized = sanitize_body_for_cache_key(&body_str);
json_request["body"] =
serde_json::Value::Array(sanitized.bytes().map(|b| b.into()).collect());
}
}

let hash = hash_value(&json_request)?;

// Capture the serialized request for potential debugging storage
Expand Down Expand Up @@ -306,6 +379,13 @@ async fn check_cache<
"Cache miss: {}",
path_str,
);
cache_miss_tracker.record(CacheMissEntry {
cache_key: hash.clone(),
host: request.uri().host().unwrap_or("unknown").to_string(),
method: request.method().to_string(),
path: request.uri().path().to_string(),
timestamp: chrono::Utc::now().to_rfc3339(),
});
if matches!(args.mode, CacheMode::ReadOnlyRequireHit) {
tracing::error!("Cache miss in ReadOnlyRequireHit mode: {path_str}");
let body = Full::new(Bytes::from(format!(
Expand Down Expand Up @@ -432,10 +512,17 @@ pub struct Args {
pub remove_user_agent_non_amazon: bool,
#[arg(long, default_value = "read-old-write-new")]
pub mode: CacheMode,
/// If `true`, normalizes UUIDs and random localhost ports in request bodies
/// before computing cache keys, making them deterministic across test runs.
#[arg(long, default_value = "true")]
pub sanitize_body: bool,
/// If `true`, saves the request body in the cached output for debugging purposes.
/// The saved request body is not used when reading from the cache.
#[arg(long, default_value = "true")]
pub save_request_body: bool,
/// Path to write a JSON manifest of all cache misses. If unset, no manifest is written.
#[arg(long)]
pub cache_miss_manifest: Option<PathBuf>,
}

fn find_duplicate_header(headers: &http::HeaderMap) -> Option<HeaderName> {
Expand Down Expand Up @@ -519,14 +606,18 @@ pub async fn run_server(args: Args, server_started: oneshot::Sender<SocketAddr>)
Some(Cache::new(128)),
);

let cache_miss_tracker = CacheMissTracker::default();

let client = reqwest::Client::new();
let args_clone = args.clone();
let tracker_clone = cache_miss_tracker.clone();
let (server_addr, server) = proxy
.bind(
("0.0.0.0", args.port),
service_fn(move |req: hyper::Request<hyper::body::Incoming>| {
let client = client.clone();
let args = args_clone.clone();
let tracker = tracker_clone.clone();
async move {
let (parts, body) = req.into_parts();

Expand Down Expand Up @@ -577,7 +668,7 @@ pub async fn run_server(args: Args, server_started: oneshot::Sender<SocketAddr>)
// Add 1ms delay to simulate network latency
tokio::time::sleep(std::time::Duration::from_millis(1)).await;

let response = check_cache(start_time, &args, bytes_request.clone(), || async {
let response = check_cache(start_time, &args, &tracker, bytes_request.clone(), || async {
let mut request: reqwest::Request =
bytes_request.try_into().with_context(|| {
"Failed to convert Request from `hyper` to `reqwest`"
Expand Down Expand Up @@ -611,4 +702,14 @@ pub async fn run_server(args: Args, server_started: oneshot::Sender<SocketAddr>)
.send(server_addr)
.expect("Failed to send server started signal");
server.await;

// Write cache miss manifest if requested
if let Some(manifest_path) = &args.cache_miss_manifest
&& let Err(e) = cache_miss_tracker.write_to_file(manifest_path)
{
tracing::error!(
err = e.as_ref() as &dyn std::error::Error,
"Failed to write cache miss manifest"
);
}
}
14 changes: 14 additions & 0 deletions crates/provider-proxy/tests/e2e/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ async fn test_provider_proxy() {
remove_user_agent_non_amazon: false,
health_port: 0,
mode: CacheMode::ReadWrite,
sanitize_body: true,
save_request_body: true,
cache_miss_manifest: None,
},
server_started_tx,
));
Expand Down Expand Up @@ -249,7 +251,9 @@ async fn test_read_old_write_new() {
sanitize_model_headers: true,
remove_user_agent_non_amazon: false,
mode: CacheMode::ReadOldWriteNew,
sanitize_body: true,
save_request_body: true,
cache_miss_manifest: None,
},
server_started_tx,
));
Expand Down Expand Up @@ -361,7 +365,9 @@ async fn test_read_old_write_new() {
sanitize_model_headers: true,
remove_user_agent_non_amazon: false,
mode: CacheMode::ReadOldWriteNew,
sanitize_body: true,
save_request_body: true,
cache_miss_manifest: None,
},
server_started_tx,
));
Expand Down Expand Up @@ -411,7 +417,9 @@ async fn test_dropped_stream_body() {
sanitize_model_headers: true,
remove_user_agent_non_amazon: false,
mode: CacheMode::ReadOldWriteNew,
sanitize_body: true,
save_request_body: true,
cache_miss_manifest: None,
},
server_started_tx,
));
Expand Down Expand Up @@ -529,7 +537,9 @@ async fn test_read_only_require_hit() {
sanitize_model_headers: true,
remove_user_agent_non_amazon: false,
mode: CacheMode::ReadWrite,
sanitize_body: true,
save_request_body: true,
cache_miss_manifest: None,
},
server_started_tx,
));
Expand Down Expand Up @@ -595,7 +605,9 @@ async fn test_read_only_require_hit() {
sanitize_model_headers: true,
remove_user_agent_non_amazon: false,
mode: CacheMode::ReadOnlyRequireHit,
sanitize_body: true,
save_request_body: true,
cache_miss_manifest: None,
},
server_started_tx,
));
Expand Down Expand Up @@ -664,7 +676,9 @@ async fn test_stream_body() {
sanitize_model_headers: true,
remove_user_agent_non_amazon: false,
mode: CacheMode::ReadOldWriteNew,
sanitize_body: true,
save_request_body: true,
cache_miss_manifest: None,
},
server_started_tx,
));
Expand Down
2 changes: 1 addition & 1 deletion crates/tensorzero-core/src/http.rs
Original file line number Diff line number Diff line change
Expand Up @@ -709,7 +709,7 @@ fn build_client(global_outbound_http_timeout: Duration) -> Result<Client, Error>
})
})?
.no_proxy(NoProxy::from_string(
"localhost,0.0.0.0,127.0.0.1,minio,mock-provider-api,gateway,provider-proxy,clickhouse",
"localhost,0.0.0.0,127.0.0.1,minio,mock-provider-api,gateway,provider-proxy,clickhouse,raw.githubusercontent.com",
)),
)
// When running e2e tests, we use `provider-proxy` as an MITM proxy
Expand Down
Loading
Loading