diff --git a/src/api_server.rs b/src/api_server.rs index cf70dbdf..c1b4d81d 100644 --- a/src/api_server.rs +++ b/src/api_server.rs @@ -1,68 +1,29 @@ use std::sync::Arc; +use std::net::SocketAddr; +use std::collections::VecDeque; -use futures_util::TryStreamExt; +use futures::stream::StreamExt; use hyper::service::{make_service_fn, service_fn}; +use hyper::server::conn::AddrStream; use hyper::{Body, Method, Request, Response, Server, StatusCode}; use futures::future::Future; use crate::error::Error; use crate::membership::System; - -/// This is our service handler. It receives a Request, routes on its -/// path, and returns a Future of a Response. -async fn handler(sys: Arc, req: Request) -> Result, Error> { - match (req.method(), req.uri().path()) { - // Serve some instructions at / - (&Method::GET, "/") => Ok(Response::new(Body::from( - "Try POSTing data to /echo such as: `curl localhost:3000/echo -XPOST -d 'hello world'`", - ))), - - // Simply echo the body back to the client. - (&Method::POST, "/echo") => Ok(Response::new(req.into_body())), - - // Convert to uppercase before sending back to client using a stream. - (&Method::POST, "/echo/uppercase") => { - let chunk_stream = req.into_body().map_ok(|chunk| { - chunk - .iter() - .map(|byte| byte.to_ascii_uppercase()) - .collect::>() - }); - Ok(Response::new(Body::wrap_stream(chunk_stream))) - } - - // Reverse the entire body before sending back to the client. - // - // Since we don't know the end yet, we can't simply stream - // the chunks as they arrive as we did with the above uppercase endpoint. - // So here we do `.await` on the future, waiting on concatenating the full body, - // then afterwards the content can be reversed. Only then can we return a `Response`. - (&Method::POST, "/echo/reversed") => { - let whole_body = hyper::body::to_bytes(req.into_body()).await?; - - let reversed_body = whole_body.iter().rev().cloned().collect::>(); - Ok(Response::new(Body::from(reversed_body))) - } - - // Return the 404 Not Found for other routes. - _ => { - let mut not_found = Response::default(); - *not_found.status_mut() = StatusCode::NOT_FOUND; - Ok(not_found) - } - } -} +use crate::data::*; +use crate::proto::*; +use crate::rpc_client::*; pub async fn run_api_server(sys: Arc, shutdown_signal: impl Future) -> Result<(), hyper::Error> { let addr = ([0, 0, 0, 0], sys.config.api_port).into(); - let service = make_service_fn(|_| { + let service = make_service_fn(|conn: &AddrStream| { let sys = sys.clone(); + let client_addr = conn.remote_addr(); async move { - let sys = sys.clone(); Ok::<_, Error>(service_fn(move |req: Request| { let sys = sys.clone(); - handler(sys, req) + handler(sys, req, client_addr) })) } }); @@ -74,3 +35,180 @@ pub async fn run_api_server(sys: Arc, shutdown_signal: impl Future, req: Request, addr: SocketAddr) -> Result, Error> { + match handler_inner(sys, req, addr).await { + Ok(x) => Ok(x), + Err(Error::BadRequest(e)) => { + let mut bad_request = Response::new(Body::from(format!("{}\n", e))); + *bad_request.status_mut() = StatusCode::BAD_REQUEST; + Ok(bad_request) + } + Err(e) => { + let mut ise = Response::new(Body::from( + format!("Internal server error: {}\n", e))); + *ise.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + Ok(ise) + } + } +} + +async fn handler_inner(sys: Arc, req: Request, addr: SocketAddr) -> Result, Error> { + eprintln!("{} {} {}", addr, req.method(), req.uri()); + + let bucket = req.headers() + .get(hyper::header::HOST) + .map(|x| x.to_str().map_err(Error::from)) + .unwrap_or(Err(Error::BadRequest(format!("Host: header missing"))))? + .to_lowercase(); + let key = req.uri().path().to_string(); + + match req.method() { + &Method::GET => { + Ok(Response::new(Body::from( + "TODO: implement GET object", + ))) + } + &Method::PUT => { + let mime_type = req.headers() + .get(hyper::header::CONTENT_TYPE) + .map(|x| x.to_str()) + .unwrap_or(Ok("blob"))? + .to_string(); + let version_uuid = handle_put(sys, &mime_type, &bucket, &key, req.into_body()).await?; + Ok(Response::new(Body::from( + format!("Version UUID: {:?}", version_uuid), + ))) + } + _ => Err(Error::BadRequest(format!("Invalid method"))), + } +} + +async fn handle_put(sys: Arc, + mime_type: &str, + bucket: &str, key: &str, body: Body) + -> Result +{ + let version_uuid = gen_uuid(); + + let mut chunker = BodyChunker::new(body, sys.config.block_size); + let first_block = match chunker.next().await? { + Some(x) => x, + None => return Err(Error::BadRequest(format!("Empty body"))), + }; + + let mut version = VersionMeta{ + bucket: bucket.to_string(), + key: key.to_string(), + timestamp: now_msec(), + uuid: version_uuid.clone(), + mime_type: mime_type.to_string(), + size: first_block.len() as u64, + is_complete: false, + data: VersionData::DeleteMarker, + }; + let version_who = sys.members.read().await + .walk_ring(&version_uuid, sys.config.meta_replication_factor); + + if first_block.len() < INLINE_THRESHOLD { + version.data = VersionData::Inline(first_block); + version.is_complete = true; + rpc_try_call_many(sys.clone(), + &version_who[..], + &Message::AdvertiseVersion(version), + (sys.config.meta_replication_factor+1)/2, + DEFAULT_TIMEOUT).await?; + return Ok(version_uuid) + } + + let first_block_hash = hash(&first_block[..]); + version.data = VersionData::FirstBlock(first_block_hash); + rpc_try_call_many(sys.clone(), + &version_who[..], + &Message::AdvertiseVersion(version.clone()), + (sys.config.meta_replication_factor+1)/2, + DEFAULT_TIMEOUT).await?; + + let block_meta = BlockMeta{ + version_uuid: version_uuid.clone(), + offset: 0, + hash: hash(&first_block[..]), + }; + let mut next_offset = first_block.len(); + let mut put_curr_block = put_block(sys.clone(), block_meta, first_block); + loop { + let (_, next_block) = futures::try_join!(put_curr_block, chunker.next())?; + if let Some(block) = next_block { + let block_meta = BlockMeta{ + version_uuid: version_uuid.clone(), + offset: next_offset as u64, + hash: hash(&block[..]), + }; + next_offset += block.len(); + put_curr_block = put_block(sys.clone(), block_meta, block); + } else { + break; + } + } + + // TODO: if at any step we have an error, we should undo everything we did + + version.is_complete = true; + rpc_try_call_many(sys.clone(), + &version_who[..], + &Message::AdvertiseVersion(version), + (sys.config.meta_replication_factor+1)/2, + DEFAULT_TIMEOUT).await?; + Ok(version_uuid) +} + +async fn put_block(sys: Arc, meta: BlockMeta, data: Vec) -> Result<(), Error> { + let who = sys.members.read().await + .walk_ring(&meta.hash, sys.config.meta_replication_factor); + rpc_try_call_many(sys.clone(), + &who[..], + &Message::PutBlock(PutBlockMessage{ + meta, + data, + }), + (sys.config.meta_replication_factor+1)/2, + DEFAULT_TIMEOUT).await?; + Ok(()) +} + +struct BodyChunker { + body: Body, + block_size: usize, + buf: VecDeque, +} + +impl BodyChunker { + fn new(body: Body, block_size: usize) -> Self { + Self{ + body, + block_size, + buf: VecDeque::new(), + } + } + async fn next(&mut self) -> Result>, Error> { + while self.buf.len() < self.block_size { + if let Some(block) = self.body.next().await { + let bytes = block?; + self.buf.extend(&bytes[..]); + } else { + break; + } + } + if self.buf.len() == 0 { + Ok(None) + } else if self.buf.len() <= self.block_size { + let block = self.buf.drain(..) + .collect::>(); + Ok(Some(block)) + } else { + let block = self.buf.drain(..self.block_size) + .collect::>(); + Ok(Some(block)) + } + } +} diff --git a/src/data.rs b/src/data.rs index f54c4cc1..3c71b782 100644 --- a/src/data.rs +++ b/src/data.rs @@ -1,7 +1,10 @@ +use std::time::{SystemTime, UNIX_EPOCH}; use std::fmt; use std::collections::HashMap; use serde::{Serializer, Deserializer, Serialize, Deserialize}; use serde::de::{self, Visitor}; +use rand::Rng; +use sha2::{Sha256, Digest}; #[derive(Default, PartialOrd, Ord, Clone, Hash, PartialEq)] pub struct FixedBytes32([u8; 32]); @@ -69,6 +72,23 @@ impl FixedBytes32 { pub type UUID = FixedBytes32; pub type Hash = FixedBytes32; +pub fn hash(data: &[u8]) -> Hash { + let mut hasher = Sha256::new(); + hasher.input(data); + let mut hash = [0u8; 32]; + hash.copy_from_slice(&hasher.result()[..]); + hash.into() +} + +pub fn gen_uuid() -> UUID { + rand::thread_rng().gen::<[u8; 32]>().into() +} + +pub fn now_msec() -> u64 { + SystemTime::now().duration_since(UNIX_EPOCH) + .expect("Fix your clock :o") + .as_millis() as u64 +} // Network management @@ -86,47 +106,49 @@ pub struct NetworkConfigEntry { // Data management +pub const INLINE_THRESHOLD: usize = 2048; + #[derive(Debug, Serialize, Deserialize)] pub struct SplitpointMeta { - bucket: String, - key: String, + pub bucket: String, + pub key: String, - timestamp: u64, - uuid: UUID, - deleted: bool, + pub timestamp: u64, + pub uuid: UUID, + pub deleted: bool, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct VersionMeta { - bucket: String, - key: String, + pub bucket: String, + pub key: String, - timestamp: u64, - uuid: UUID, - deleted: bool, + pub timestamp: u64, + pub uuid: UUID, - mime_type: String, - size: u64, - is_complete: bool, + pub mime_type: String, + pub size: u64, + pub is_complete: bool, - data: VersionData, + pub data: VersionData, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub enum VersionData { - Inline(Vec), + DeleteMarker, + Inline(#[serde(with="serde_bytes")] Vec), FirstBlock(Hash), } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Clone, Debug, Serialize, Deserialize)] pub struct BlockMeta { - version_uuid: UUID, - offset: u64, - hash: Hash, + pub version_uuid: UUID, + pub offset: u64, + pub hash: Hash, } #[derive(Debug, Serialize, Deserialize)] pub struct BlockReverseMeta { - versions: Vec, - deleted_versions: Vec, + pub versions: Vec, + pub deleted_versions: Vec, } diff --git a/src/error.rs b/src/error.rs index fd717638..30f7dac6 100644 --- a/src/error.rs +++ b/src/error.rs @@ -12,6 +12,9 @@ pub enum Error { #[error(display = "HTTP error: {}", _0)] HTTP(#[error(source)] http::Error), + #[error(display = "Invalid HTTP header value: {}", _0)] + HTTPHeader(#[error(source)] http::header::ToStrError), + #[error(display = "Messagepack encode error: {}", _0)] RMPEncode(#[error(source)] rmp_serde::encode::Error), #[error(display = "Messagepack decode error: {}", _0)] @@ -26,6 +29,9 @@ pub enum Error { #[error(display = "RPC error: {}", _0)] RPCError(String), + #[error(display = "{}", _0)] + BadRequest(String), + #[error(display = "{}", _0)] Message(String), } diff --git a/src/membership.rs b/src/membership.rs index b7b99bb1..69805f2a 100644 --- a/src/membership.rs +++ b/src/membership.rs @@ -7,10 +7,10 @@ use std::collections::HashMap; use std::time::Duration; use std::net::{IpAddr, SocketAddr}; +use sha2::{Sha256, Digest}; use tokio::prelude::*; use futures::future::join_all; use tokio::sync::RwLock; -use sha2::{Sha256, Digest}; use crate::server::Config; use crate::error::Error; @@ -96,10 +96,7 @@ impl Members { } for i in 0..config.n_tokens { - let mut location_hasher = Sha256::new(); - location_hasher.input(format!("{} {}", hex::encode(&id), i)); - let mut location = [0u8; 32]; - location.copy_from_slice(&location_hasher.result()[..]); + let location = hash(format!("{} {}", hex::encode(&id), i).as_bytes()); new_ring.push(RingEntry{ location: location.into(), @@ -114,7 +111,7 @@ impl Members { self.n_datacenters = datacenters.len(); } - fn walk_ring(&self, from: &Hash, n: usize) -> Vec { + pub fn walk_ring(&self, from: &Hash, n: usize) -> Vec { if n >= self.config.members.len() { return self.config.members.keys().cloned().collect::>(); } @@ -222,7 +219,7 @@ impl System { let members = self.members.read().await; let to = members.status.keys().filter(|x| **x != self.id).cloned().collect::>(); drop(members); - rpc_call_many(self.clone(), &to[..], &msg, None, timeout).await; + rpc_call_many(self.clone(), &to[..], &msg, timeout).await; } pub async fn bootstrap(self: Arc) { diff --git a/src/proto.rs b/src/proto.rs index d1d4fb59..04b8e2b2 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -16,6 +16,9 @@ pub enum Message { PullConfig, AdvertiseNodesUp(Vec), AdvertiseConfig(NetworkConfig), + + PutBlock(PutBlockMessage), + AdvertiseVersion(VersionMeta), } #[derive(Debug, Serialize, Deserialize)] @@ -32,3 +35,11 @@ pub struct AdvertisedNode { pub id: UUID, pub addr: SocketAddr, } + +#[derive(Debug, Serialize, Deserialize)] +pub struct PutBlockMessage { + pub meta: BlockMeta, + + #[serde(with="serde_bytes")] + pub data: Vec, +} diff --git a/src/rpc_client.rs b/src/rpc_client.rs index 057c19e8..7995cdfa 100644 --- a/src/rpc_client.rs +++ b/src/rpc_client.rs @@ -16,7 +16,6 @@ use crate::membership::System; pub async fn rpc_call_many(sys: Arc, to: &[UUID], msg: &Message, - stop_after: Option, timeout: Duration) -> Vec> { @@ -25,19 +24,49 @@ pub async fn rpc_call_many(sys: Arc, .collect::>(); let mut results = vec![]; - let mut n_ok = 0; while let Some(resp) = resp_stream.next().await { - if resp.is_ok() { - n_ok += 1 - } results.push(resp); - if let Some(n) = stop_after { - if n_ok >= n { - break + } + results +} + +pub async fn rpc_try_call_many(sys: Arc, + to: &[UUID], + msg: &Message, + stop_after: usize, + timeout: Duration) + -> Result, Error> +{ + let mut resp_stream = to.iter() + .map(|to| rpc_call(sys.clone(), to, msg, timeout)) + .collect::>(); + + let mut results = vec![]; + let mut errors = vec![]; + + while let Some(resp) = resp_stream.next().await { + match resp { + Ok(msg) => { + results.push(msg); + if results.len() >= stop_after { + break + } + } + Err(e) => { + errors.push(e); } } } - results + + if results.len() >= stop_after { + Ok(results) + } else { + let mut msg = "Too many failures:".to_string(); + for e in errors { + msg += &format!("\n{}", e); + } + Err(Error::Message(msg)) + } } pub async fn rpc_call(sys: Arc, diff --git a/src/server.rs b/src/server.rs index 5cac1c70..d5da8c17 100644 --- a/src/server.rs +++ b/src/server.rs @@ -4,14 +4,20 @@ use std::net::SocketAddr; use std::path::PathBuf; use futures::channel::oneshot; use serde::Deserialize; -use rand::Rng; -use crate::data::UUID; +use crate::data::*; use crate::error::Error; use crate::membership::System; use crate::api_server; use crate::rpc_server; +fn default_block_size() -> usize { + 1048576 +} +fn default_meta_replication_factor() -> usize { + 3 +} + #[derive(Deserialize, Debug)] pub struct Config { pub metadata_dir: PathBuf, @@ -21,6 +27,12 @@ pub struct Config { pub rpc_port: u16, pub bootstrap_peers: Vec, + + #[serde(default = "default_block_size")] + pub block_size: usize, + + #[serde(default = "default_meta_replication_factor")] + pub meta_replication_factor: usize, } fn read_config(config_file: PathBuf) -> Result { @@ -49,11 +61,11 @@ fn gen_node_id(metadata_dir: &PathBuf) -> Result { id.copy_from_slice(&d[..]); Ok(id.into()) } else { - let id = rand::thread_rng().gen::<[u8; 32]>(); + let id = gen_uuid(); let mut f = std::fs::File::create(id_file.as_path())?; - f.write_all(&id[..])?; - Ok(id.into()) + f.write_all(id.as_slice())?; + Ok(id) } }