From 2c66fa50d8931c951d9f0fbd99451139ac948052 Mon Sep 17 00:00:00 2001 From: Silas Brack Date: Sat, 7 Mar 2026 17:27:54 +0100 Subject: [PATCH] Idk broq --- src/db.rs | 382 ++++++++++++++++++++++-------------------- src/error.rs | 146 ++++++++-------- src/hasher.rs | 166 +++++++++--------- src/lib.rs | 106 ++++++------ src/main.rs | 13 +- src/rebalance.rs | 353 +++++++++++++++++++++------------------ src/rebuild.rs | 158 ++++++++++-------- src/server.rs | 372 ++++++++++++++++++++++------------------- tests/integration.rs | 389 +++++++++++++++++++++++++------------------ 9 files changed, 1125 insertions(+), 960 deletions(-) diff --git a/src/db.rs b/src/db.rs index cd65587..24a2687 100644 --- a/src/db.rs +++ b/src/db.rs @@ -1,182 +1,200 @@ -use rusqlite::{params, Connection, OpenFlags}; -use std::sync::{Arc, Mutex}; - -use crate::error::AppError; - -#[derive(Debug, Clone)] -pub struct Record { - pub key: String, - pub volumes: Vec, - pub size: Option, -} - -fn apply_pragmas(conn: &Connection) { - conn.execute_batch( - "PRAGMA journal_mode = WAL; - PRAGMA synchronous = NORMAL; - PRAGMA busy_timeout = 5000; - PRAGMA temp_store = memory; - PRAGMA cache_size = -64000; - PRAGMA mmap_size = 268435456;", - ) - .expect("failed to set pragmas"); -} - -fn parse_volumes(s: &str) -> Vec { - serde_json::from_str(s).unwrap_or_default() -} - -fn encode_volumes(v: &[String]) -> String { - serde_json::to_string(v).unwrap() -} - -/// A single SQLite connection behind a mutex, used for both reads and writes. -#[derive(Clone)] -pub struct Db { - conn: Arc>, -} - -impl Db { - pub fn new(path: &str) -> Self { - let conn = Connection::open_with_flags( - path, - OpenFlags::SQLITE_OPEN_READ_WRITE - | OpenFlags::SQLITE_OPEN_CREATE - | OpenFlags::SQLITE_OPEN_NO_MUTEX - | OpenFlags::SQLITE_OPEN_URI, - ) - .expect("failed to open database"); - apply_pragmas(&conn); - conn.execute_batch( - "CREATE TABLE IF NOT EXISTS kv ( - key TEXT PRIMARY KEY, - volumes TEXT NOT NULL, - size INTEGER, - created_at INTEGER DEFAULT (unixepoch()) - );", - ) - .expect("failed to create tables"); - Self { conn: Arc::new(Mutex::new(conn)) } - } - - pub async fn get(&self, key: &str) -> Result { - let conn = self.conn.clone(); - let key = key.to_string(); - tokio::task::spawn_blocking(move || { - let conn = conn.lock().unwrap(); - let mut stmt = conn.prepare_cached("SELECT key, volumes, size FROM kv WHERE key = ?1")?; - Ok(stmt.query_row(params![key], |row| { - let vj: String = row.get(1)?; - Ok(Record { key: row.get(0)?, volumes: parse_volumes(&vj), size: row.get(2)? }) - })?) - }) - .await - .unwrap() - } - - pub async fn list_keys(&self, prefix: &str) -> Result, AppError> { - let conn = self.conn.clone(); - let prefix = prefix.to_string(); - tokio::task::spawn_blocking(move || { - let conn = conn.lock().unwrap(); - if prefix.is_empty() { - let mut stmt = conn.prepare_cached("SELECT key FROM kv ORDER BY key")?; - let keys = stmt - .query_map([], |row| row.get(0))? - .collect::, _>>()?; - return Ok(keys); - } - // Compute exclusive upper bound: increment last non-0xFF byte - let upper = { - let mut bytes = prefix.as_bytes().to_vec(); - let mut result = None; - while let Some(last) = bytes.pop() { - if last < 0xFF { - bytes.push(last + 1); - result = Some(String::from_utf8_lossy(&bytes).into_owned()); - break; - } - } - result - }; - let keys = match &upper { - Some(end) => { - let mut stmt = conn.prepare_cached( - "SELECT key FROM kv WHERE key >= ?1 AND key < ?2 ORDER BY key", - )?; - stmt.query_map(params![prefix, end], |row| row.get(0))? - .collect::, _>>()? - } - None => { - let mut stmt = conn.prepare_cached( - "SELECT key FROM kv WHERE key >= ?1 ORDER BY key", - )?; - stmt.query_map(params![prefix], |row| row.get(0))? - .collect::, _>>()? - } - }; - Ok(keys) - }) - .await - .unwrap() - } - - pub async fn put(&self, key: String, volumes: Vec, size: Option) -> Result<(), AppError> { - let conn = self.conn.clone(); - tokio::task::spawn_blocking(move || { - let conn = conn.lock().unwrap(); - conn.prepare_cached( - "INSERT INTO kv (key, volumes, size) VALUES (?1, ?2, ?3) - ON CONFLICT(key) DO UPDATE SET volumes = ?2, size = ?3", - )? - .execute(params![key, encode_volumes(&volumes), size])?; - Ok(()) - }) - .await - .unwrap() - } - - pub async fn delete(&self, key: String) -> Result<(), AppError> { - let conn = self.conn.clone(); - tokio::task::spawn_blocking(move || { - let conn = conn.lock().unwrap(); - conn.prepare_cached("DELETE FROM kv WHERE key = ?1")? - .execute(params![key])?; - Ok(()) - }) - .await - .unwrap() - } - - pub async fn bulk_put(&self, records: Vec<(String, Vec, Option)>) -> Result<(), AppError> { - let conn = self.conn.clone(); - tokio::task::spawn_blocking(move || { - let conn = conn.lock().unwrap(); - conn.execute_batch("BEGIN")?; - let mut stmt = conn.prepare_cached( - "INSERT INTO kv (key, volumes, size) VALUES (?1, ?2, ?3) - ON CONFLICT(key) DO UPDATE SET volumes = ?2, size = ?3", - )?; - for (key, volumes, size) in &records { - stmt.execute(params![key, encode_volumes(volumes), size])?; - } - drop(stmt); - conn.execute_batch("COMMIT")?; - Ok(()) - }) - .await - .unwrap() - } - - pub fn all_records_sync(&self) -> Result, AppError> { - let conn = self.conn.lock().unwrap(); - let mut stmt = conn.prepare_cached("SELECT key, volumes, size FROM kv")?; - let records = stmt - .query_map([], |row| { - let vj: String = row.get(1)?; - Ok(Record { key: row.get(0)?, volumes: parse_volumes(&vj), size: row.get(2)? }) - })? - .collect::, _>>()?; - Ok(records) - } -} +use rusqlite::{Connection, OpenFlags, params}; +use std::sync::{Arc, Mutex}; + +use crate::error::AppError; + +#[derive(Debug, Clone)] +pub struct Record { + pub key: String, + pub volumes: Vec, + pub size: Option, +} + +fn apply_pragmas(conn: &Connection) { + conn.execute_batch( + "PRAGMA journal_mode = WAL; + PRAGMA synchronous = NORMAL; + PRAGMA busy_timeout = 5000; + PRAGMA temp_store = memory; + PRAGMA cache_size = -64000; + PRAGMA mmap_size = 268435456;", + ) + .expect("failed to set pragmas"); +} + +fn parse_volumes(s: &str) -> Vec { + serde_json::from_str(s).unwrap_or_default() +} + +fn encode_volumes(v: &[String]) -> String { + serde_json::to_string(v).unwrap() +} + +/// A single SQLite connection behind a mutex, used for both reads and writes. +#[derive(Clone)] +pub struct Db { + conn: Arc>, +} + +impl Db { + pub fn new(path: &str) -> Self { + let conn = Connection::open_with_flags( + path, + OpenFlags::SQLITE_OPEN_READ_WRITE + | OpenFlags::SQLITE_OPEN_CREATE + | OpenFlags::SQLITE_OPEN_NO_MUTEX + | OpenFlags::SQLITE_OPEN_URI, + ) + .expect("failed to open database"); + apply_pragmas(&conn); + conn.execute_batch( + "CREATE TABLE IF NOT EXISTS kv ( + key TEXT PRIMARY KEY, + volumes TEXT NOT NULL, + size INTEGER, + created_at INTEGER DEFAULT (unixepoch()) + );", + ) + .expect("failed to create tables"); + Self { + conn: Arc::new(Mutex::new(conn)), + } + } + + pub async fn get(&self, key: &str) -> Result { + let conn = self.conn.clone(); + let key = key.to_string(); + tokio::task::spawn_blocking(move || { + let conn = conn.lock().unwrap(); + let mut stmt = + conn.prepare_cached("SELECT key, volumes, size FROM kv WHERE key = ?1")?; + Ok(stmt.query_row(params![key], |row| { + let vj: String = row.get(1)?; + Ok(Record { + key: row.get(0)?, + volumes: parse_volumes(&vj), + size: row.get(2)?, + }) + })?) + }) + .await + .unwrap() + } + + pub async fn list_keys(&self, prefix: &str) -> Result, AppError> { + let conn = self.conn.clone(); + let prefix = prefix.to_string(); + tokio::task::spawn_blocking(move || { + let conn = conn.lock().unwrap(); + if prefix.is_empty() { + let mut stmt = conn.prepare_cached("SELECT key FROM kv ORDER BY key")?; + let keys = stmt + .query_map([], |row| row.get(0))? + .collect::, _>>()?; + return Ok(keys); + } + // Compute exclusive upper bound: increment last non-0xFF byte + let upper = { + let mut bytes = prefix.as_bytes().to_vec(); + let mut result = None; + while let Some(last) = bytes.pop() { + if last < 0xFF { + bytes.push(last + 1); + result = Some(String::from_utf8_lossy(&bytes).into_owned()); + break; + } + } + result + }; + let keys = match &upper { + Some(end) => { + let mut stmt = conn.prepare_cached( + "SELECT key FROM kv WHERE key >= ?1 AND key < ?2 ORDER BY key", + )?; + stmt.query_map(params![prefix, end], |row| row.get(0))? + .collect::, _>>()? + } + None => { + let mut stmt = + conn.prepare_cached("SELECT key FROM kv WHERE key >= ?1 ORDER BY key")?; + stmt.query_map(params![prefix], |row| row.get(0))? + .collect::, _>>()? + } + }; + Ok(keys) + }) + .await + .unwrap() + } + + pub async fn put( + &self, + key: String, + volumes: Vec, + size: Option, + ) -> Result<(), AppError> { + let conn = self.conn.clone(); + tokio::task::spawn_blocking(move || { + let conn = conn.lock().unwrap(); + conn.prepare_cached( + "INSERT INTO kv (key, volumes, size) VALUES (?1, ?2, ?3) + ON CONFLICT(key) DO UPDATE SET volumes = ?2, size = ?3", + )? + .execute(params![key, encode_volumes(&volumes), size])?; + Ok(()) + }) + .await + .unwrap() + } + + pub async fn delete(&self, key: String) -> Result<(), AppError> { + let conn = self.conn.clone(); + tokio::task::spawn_blocking(move || { + let conn = conn.lock().unwrap(); + conn.prepare_cached("DELETE FROM kv WHERE key = ?1")? + .execute(params![key])?; + Ok(()) + }) + .await + .unwrap() + } + + pub async fn bulk_put( + &self, + records: Vec<(String, Vec, Option)>, + ) -> Result<(), AppError> { + let conn = self.conn.clone(); + tokio::task::spawn_blocking(move || { + let conn = conn.lock().unwrap(); + conn.execute_batch("BEGIN")?; + let mut stmt = conn.prepare_cached( + "INSERT INTO kv (key, volumes, size) VALUES (?1, ?2, ?3) + ON CONFLICT(key) DO UPDATE SET volumes = ?2, size = ?3", + )?; + for (key, volumes, size) in &records { + stmt.execute(params![key, encode_volumes(volumes), size])?; + } + drop(stmt); + conn.execute_batch("COMMIT")?; + Ok(()) + }) + .await + .unwrap() + } + + pub fn all_records_sync(&self) -> Result, AppError> { + let conn = self.conn.lock().unwrap(); + let mut stmt = conn.prepare_cached("SELECT key, volumes, size FROM kv")?; + let records = stmt + .query_map([], |row| { + let vj: String = row.get(1)?; + Ok(Record { + key: row.get(0)?, + volumes: parse_volumes(&vj), + size: row.get(2)?, + }) + })? + .collect::, _>>()?; + Ok(records) + } +} diff --git a/src/error.rs b/src/error.rs index d81de8d..32e25a4 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,70 +1,76 @@ -use axum::http::StatusCode; -use axum::response::{IntoResponse, Response}; - -/// Errors from individual volume HTTP requests — used for logging, not HTTP responses. -#[derive(Debug)] -pub enum VolumeError { - Request { url: String, source: reqwest::Error }, - BadStatus { url: String, status: reqwest::StatusCode }, -} - -impl std::fmt::Display for VolumeError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - VolumeError::Request { url, source } => { - write!(f, "volume request to {url} failed: {source}") - } - VolumeError::BadStatus { url, status } => { - write!(f, "volume {url} returned status {status}") - } - } - } -} - -/// Application-level errors that map to HTTP responses. -#[derive(Debug)] -pub enum AppError { - NotFound, - CorruptRecord { key: String }, - Db(rusqlite::Error), - InsufficientVolumes { need: usize, have: usize }, - PartialWrite, -} - -impl From for AppError { - fn from(e: rusqlite::Error) -> Self { - match e { - rusqlite::Error::QueryReturnedNoRows => AppError::NotFound, - other => AppError::Db(other), - } - } -} - -impl std::fmt::Display for AppError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - AppError::NotFound => write!(f, "not found"), - AppError::CorruptRecord { key } => { - write!(f, "corrupt record for key {key}: no volumes") - } - AppError::Db(e) => write!(f, "database error: {e}"), - AppError::InsufficientVolumes { need, have } => { - write!(f, "need {need} volumes but only {have} available") - } - AppError::PartialWrite => write!(f, "not all volume writes succeeded"), - } - } -} - -impl IntoResponse for AppError { - fn into_response(self) -> Response { - let status = match &self { - AppError::NotFound => StatusCode::NOT_FOUND, - AppError::CorruptRecord { .. } => StatusCode::INTERNAL_SERVER_ERROR, - AppError::Db(_) => StatusCode::INTERNAL_SERVER_ERROR, - AppError::InsufficientVolumes { .. } => StatusCode::SERVICE_UNAVAILABLE, - AppError::PartialWrite => StatusCode::BAD_GATEWAY, - }; - (status, self.to_string()).into_response() - } -} +use axum::http::StatusCode; +use axum::response::{IntoResponse, Response}; + +/// Errors from individual volume HTTP requests — used for logging, not HTTP responses. +#[derive(Debug)] +pub enum VolumeError { + Request { + url: String, + source: reqwest::Error, + }, + BadStatus { + url: String, + status: reqwest::StatusCode, + }, +} + +impl std::fmt::Display for VolumeError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + VolumeError::Request { url, source } => { + write!(f, "volume request to {url} failed: {source}") + } + VolumeError::BadStatus { url, status } => { + write!(f, "volume {url} returned status {status}") + } + } + } +} + +/// Application-level errors that map to HTTP responses. +#[derive(Debug)] +pub enum AppError { + NotFound, + CorruptRecord { key: String }, + Db(rusqlite::Error), + InsufficientVolumes { need: usize, have: usize }, + PartialWrite, +} + +impl From for AppError { + fn from(e: rusqlite::Error) -> Self { + match e { + rusqlite::Error::QueryReturnedNoRows => AppError::NotFound, + other => AppError::Db(other), + } + } +} + +impl std::fmt::Display for AppError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + AppError::NotFound => write!(f, "not found"), + AppError::CorruptRecord { key } => { + write!(f, "corrupt record for key {key}: no volumes") + } + AppError::Db(e) => write!(f, "database error: {e}"), + AppError::InsufficientVolumes { need, have } => { + write!(f, "need {need} volumes but only {have} available") + } + AppError::PartialWrite => write!(f, "not all volume writes succeeded"), + } + } +} + +impl IntoResponse for AppError { + fn into_response(self) -> Response { + let status = match &self { + AppError::NotFound => StatusCode::NOT_FOUND, + AppError::CorruptRecord { .. } => StatusCode::INTERNAL_SERVER_ERROR, + AppError::Db(_) => StatusCode::INTERNAL_SERVER_ERROR, + AppError::InsufficientVolumes { .. } => StatusCode::SERVICE_UNAVAILABLE, + AppError::PartialWrite => StatusCode::BAD_GATEWAY, + }; + (status, self.to_string()).into_response() + } +} diff --git a/src/hasher.rs b/src/hasher.rs index 9b684ea..a01bec6 100644 --- a/src/hasher.rs +++ b/src/hasher.rs @@ -1,81 +1,85 @@ -use sha2::{Digest, Sha256}; - -/// Pick `count` volumes for a key by hashing key+volume, sorting by score. -/// Same idea as minikeyvalue's key2volume — stable in volume name, not position. -pub fn volumes_for_key(key: &str, volumes: &[String], count: usize) -> Vec { - let mut scored: Vec<(u64, &String)> = volumes - .iter() - .map(|v| { - let hash = Sha256::digest(format!("{key}:{v}").as_bytes()); - let score = u64::from_be_bytes(hash[..8].try_into().unwrap()); - (score, v) - }) - .collect(); - scored.sort_by_key(|(score, _)| *score); - scored.into_iter().take(count).map(|(_, v)| v.clone()).collect() -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_deterministic() { - let volumes: Vec = (1..=3).map(|i| format!("http://vol{i}")).collect(); - let a = volumes_for_key("my-key", &volumes, 2); - let b = volumes_for_key("my-key", &volumes, 2); - assert_eq!(a, b); - } - - #[test] - fn test_count_capped() { - let volumes: Vec = (1..=2).map(|i| format!("http://vol{i}")).collect(); - let selected = volumes_for_key("key", &volumes, 5); - assert_eq!(selected.len(), 2); - } - - #[test] - fn test_even_distribution() { - let volumes: Vec = (1..=3).map(|i| format!("http://vol{i}")).collect(); - let mut counts = std::collections::HashMap::new(); - for i in 0..3000 { - let key = format!("key-{i}"); - let primary = &volumes_for_key(&key, &volumes, 1)[0]; - *counts.entry(primary.clone()).or_insert(0u32) += 1; - } - for (vol, count) in &counts { - assert!( - *count > 700 && *count < 1300, - "volume {vol} got {count} keys, expected ~1000" - ); - } - } - - #[test] - fn test_stability_on_add() { - let volumes: Vec = (1..=3).map(|i| format!("http://vol{i}")).collect(); - let mut volumes4 = volumes.clone(); - volumes4.push("http://vol4".into()); - - let total = 10000; - let mut moved = 0; - for i in 0..total { - let key = format!("key-{i}"); - let before = &volumes_for_key(&key, &volumes, 1)[0]; - let after = &volumes_for_key(&key, &volumes4, 1)[0]; - if before != after { - moved += 1; - } - } - let pct = moved as f64 / total as f64 * 100.0; - assert!( - pct > 15.0 && pct < 40.0, - "expected ~25% of keys to move, got {pct:.1}%" - ); - } - - #[test] - fn test_empty() { - assert_eq!(volumes_for_key("key", &[], 1), Vec::::new()); - } -} +use sha2::{Digest, Sha256}; + +/// Pick `count` volumes for a key by hashing key+volume, sorting by score. +/// Same idea as minikeyvalue's key2volume — stable in volume name, not position. +pub fn volumes_for_key(key: &str, volumes: &[String], count: usize) -> Vec { + let mut scored: Vec<(u64, &String)> = volumes + .iter() + .map(|v| { + let hash = Sha256::digest(format!("{key}:{v}").as_bytes()); + let score = u64::from_be_bytes(hash[..8].try_into().unwrap()); + (score, v) + }) + .collect(); + scored.sort_by_key(|(score, _)| *score); + scored + .into_iter() + .take(count) + .map(|(_, v)| v.clone()) + .collect() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_deterministic() { + let volumes: Vec = (1..=3).map(|i| format!("http://vol{i}")).collect(); + let a = volumes_for_key("my-key", &volumes, 2); + let b = volumes_for_key("my-key", &volumes, 2); + assert_eq!(a, b); + } + + #[test] + fn test_count_capped() { + let volumes: Vec = (1..=2).map(|i| format!("http://vol{i}")).collect(); + let selected = volumes_for_key("key", &volumes, 5); + assert_eq!(selected.len(), 2); + } + + #[test] + fn test_even_distribution() { + let volumes: Vec = (1..=3).map(|i| format!("http://vol{i}")).collect(); + let mut counts = std::collections::HashMap::new(); + for i in 0..3000 { + let key = format!("key-{i}"); + let primary = &volumes_for_key(&key, &volumes, 1)[0]; + *counts.entry(primary.clone()).or_insert(0u32) += 1; + } + for (vol, count) in &counts { + assert!( + *count > 700 && *count < 1300, + "volume {vol} got {count} keys, expected ~1000" + ); + } + } + + #[test] + fn test_stability_on_add() { + let volumes: Vec = (1..=3).map(|i| format!("http://vol{i}")).collect(); + let mut volumes4 = volumes.clone(); + volumes4.push("http://vol4".into()); + + let total = 10000; + let mut moved = 0; + for i in 0..total { + let key = format!("key-{i}"); + let before = &volumes_for_key(&key, &volumes, 1)[0]; + let after = &volumes_for_key(&key, &volumes4, 1)[0]; + if before != after { + moved += 1; + } + } + let pct = moved as f64 / total as f64 * 100.0; + assert!( + pct > 15.0 && pct < 40.0, + "expected ~25% of keys to move, got {pct:.1}%" + ); + } + + #[test] + fn test_empty() { + assert_eq!(volumes_for_key("key", &[], 1), Vec::::new()); + } +} diff --git a/src/lib.rs b/src/lib.rs index 4e80a27..ba576f6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,53 +1,53 @@ -pub mod db; -pub mod error; -pub mod hasher; -pub mod server; -pub mod rebalance; -pub mod rebuild; - -use std::sync::Arc; - -const DEFAULT_BODY_LIMIT: usize = 256 * 1024 * 1024; // 256 MB - -pub struct Args { - pub db_path: String, - pub volumes: Vec, - pub replicas: usize, -} - -pub fn build_app(args: &Args) -> axum::Router { - if args.replicas > args.volumes.len() { - eprintln!( - "Error: replication factor ({}) exceeds number of volumes ({})", - args.replicas, - args.volumes.len() - ); - std::process::exit(1); - } - - if let Some(parent) = std::path::Path::new(&args.db_path).parent() { - std::fs::create_dir_all(parent).unwrap_or_else(|e| { - eprintln!("Failed to create database directory: {e}"); - std::process::exit(1); - }); - } - - let state = server::AppState { - db: db::Db::new(&args.db_path), - volumes: Arc::new(args.volumes.clone()), - replicas: args.replicas, - http: reqwest::Client::new(), - }; - - axum::Router::new() - .route("/", axum::routing::get(server::list_keys)) - .route( - "/{*key}", - axum::routing::get(server::get_key) - .put(server::put_key) - .delete(server::delete_key) - .head(server::head_key), - ) - .layer(axum::extract::DefaultBodyLimit::max(DEFAULT_BODY_LIMIT)) - .with_state(state) -} +pub mod db; +pub mod error; +pub mod hasher; +pub mod rebalance; +pub mod rebuild; +pub mod server; + +use std::sync::Arc; + +const DEFAULT_BODY_LIMIT: usize = 256 * 1024 * 1024; // 256 MB + +pub struct Args { + pub db_path: String, + pub volumes: Vec, + pub replicas: usize, +} + +pub fn build_app(args: &Args) -> axum::Router { + if args.replicas > args.volumes.len() { + eprintln!( + "Error: replication factor ({}) exceeds number of volumes ({})", + args.replicas, + args.volumes.len() + ); + std::process::exit(1); + } + + if let Some(parent) = std::path::Path::new(&args.db_path).parent() { + std::fs::create_dir_all(parent).unwrap_or_else(|e| { + eprintln!("Failed to create database directory: {e}"); + std::process::exit(1); + }); + } + + let state = server::AppState { + db: db::Db::new(&args.db_path), + volumes: Arc::new(args.volumes.clone()), + replicas: args.replicas, + http: reqwest::Client::new(), + }; + + axum::Router::new() + .route("/", axum::routing::get(server::list_keys)) + .route( + "/{*key}", + axum::routing::get(server::get_key) + .put(server::put_key) + .delete(server::delete_key) + .head(server::head_key), + ) + .layer(axum::extract::DefaultBodyLimit::max(DEFAULT_BODY_LIMIT)) + .with_state(state) +} diff --git a/src/main.rs b/src/main.rs index dc49008..f838428 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,7 +6,13 @@ struct Cli { #[arg(short, long, env = "MKV_DB", default_value = "/tmp/mkv/index.db")] db: String, - #[arg(short, long, env = "MKV_VOLUMES", required = true, value_delimiter = ',')] + #[arg( + short, + long, + env = "MKV_VOLUMES", + required = true, + value_delimiter = ',' + )] volumes: Vec, #[arg(short, long, env = "MKV_REPLICAS", default_value_t = 2)] @@ -36,9 +42,8 @@ async fn shutdown_signal() { let ctrl_c = tokio::signal::ctrl_c(); #[cfg(unix)] { - let mut sigterm = - tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) - .expect("failed to install SIGTERM handler"); + let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate()) + .expect("failed to install SIGTERM handler"); tokio::select! { _ = ctrl_c => tracing::info!("Received SIGINT, shutting down..."), _ = sigterm.recv() => tracing::info!("Received SIGTERM, shutting down..."), diff --git a/src/rebalance.rs b/src/rebalance.rs index e63ec12..1d9be5d 100644 --- a/src/rebalance.rs +++ b/src/rebalance.rs @@ -1,159 +1,194 @@ -use crate::db; -use crate::Args; - -pub struct KeyMove { - pub key: String, - pub size: Option, - pub current_volumes: Vec, - pub desired_volumes: Vec, - pub to_add: Vec, - pub to_remove: Vec, -} - -pub fn plan_rebalance(records: &[db::Record], volumes: &[String], replication: usize) -> Vec { - let mut moves = Vec::new(); - for record in records { - let desired = crate::hasher::volumes_for_key(&record.key, volumes, replication); - let to_add: Vec = desired.iter().filter(|v| !record.volumes.contains(v)).cloned().collect(); - let to_remove: Vec = record.volumes.iter().filter(|v| !desired.contains(v)).cloned().collect(); - - if !to_add.is_empty() || !to_remove.is_empty() { - moves.push(KeyMove { - key: record.key.clone(), - size: record.size, - current_volumes: record.volumes.clone(), - desired_volumes: desired, - to_add, - to_remove, - }); - } - } - moves -} - -pub async fn run(args: &Args, dry_run: bool) { - let db = db::Db::new(&args.db_path); - let records = db.all_records_sync().expect("failed to read records"); - let moves = plan_rebalance(&records, &args.volumes, args.replicas); - - if moves.is_empty() { - eprintln!("Nothing to rebalance — all keys are already correctly placed."); - return; - } - - let total_bytes: i64 = moves.iter().filter_map(|m| m.size).sum(); - eprintln!("{} keys to move ({} bytes)", moves.len(), total_bytes); - - if dry_run { - for m in &moves { - eprintln!(" {} : add {:?}, remove {:?}", m.key, m.to_add, m.to_remove); - } - return; - } - - let client = reqwest::Client::new(); - let mut moved = 0; - let mut errors = 0; - - for m in &moves { - let Some(src) = m.current_volumes.first() else { - eprintln!(" SKIP {} : no source volume", m.key); - errors += 1; - continue; - }; - let mut copy_ok = true; - - for dst in &m.to_add { - let src_url = format!("{src}/{}", m.key); - match client.get(&src_url).send().await { - Ok(resp) if resp.status().is_success() => { - let data = match resp.bytes().await { - Ok(b) => b, - Err(e) => { - eprintln!(" ERROR read body {} from {}: {}", m.key, src, e); - copy_ok = false; - errors += 1; - break; - } - }; - let dst_url = format!("{dst}/{}", m.key); - match client.put(&dst_url).body(data).send().await { - Ok(resp) if !resp.status().is_success() => { - eprintln!(" ERROR copy {} to {}: status {}", m.key, dst, resp.status()); - copy_ok = false; - errors += 1; - } - Err(e) => { - eprintln!(" ERROR copy {} to {}: {}", m.key, dst, e); - copy_ok = false; - errors += 1; - } - Ok(_) => {} - } - } - Ok(resp) => { - eprintln!(" ERROR read {} from {}: status {}", m.key, src, resp.status()); - copy_ok = false; - errors += 1; - } - Err(e) => { - eprintln!(" ERROR read {} from {}: {}", m.key, src, e); - copy_ok = false; - errors += 1; - } - } - } - - if !copy_ok { continue; } - - db.put(m.key.clone(), m.desired_volumes.clone(), m.size).await.expect("failed to update index"); - - for old in &m.to_remove { - let url = format!("{old}/{}", m.key); - if let Err(e) = client.delete(&url).send().await { - eprintln!(" WARN delete {} from {}: {}", m.key, old, e); - } - } - moved += 1; - } - - eprintln!("Rebalanced {moved}/{} keys ({errors} errors)", moves.len()); -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_plan_rebalance_no_change() { - let volumes: Vec = (1..=3).map(|i| format!("http://vol{i}")).collect(); - let records: Vec = (0..100) - .map(|i| { - let key = format!("key-{i}"); - let vols = crate::hasher::volumes_for_key(&key, &volumes, 2); - db::Record { key, volumes: vols, size: Some(100) } - }) - .collect(); - - let moves = plan_rebalance(&records, &volumes, 2); - assert!(moves.is_empty()); - } - - #[test] - fn test_plan_rebalance_new_volume() { - let volumes3: Vec = (1..=3).map(|i| format!("http://vol{i}")).collect(); - let records: Vec = (0..1000) - .map(|i| { - let key = format!("key-{i}"); - let vols = crate::hasher::volumes_for_key(&key, &volumes3, 2); - db::Record { key, volumes: vols, size: Some(100) } - }) - .collect(); - - let volumes4: Vec = (1..=4).map(|i| format!("http://vol{i}")).collect(); - let moves = plan_rebalance(&records, &volumes4, 2); - - assert!(!moves.is_empty()); - assert!(moves.len() < 800, "too many moves: {}", moves.len()); - } -} +use crate::Args; +use crate::db; + +pub struct KeyMove { + pub key: String, + pub size: Option, + pub current_volumes: Vec, + pub desired_volumes: Vec, + pub to_add: Vec, + pub to_remove: Vec, +} + +pub fn plan_rebalance( + records: &[db::Record], + volumes: &[String], + replication: usize, +) -> Vec { + let mut moves = Vec::new(); + for record in records { + let desired = crate::hasher::volumes_for_key(&record.key, volumes, replication); + let to_add: Vec = desired + .iter() + .filter(|v| !record.volumes.contains(v)) + .cloned() + .collect(); + let to_remove: Vec = record + .volumes + .iter() + .filter(|v| !desired.contains(v)) + .cloned() + .collect(); + + if !to_add.is_empty() || !to_remove.is_empty() { + moves.push(KeyMove { + key: record.key.clone(), + size: record.size, + current_volumes: record.volumes.clone(), + desired_volumes: desired, + to_add, + to_remove, + }); + } + } + moves +} + +pub async fn run(args: &Args, dry_run: bool) { + let db = db::Db::new(&args.db_path); + let records = db.all_records_sync().expect("failed to read records"); + let moves = plan_rebalance(&records, &args.volumes, args.replicas); + + if moves.is_empty() { + eprintln!("Nothing to rebalance — all keys are already correctly placed."); + return; + } + + let total_bytes: i64 = moves.iter().filter_map(|m| m.size).sum(); + eprintln!("{} keys to move ({} bytes)", moves.len(), total_bytes); + + if dry_run { + for m in &moves { + eprintln!(" {} : add {:?}, remove {:?}", m.key, m.to_add, m.to_remove); + } + return; + } + + let client = reqwest::Client::new(); + let mut moved = 0; + let mut errors = 0; + + for m in &moves { + let Some(src) = m.current_volumes.first() else { + eprintln!(" SKIP {} : no source volume", m.key); + errors += 1; + continue; + }; + let mut copy_ok = true; + + for dst in &m.to_add { + let src_url = format!("{src}/{}", m.key); + match client.get(&src_url).send().await { + Ok(resp) if resp.status().is_success() => { + let data = match resp.bytes().await { + Ok(b) => b, + Err(e) => { + eprintln!(" ERROR read body {} from {}: {}", m.key, src, e); + copy_ok = false; + errors += 1; + break; + } + }; + let dst_url = format!("{dst}/{}", m.key); + match client.put(&dst_url).body(data).send().await { + Ok(resp) if !resp.status().is_success() => { + eprintln!( + " ERROR copy {} to {}: status {}", + m.key, + dst, + resp.status() + ); + copy_ok = false; + errors += 1; + } + Err(e) => { + eprintln!(" ERROR copy {} to {}: {}", m.key, dst, e); + copy_ok = false; + errors += 1; + } + Ok(_) => {} + } + } + Ok(resp) => { + eprintln!( + " ERROR read {} from {}: status {}", + m.key, + src, + resp.status() + ); + copy_ok = false; + errors += 1; + } + Err(e) => { + eprintln!(" ERROR read {} from {}: {}", m.key, src, e); + copy_ok = false; + errors += 1; + } + } + } + + if !copy_ok { + continue; + } + + db.put(m.key.clone(), m.desired_volumes.clone(), m.size) + .await + .expect("failed to update index"); + + for old in &m.to_remove { + let url = format!("{old}/{}", m.key); + if let Err(e) = client.delete(&url).send().await { + eprintln!(" WARN delete {} from {}: {}", m.key, old, e); + } + } + moved += 1; + } + + eprintln!("Rebalanced {moved}/{} keys ({errors} errors)", moves.len()); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_plan_rebalance_no_change() { + let volumes: Vec = (1..=3).map(|i| format!("http://vol{i}")).collect(); + let records: Vec = (0..100) + .map(|i| { + let key = format!("key-{i}"); + let vols = crate::hasher::volumes_for_key(&key, &volumes, 2); + db::Record { + key, + volumes: vols, + size: Some(100), + } + }) + .collect(); + + let moves = plan_rebalance(&records, &volumes, 2); + assert!(moves.is_empty()); + } + + #[test] + fn test_plan_rebalance_new_volume() { + let volumes3: Vec = (1..=3).map(|i| format!("http://vol{i}")).collect(); + let records: Vec = (0..1000) + .map(|i| { + let key = format!("key-{i}"); + let vols = crate::hasher::volumes_for_key(&key, &volumes3, 2); + db::Record { + key, + volumes: vols, + size: Some(100), + } + }) + .collect(); + + let volumes4: Vec = (1..=4).map(|i| format!("http://vol{i}")).collect(); + let moves = plan_rebalance(&records, &volumes4, 2); + + assert!(!moves.is_empty()); + assert!(moves.len() < 800, "too many moves: {}", moves.len()); + } +} diff --git a/src/rebuild.rs b/src/rebuild.rs index 1ba67ed..f0f3121 100644 --- a/src/rebuild.rs +++ b/src/rebuild.rs @@ -1,72 +1,86 @@ -use std::collections::HashMap; - -use crate::db; -use crate::Args; - -#[derive(serde::Deserialize)] -struct NginxEntry { - name: String, - #[serde(rename = "type")] - entry_type: String, - #[serde(default)] - size: Option, -} - -async fn list_volume_keys(volume_url: &str) -> Result, String> { - let http = reqwest::Client::new(); - let mut keys = Vec::new(); - let mut dirs = vec![String::new()]; - - while let Some(prefix) = dirs.pop() { - let url = format!("{volume_url}/{prefix}"); - let resp = http.get(&url).send().await.map_err(|e| format!("GET {url}: {e}"))?; - if !resp.status().is_success() { - return Err(format!("GET {url}: status {}", resp.status())); - } - let entries: Vec = resp.json().await.map_err(|e| format!("parse {url}: {e}"))?; - for entry in entries { - let full_path = if prefix.is_empty() { entry.name.clone() } else { format!("{prefix}{}", entry.name) }; - match entry.entry_type.as_str() { - "directory" => dirs.push(format!("{full_path}/")), - "file" => keys.push((full_path, entry.size.unwrap_or(0))), - _ => {} - } - } - } - Ok(keys) -} - -pub async fn run(args: &Args) { - let db_path = &args.db_path; - - if let Some(parent) = std::path::Path::new(db_path).parent() { - let _ = std::fs::create_dir_all(parent); - } - - let _ = std::fs::remove_file(db_path); - let _ = std::fs::remove_file(format!("{db_path}-wal")); - let _ = std::fs::remove_file(format!("{db_path}-shm")); - - let db = db::Db::new(db_path); - let mut index: HashMap, i64)> = HashMap::new(); - - for vol_url in &args.volumes { - eprintln!("Scanning {vol_url}..."); - match list_volume_keys(vol_url).await { - Ok(keys) => { - eprintln!(" Found {} keys", keys.len()); - for (key, size) in keys { - let entry = index.entry(key).or_insert_with(|| (Vec::new(), size)); - entry.0.push(vol_url.clone()); - if size > entry.1 { entry.1 = size; } - } - } - Err(e) => eprintln!(" Error scanning {vol_url}: {e}"), - } - } - - let records: Vec<_> = index.into_iter().map(|(k, (v, s))| (k, v, Some(s))).collect(); - let count = records.len(); - db.bulk_put(records).await.expect("bulk_put failed"); - eprintln!("Rebuilt index with {count} keys"); -} +use std::collections::HashMap; + +use crate::Args; +use crate::db; + +#[derive(serde::Deserialize)] +struct NginxEntry { + name: String, + #[serde(rename = "type")] + entry_type: String, + #[serde(default)] + size: Option, +} + +async fn list_volume_keys(volume_url: &str) -> Result, String> { + let http = reqwest::Client::new(); + let mut keys = Vec::new(); + let mut dirs = vec![String::new()]; + + while let Some(prefix) = dirs.pop() { + let url = format!("{volume_url}/{prefix}"); + let resp = http + .get(&url) + .send() + .await + .map_err(|e| format!("GET {url}: {e}"))?; + if !resp.status().is_success() { + return Err(format!("GET {url}: status {}", resp.status())); + } + let entries: Vec = + resp.json().await.map_err(|e| format!("parse {url}: {e}"))?; + for entry in entries { + let full_path = if prefix.is_empty() { + entry.name.clone() + } else { + format!("{prefix}{}", entry.name) + }; + match entry.entry_type.as_str() { + "directory" => dirs.push(format!("{full_path}/")), + "file" => keys.push((full_path, entry.size.unwrap_or(0))), + _ => {} + } + } + } + Ok(keys) +} + +pub async fn run(args: &Args) { + let db_path = &args.db_path; + + if let Some(parent) = std::path::Path::new(db_path).parent() { + let _ = std::fs::create_dir_all(parent); + } + + let _ = std::fs::remove_file(db_path); + let _ = std::fs::remove_file(format!("{db_path}-wal")); + let _ = std::fs::remove_file(format!("{db_path}-shm")); + + let db = db::Db::new(db_path); + let mut index: HashMap, i64)> = HashMap::new(); + + for vol_url in &args.volumes { + eprintln!("Scanning {vol_url}..."); + match list_volume_keys(vol_url).await { + Ok(keys) => { + eprintln!(" Found {} keys", keys.len()); + for (key, size) in keys { + let entry = index.entry(key).or_insert_with(|| (Vec::new(), size)); + entry.0.push(vol_url.clone()); + if size > entry.1 { + entry.1 = size; + } + } + } + Err(e) => eprintln!(" Error scanning {vol_url}: {e}"), + } + } + + let records: Vec<_> = index + .into_iter() + .map(|(k, (v, s))| (k, v, Some(s))) + .collect(); + let count = records.len(); + db.bulk_put(records).await.expect("bulk_put failed"); + eprintln!("Rebuilt index with {count} keys"); +} diff --git a/src/server.rs b/src/server.rs index d2e1845..a4b24d4 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,173 +1,199 @@ -use axum::body::Bytes; -use axum::extract::{Path, Query, State}; -use axum::http::{HeaderMap, StatusCode}; -use axum::response::{IntoResponse, Response}; -use std::sync::Arc; - -use crate::db; -use crate::error::{AppError, VolumeError}; - -#[derive(Clone)] -pub struct AppState { - pub db: db::Db, - pub volumes: Arc>, - pub replicas: usize, - pub http: reqwest::Client, -} - -pub async fn get_key( - State(state): State, - Path(key): Path, -) -> Result { - let record = state.db.get(&key).await?; - let vol = record - .volumes - .first() - .ok_or_else(|| AppError::CorruptRecord { key: key.clone() })?; - let location = format!("{vol}/{key}"); - Ok((StatusCode::FOUND, [(axum::http::header::LOCATION, location)]).into_response()) -} - -pub async fn put_key( - State(state): State, - Path(key): Path, - body: Bytes, -) -> Result { - let target_volumes = crate::hasher::volumes_for_key(&key, &state.volumes, state.replicas); - if target_volumes.len() < state.replicas { - return Err(AppError::InsufficientVolumes { - need: state.replicas, - have: target_volumes.len(), - }); - } - - // Fan out PUTs to all target volumes concurrently - let mut handles = Vec::with_capacity(target_volumes.len()); - for vol in &target_volumes { - let url = format!("{vol}/{key}"); - let handle = tokio::spawn({ - let client = state.http.clone(); - let data = body.clone(); - async move { - let resp = client.put(&url).body(data).send().await.map_err(|e| { - VolumeError::Request { url: url.clone(), source: e } - })?; - if !resp.status().is_success() { - return Err(VolumeError::BadStatus { url, status: resp.status() }); - } - Ok(()) - } - }); - handles.push(handle); - } - - let mut failed = false; - for handle in handles { - match handle.await { - Ok(Err(e)) => { - tracing::error!("{e}"); - failed = true; - } - Err(e) => { - tracing::error!("volume write task failed: {e}"); - failed = true; - } - Ok(Ok(())) => {} - } - } - - if failed { - // Rollback: best-effort delete from volumes - for vol in &target_volumes { - let _ = state.http.delete(format!("{vol}/{key}")).send().await; - } - return Err(AppError::PartialWrite); - } - - let size = Some(body.len() as i64); - if let Err(e) = state.db.put(key.clone(), target_volumes.clone(), size).await { - for vol in &target_volumes { - let _ = state.http.delete(format!("{vol}/{key}")).send().await; - } - return Err(e); - } - Ok(StatusCode::CREATED.into_response()) -} - -pub async fn delete_key( - State(state): State, - Path(key): Path, -) -> Result { - let record = state.db.get(&key).await?; - - let mut handles = Vec::new(); - for vol in &record.volumes { - let url = format!("{vol}/{key}"); - let handle = tokio::spawn({ - let client = state.http.clone(); - async move { - let resp = client.delete(&url).send().await.map_err(|e| { - VolumeError::Request { url: url.clone(), source: e } - })?; - if !resp.status().is_success() { - return Err(VolumeError::BadStatus { url, status: resp.status() }); - } - Ok(()) - } - }); - handles.push(handle); - } - for handle in handles { - match handle.await { - Ok(Err(e)) => tracing::error!("{e}"), - Err(e) => tracing::error!("volume delete task failed: {e}"), - Ok(Ok(())) => {} - } - } - - state.db.delete(key).await?; - Ok(StatusCode::NO_CONTENT.into_response()) -} - -pub async fn head_key( - State(state): State, - Path(key): Path, -) -> Result { - let record = state.db.get(&key).await?; - let mut headers = HeaderMap::new(); - if let Some(size) = record.size { - headers.insert(axum::http::header::CONTENT_LENGTH, size.to_string().parse().unwrap()); - } - Ok((StatusCode::OK, headers).into_response()) -} - -#[derive(serde::Deserialize)] -pub struct ListQuery { - #[serde(default)] - pub prefix: String, -} - -pub async fn list_keys( - State(state): State, - Query(query): Query, -) -> Result { - let keys = state.db.list_keys(&query.prefix).await?; - Ok((StatusCode::OK, keys.join("\n")).into_response()) -} - -#[cfg(test)] -mod tests { - #[test] - fn test_volumes_for_key_sufficient() { - let volumes: Vec = (1..=3).map(|i| format!("http://vol{i}")).collect(); - let selected = crate::hasher::volumes_for_key("test-key", &volumes, 2); - assert_eq!(selected.len(), 2); - } - - #[test] - fn test_volumes_for_key_insufficient() { - let volumes: Vec = vec!["http://vol1".into()]; - let selected = crate::hasher::volumes_for_key("test-key", &volumes, 2); - assert_eq!(selected.len(), 1); - } -} +use axum::body::Bytes; +use axum::extract::{Path, Query, State}; +use axum::http::{HeaderMap, StatusCode}; +use axum::response::{IntoResponse, Response}; +use std::sync::Arc; + +use crate::db; +use crate::error::{AppError, VolumeError}; + +#[derive(Clone)] +pub struct AppState { + pub db: db::Db, + pub volumes: Arc>, + pub replicas: usize, + pub http: reqwest::Client, +} + +pub async fn get_key( + State(state): State, + Path(key): Path, +) -> Result { + let record = state.db.get(&key).await?; + let vol = record + .volumes + .first() + .ok_or_else(|| AppError::CorruptRecord { key: key.clone() })?; + let location = format!("{vol}/{key}"); + Ok(( + StatusCode::FOUND, + [(axum::http::header::LOCATION, location)], + ) + .into_response()) +} + +pub async fn put_key( + State(state): State, + Path(key): Path, + body: Bytes, +) -> Result { + let target_volumes = crate::hasher::volumes_for_key(&key, &state.volumes, state.replicas); + if target_volumes.len() < state.replicas { + return Err(AppError::InsufficientVolumes { + need: state.replicas, + have: target_volumes.len(), + }); + } + + // Fan out PUTs to all target volumes concurrently + let mut handles = Vec::with_capacity(target_volumes.len()); + for vol in &target_volumes { + let url = format!("{vol}/{key}"); + let handle = + tokio::spawn({ + let client = state.http.clone(); + let data = body.clone(); + async move { + let resp = client.put(&url).body(data).send().await.map_err(|e| { + VolumeError::Request { + url: url.clone(), + source: e, + } + })?; + if !resp.status().is_success() { + return Err(VolumeError::BadStatus { + url, + status: resp.status(), + }); + } + Ok(()) + } + }); + handles.push(handle); + } + + let mut failed = false; + for handle in handles { + match handle.await { + Ok(Err(e)) => { + tracing::error!("{e}"); + failed = true; + } + Err(e) => { + tracing::error!("volume write task failed: {e}"); + failed = true; + } + Ok(Ok(())) => {} + } + } + + if failed { + // Rollback: best-effort delete from volumes + for vol in &target_volumes { + let _ = state.http.delete(format!("{vol}/{key}")).send().await; + } + return Err(AppError::PartialWrite); + } + + let size = Some(body.len() as i64); + if let Err(e) = state + .db + .put(key.clone(), target_volumes.clone(), size) + .await + { + for vol in &target_volumes { + let _ = state.http.delete(format!("{vol}/{key}")).send().await; + } + return Err(e); + } + Ok(StatusCode::CREATED.into_response()) +} + +pub async fn delete_key( + State(state): State, + Path(key): Path, +) -> Result { + let record = state.db.get(&key).await?; + + let mut handles = Vec::new(); + for vol in &record.volumes { + let url = format!("{vol}/{key}"); + let handle = tokio::spawn({ + let client = state.http.clone(); + async move { + let resp = client + .delete(&url) + .send() + .await + .map_err(|e| VolumeError::Request { + url: url.clone(), + source: e, + })?; + if !resp.status().is_success() { + return Err(VolumeError::BadStatus { + url, + status: resp.status(), + }); + } + Ok(()) + } + }); + handles.push(handle); + } + for handle in handles { + match handle.await { + Ok(Err(e)) => tracing::error!("{e}"), + Err(e) => tracing::error!("volume delete task failed: {e}"), + Ok(Ok(())) => {} + } + } + + state.db.delete(key).await?; + Ok(StatusCode::NO_CONTENT.into_response()) +} + +pub async fn head_key( + State(state): State, + Path(key): Path, +) -> Result { + let record = state.db.get(&key).await?; + let mut headers = HeaderMap::new(); + if let Some(size) = record.size { + headers.insert( + axum::http::header::CONTENT_LENGTH, + size.to_string().parse().unwrap(), + ); + } + Ok((StatusCode::OK, headers).into_response()) +} + +#[derive(serde::Deserialize)] +pub struct ListQuery { + #[serde(default)] + pub prefix: String, +} + +pub async fn list_keys( + State(state): State, + Query(query): Query, +) -> Result { + let keys = state.db.list_keys(&query.prefix).await?; + Ok((StatusCode::OK, keys.join("\n")).into_response()) +} + +#[cfg(test)] +mod tests { + #[test] + fn test_volumes_for_key_sufficient() { + let volumes: Vec = (1..=3).map(|i| format!("http://vol{i}")).collect(); + let selected = crate::hasher::volumes_for_key("test-key", &volumes, 2); + assert_eq!(selected.len(), 2); + } + + #[test] + fn test_volumes_for_key_insufficient() { + let volumes: Vec = vec!["http://vol1".into()]; + let selected = crate::hasher::volumes_for_key("test-key", &volumes, 2); + assert_eq!(selected.len(), 1); + } +} diff --git a/tests/integration.rs b/tests/integration.rs index 7473358..890b23a 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -1,166 +1,223 @@ -use reqwest::StatusCode; -use std::sync::atomic::{AtomicU32, Ordering}; - -static TEST_COUNTER: AtomicU32 = AtomicU32::new(0); - -async fn start_server() -> String { - let id = TEST_COUNTER.fetch_add(1, Ordering::Relaxed); - let db_path = format!("/tmp/mkv-test/index-{id}.db"); - - let _ = std::fs::remove_file(&db_path); - let _ = std::fs::remove_file(format!("{db_path}-wal")); - let _ = std::fs::remove_file(format!("{db_path}-shm")); - - let args = mkv::Args { - db_path, - volumes: vec![ - "http://localhost:3101".into(), - "http://localhost:3102".into(), - "http://localhost:3103".into(), - ], - replicas: 2, - }; - - let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); - let port = listener.local_addr().unwrap().port(); - let app = mkv::build_app(&args); - - tokio::spawn(async move { - axum::serve(listener, app).await.unwrap(); - }); - - tokio::time::sleep(std::time::Duration::from_millis(50)).await; - format!("http://127.0.0.1:{port}") -} - -fn client() -> reqwest::Client { - reqwest::Client::builder() - .redirect(reqwest::redirect::Policy::none()) - .build() - .unwrap() -} - -#[tokio::test] -async fn test_put_and_head() { - let base = start_server().await; - let c = client(); - - let resp = c.put(format!("{base}/hello")).body("world").send().await.unwrap(); - assert_eq!(resp.status(), StatusCode::CREATED); - - let resp = c.head(format!("{base}/hello")).send().await.unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - assert_eq!(resp.headers().get("content-length").unwrap().to_str().unwrap(), "5"); -} - -#[tokio::test] -async fn test_put_and_get_redirect() { - let base = start_server().await; - let c = client(); - - let resp = c.put(format!("{base}/redirect-test")).body("some data").send().await.unwrap(); - assert_eq!(resp.status(), StatusCode::CREATED); - - let resp = c.get(format!("{base}/redirect-test")).send().await.unwrap(); - assert_eq!(resp.status(), StatusCode::FOUND); - - let location = resp.headers().get("location").unwrap().to_str().unwrap(); - assert!(location.starts_with("http://localhost:310"), "got: {location}"); - - let blob_resp = reqwest::get(location).await.unwrap(); - assert_eq!(blob_resp.status(), StatusCode::OK); - assert_eq!(blob_resp.text().await.unwrap(), "some data"); -} - -#[tokio::test] -async fn test_get_nonexistent_returns_404() { - let base = start_server().await; - let c = client(); - let resp = c.get(format!("{base}/does-not-exist")).send().await.unwrap(); - assert_eq!(resp.status(), StatusCode::NOT_FOUND); -} - -#[tokio::test] -async fn test_put_get_delete_get() { - let base = start_server().await; - let c = client(); - - let resp = c.put(format!("{base}/delete-me")).body("temporary").send().await.unwrap(); - assert_eq!(resp.status(), StatusCode::CREATED); - - let resp = c.get(format!("{base}/delete-me")).send().await.unwrap(); - assert_eq!(resp.status(), StatusCode::FOUND); - - let resp = c.delete(format!("{base}/delete-me")).send().await.unwrap(); - assert_eq!(resp.status(), StatusCode::NO_CONTENT); - - let resp = c.get(format!("{base}/delete-me")).send().await.unwrap(); - assert_eq!(resp.status(), StatusCode::NOT_FOUND); -} - -#[tokio::test] -async fn test_delete_nonexistent_returns_404() { - let base = start_server().await; - let c = client(); - let resp = c.delete(format!("{base}/never-existed")).send().await.unwrap(); - assert_eq!(resp.status(), StatusCode::NOT_FOUND); -} - -#[tokio::test] -async fn test_list_keys() { - let base = start_server().await; - let c = client(); - - for name in ["docs/a", "docs/b", "docs/c", "other/x"] { - c.put(format!("{base}/{name}")).body("data").send().await.unwrap(); - } - - let resp = c.get(format!("{base}/")).send().await.unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - let body = resp.text().await.unwrap(); - assert!(body.contains("docs/a")); - assert!(body.contains("other/x")); - - let resp = c.get(format!("{base}/?prefix=docs/")).send().await.unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - let body = resp.text().await.unwrap(); - let lines: Vec<&str> = body.lines().collect(); - assert_eq!(lines.len(), 3); - assert!(!body.contains("other/x")); -} - -#[tokio::test] -async fn test_put_overwrite() { - let base = start_server().await; - let c = client(); - - c.put(format!("{base}/overwrite")).body("version1").send().await.unwrap(); - - let resp = c.put(format!("{base}/overwrite")).body("version2").send().await.unwrap(); - assert_eq!(resp.status(), StatusCode::CREATED); - - let resp = c.head(format!("{base}/overwrite")).send().await.unwrap(); - assert_eq!(resp.headers().get("content-length").unwrap().to_str().unwrap(), "8"); - - let resp = c.get(format!("{base}/overwrite")).send().await.unwrap(); - let location = resp.headers().get("location").unwrap().to_str().unwrap(); - let body = reqwest::get(location).await.unwrap().text().await.unwrap(); - assert_eq!(body, "version2"); -} - -#[tokio::test] -async fn test_replication_writes_to_multiple_volumes() { - let base = start_server().await; - let c = client(); - - c.put(format!("{base}/replicated")).body("replica-data").send().await.unwrap(); - - let resp = c.head(format!("{base}/replicated")).send().await.unwrap(); - assert_eq!(resp.status(), StatusCode::OK); - - let resp = c.get(format!("{base}/replicated")).send().await.unwrap(); - assert_eq!(resp.status(), StatusCode::FOUND); - let location = resp.headers().get("location").unwrap().to_str().unwrap(); - let body = reqwest::get(location).await.unwrap().text().await.unwrap(); - assert_eq!(body, "replica-data"); -} +use reqwest::StatusCode; +use std::sync::atomic::{AtomicU32, Ordering}; + +static TEST_COUNTER: AtomicU32 = AtomicU32::new(0); + +async fn start_server() -> String { + let id = TEST_COUNTER.fetch_add(1, Ordering::Relaxed); + let db_path = format!("/tmp/mkv-test/index-{id}.db"); + + let _ = std::fs::remove_file(&db_path); + let _ = std::fs::remove_file(format!("{db_path}-wal")); + let _ = std::fs::remove_file(format!("{db_path}-shm")); + + let args = mkv::Args { + db_path, + volumes: vec![ + "http://localhost:3101".into(), + "http://localhost:3102".into(), + "http://localhost:3103".into(), + ], + replicas: 2, + }; + + let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + let app = mkv::build_app(&args); + + tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + format!("http://127.0.0.1:{port}") +} + +fn client() -> reqwest::Client { + reqwest::Client::builder() + .redirect(reqwest::redirect::Policy::none()) + .build() + .unwrap() +} + +#[tokio::test] +async fn test_put_and_head() { + let base = start_server().await; + let c = client(); + + let resp = c + .put(format!("{base}/hello")) + .body("world") + .send() + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::CREATED); + + let resp = c.head(format!("{base}/hello")).send().await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + assert_eq!( + resp.headers() + .get("content-length") + .unwrap() + .to_str() + .unwrap(), + "5" + ); +} + +#[tokio::test] +async fn test_put_and_get_redirect() { + let base = start_server().await; + let c = client(); + + let resp = c + .put(format!("{base}/redirect-test")) + .body("some data") + .send() + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::CREATED); + + let resp = c.get(format!("{base}/redirect-test")).send().await.unwrap(); + assert_eq!(resp.status(), StatusCode::FOUND); + + let location = resp.headers().get("location").unwrap().to_str().unwrap(); + assert!( + location.starts_with("http://localhost:310"), + "got: {location}" + ); + + let blob_resp = reqwest::get(location).await.unwrap(); + assert_eq!(blob_resp.status(), StatusCode::OK); + assert_eq!(blob_resp.text().await.unwrap(), "some data"); +} + +#[tokio::test] +async fn test_get_nonexistent_returns_404() { + let base = start_server().await; + let c = client(); + let resp = c + .get(format!("{base}/does-not-exist")) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); +} + +#[tokio::test] +async fn test_put_get_delete_get() { + let base = start_server().await; + let c = client(); + + let resp = c + .put(format!("{base}/delete-me")) + .body("temporary") + .send() + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::CREATED); + + let resp = c.get(format!("{base}/delete-me")).send().await.unwrap(); + assert_eq!(resp.status(), StatusCode::FOUND); + + let resp = c.delete(format!("{base}/delete-me")).send().await.unwrap(); + assert_eq!(resp.status(), StatusCode::NO_CONTENT); + + let resp = c.get(format!("{base}/delete-me")).send().await.unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); +} + +#[tokio::test] +async fn test_delete_nonexistent_returns_404() { + let base = start_server().await; + let c = client(); + let resp = c + .delete(format!("{base}/never-existed")) + .send() + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::NOT_FOUND); +} + +#[tokio::test] +async fn test_list_keys() { + let base = start_server().await; + let c = client(); + + for name in ["docs/a", "docs/b", "docs/c", "other/x"] { + c.put(format!("{base}/{name}")) + .body("data") + .send() + .await + .unwrap(); + } + + let resp = c.get(format!("{base}/")).send().await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + let body = resp.text().await.unwrap(); + assert!(body.contains("docs/a")); + assert!(body.contains("other/x")); + + let resp = c.get(format!("{base}/?prefix=docs/")).send().await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + let body = resp.text().await.unwrap(); + let lines: Vec<&str> = body.lines().collect(); + assert_eq!(lines.len(), 3); + assert!(!body.contains("other/x")); +} + +#[tokio::test] +async fn test_put_overwrite() { + let base = start_server().await; + let c = client(); + + c.put(format!("{base}/overwrite")) + .body("version1") + .send() + .await + .unwrap(); + + let resp = c + .put(format!("{base}/overwrite")) + .body("version2") + .send() + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::CREATED); + + let resp = c.head(format!("{base}/overwrite")).send().await.unwrap(); + assert_eq!( + resp.headers() + .get("content-length") + .unwrap() + .to_str() + .unwrap(), + "8" + ); + + let resp = c.get(format!("{base}/overwrite")).send().await.unwrap(); + let location = resp.headers().get("location").unwrap().to_str().unwrap(); + let body = reqwest::get(location).await.unwrap().text().await.unwrap(); + assert_eq!(body, "version2"); +} + +#[tokio::test] +async fn test_replication_writes_to_multiple_volumes() { + let base = start_server().await; + let c = client(); + + c.put(format!("{base}/replicated")) + .body("replica-data") + .send() + .await + .unwrap(); + + let resp = c.head(format!("{base}/replicated")).send().await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + + let resp = c.get(format!("{base}/replicated")).send().await.unwrap(); + assert_eq!(resp.status(), StatusCode::FOUND); + let location = resp.headers().get("location").unwrap().to_str().unwrap(); + let body = reqwest::get(location).await.unwrap().text().await.unwrap(); + assert_eq!(body, "replica-data"); +}