diff --git a/src/api_server.rs b/src/api_server.rs index c6d52d16..f213b4dd 100644 --- a/src/api_server.rs +++ b/src/api_server.rs @@ -9,7 +9,6 @@ use hyper::server::conn::AddrStream; use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Method, Request, Response, Server, StatusCode}; -use crate::block::*; use crate::data::*; use crate::error::Error; use crate::http_util::*; @@ -151,7 +150,9 @@ async fn handle_put( let mut next_offset = first_block.len(); let mut put_curr_version_block = put_block_meta(garage.clone(), &version, 0, first_block_hash.clone()); - let mut put_curr_block = rpc_put_block(&garage.system, first_block_hash, first_block); + let mut put_curr_block = garage + .block_manager + .rpc_put_block(first_block_hash, first_block); loop { let (_, _, next_block) = @@ -165,7 +166,7 @@ async fn handle_put( next_offset as u64, block_hash.clone(), ); - put_curr_block = rpc_put_block(&garage.system, block_hash, block); + put_curr_block = garage.block_manager.rpc_put_block(block_hash, block); next_offset += block_len; } else { break; @@ -300,7 +301,7 @@ async fn handle_get( Ok(resp_builder.body(body)?) } ObjectVersionData::FirstBlock(first_block_hash) => { - let read_first_block = rpc_get_block(&garage.system, &first_block_hash); + let read_first_block = garage.block_manager.rpc_get_block(&first_block_hash); let get_next_blocks = garage.version_table.get(&last_v.uuid, &EmptySortKey); let (first_block, version) = futures::try_join!(read_first_block, get_next_blocks)?; @@ -323,7 +324,11 @@ async fn handle_get( if let Some(data) = data_opt { Ok(Bytes::from(data)) } else { - rpc_get_block(&garage.system, &hash).await.map(Bytes::from) + garage + .block_manager + .rpc_get_block(&hash) + .await + .map(Bytes::from) } } }) diff --git a/src/block.rs b/src/block.rs index 6add24b7..879cff2c 100644 --- a/src/block.rs +++ b/src/block.rs @@ -5,6 +5,7 @@ use std::time::Duration; use arc_swap::ArcSwapOption; use futures::future::*; use futures::stream::*; +use serde::{Deserialize, Serialize}; use tokio::fs; use tokio::prelude::*; use tokio::sync::{watch, Mutex}; @@ -15,22 +16,40 @@ use crate::error::Error; use crate::membership::System; use crate::proto::*; use crate::rpc_client::*; +use crate::rpc_server::*; use crate::server::Garage; const NEED_BLOCK_QUERY_TIMEOUT: Duration = Duration::from_secs(5); const RESYNC_RETRY_TIMEOUT: Duration = Duration::from_secs(10); +#[derive(Debug, Serialize, Deserialize)] +pub enum Message { + Ok, + GetBlock(Hash), + PutBlock(PutBlockMessage), + NeedBlockQuery(Hash), + NeedBlockReply(bool), +} + +impl RpcMessage for Message {} + pub struct BlockManager { pub data_dir: PathBuf, pub rc: sled::Tree, pub resync_queue: sled::Tree, pub lock: Mutex<()>, pub system: Arc, + rpc_client: Arc>, pub garage: ArcSwapOption, } impl BlockManager { - pub fn new(db: &sled::Db, data_dir: PathBuf, system: Arc) -> Arc { + pub fn new( + db: &sled::Db, + data_dir: PathBuf, + system: Arc, + rpc_server: &mut RpcServer, + ) -> Arc { let rc = db .open_tree("block_local_rc") .expect("Unable to open block_local_rc tree"); @@ -40,14 +59,38 @@ impl BlockManager { .open_tree("block_local_resync_queue") .expect("Unable to open block_local_resync_queue tree"); - Arc::new(Self { + let rpc_path = "block_manager"; + let rpc_client = system.rpc_client::(rpc_path); + + let block_manager = Arc::new(Self { rc, resync_queue, data_dir, lock: Mutex::new(()), system, + rpc_client, garage: ArcSwapOption::from(None), - }) + }); + block_manager + .clone() + .register_handler(rpc_server, rpc_path.into()); + block_manager + } + + fn register_handler(self: Arc, rpc_server: &mut RpcServer, path: String) { + rpc_server.add_handler::(path, move |msg, _addr| { + let self2 = self.clone(); + async move { + match msg { + Message::PutBlock(m) => self2.write_block(&m.hash, &m.data).await, + Message::GetBlock(h) => self2.read_block(&h).await, + Message::NeedBlockQuery(h) => { + self2.need_block(&h).await.map(Message::NeedBlockReply) + } + _ => Err(Error::Message(format!("Invalid RPC"))), + } + } + }); } pub async fn spawn_background_worker(self: Arc) { @@ -214,10 +257,11 @@ impl BlockManager { if needed_by_others { let ring = garage.system.ring.borrow().clone(); let who = ring.walk_ring(&hash, garage.system.config.data_replication_factor); - let msg = Message::NeedBlockQuery(hash.clone()); - let who_needs_fut = who - .iter() - .map(|to| rpc_call(garage.system.clone(), to, &msg, NEED_BLOCK_QUERY_TIMEOUT)); + let msg = Arc::new(Message::NeedBlockQuery(hash.clone())); + let who_needs_fut = who.iter().map(|to| { + self.rpc_client + .call(to, msg.clone(), NEED_BLOCK_QUERY_TIMEOUT) + }); let who_needs = join_all(who_needs_fut).await; let mut need_nodes = vec![]; @@ -242,13 +286,10 @@ impl BlockManager { if need_nodes.len() > 0 { let put_block_message = self.read_block(hash).await?; - let put_responses = rpc_call_many( - garage.system.clone(), - &need_nodes[..], - put_block_message, - BLOCK_RW_TIMEOUT, - ) - .await; + let put_responses = self + .rpc_client + .call_many(&need_nodes[..], put_block_message, BLOCK_RW_TIMEOUT) + .await; for resp in put_responses { resp?; } @@ -262,12 +303,48 @@ impl BlockManager { // TODO find a way to not do this if they are sending it to us // Let's suppose this isn't an issue for now with the BLOCK_RW_TIMEOUT delay // between the RC being incremented and this part being called. - let block_data = rpc_get_block(&self.system, &hash).await?; + let block_data = self.rpc_get_block(&hash).await?; self.write_block(hash, &block_data[..]).await?; } Ok(()) } + + pub async fn rpc_get_block(&self, hash: &Hash) -> Result, Error> { + let ring = self.system.ring.borrow().clone(); + let who = ring.walk_ring(&hash, self.system.config.data_replication_factor); + let msg = Arc::new(Message::GetBlock(hash.clone())); + let mut resp_stream = who + .iter() + .map(|to| self.rpc_client.call(to, msg.clone(), BLOCK_RW_TIMEOUT)) + .collect::>(); + + while let Some(resp) = resp_stream.next().await { + if let Ok(Message::PutBlock(msg)) = resp { + if data::hash(&msg.data[..]) == *hash { + return Ok(msg.data); + } + } + } + Err(Error::Message(format!( + "Unable to read block {:?}: no valid blocks returned", + hash + ))) + } + + pub async fn rpc_put_block(&self, hash: Hash, data: Vec) -> Result<(), Error> { + let ring = self.system.ring.borrow().clone(); + let who = ring.walk_ring(&hash, self.system.config.data_replication_factor); + self.rpc_client + .try_call_many( + &who[..], + Message::PutBlock(PutBlockMessage { hash, data }), + (self.system.config.data_replication_factor + 1) / 2, + BLOCK_RW_TIMEOUT, + ) + .await?; + Ok(()) + } } fn u64_from_bytes(bytes: &[u8]) -> u64 { @@ -297,39 +374,3 @@ fn rc_merge(_key: &[u8], old: Option<&[u8]>, new: &[u8]) -> Option> { Some(u64::to_be_bytes(new).to_vec()) } } - -pub async fn rpc_get_block(system: &Arc, hash: &Hash) -> Result, Error> { - let ring = system.ring.borrow().clone(); - let who = ring.walk_ring(&hash, system.config.data_replication_factor); - let msg = Message::GetBlock(hash.clone()); - let mut resp_stream = who - .iter() - .map(|to| rpc_call(system.clone(), to, &msg, BLOCK_RW_TIMEOUT)) - .collect::>(); - - while let Some(resp) = resp_stream.next().await { - if let Ok(Message::PutBlock(msg)) = resp { - if data::hash(&msg.data[..]) == *hash { - return Ok(msg.data); - } - } - } - Err(Error::Message(format!( - "Unable to read block {:?}: no valid blocks returned", - hash - ))) -} - -pub async fn rpc_put_block(system: &Arc, hash: Hash, data: Vec) -> Result<(), Error> { - let ring = system.ring.borrow().clone(); - let who = ring.walk_ring(&hash, system.config.data_replication_factor); - rpc_try_call_many( - system.clone(), - &who[..], - Message::PutBlock(PutBlockMessage { hash, data }), - (system.config.data_replication_factor + 1) / 2, - BLOCK_RW_TIMEOUT, - ) - .await?; - Ok(()) -} diff --git a/src/main.rs b/src/main.rs index ebf97a29..84b8c2bc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -22,12 +22,14 @@ mod tls_util; use std::collections::HashSet; use std::net::SocketAddr; use std::path::PathBuf; +use std::sync::Arc; use structopt::StructOpt; use data::*; use error::Error; +use membership::Message; use proto::*; -use rpc_client::RpcClient; +use rpc_client::*; use server::TlsConfig; #[derive(StructOpt, Debug)] @@ -113,7 +115,9 @@ async fn main() { } }; - let rpc_cli = RpcClient::new(&tls_config).expect("Could not create RPC client"); + let rpc_http_cli = + Arc::new(RpcHttpClient::new(&tls_config).expect("Could not create RPC client")); + let rpc_cli = RpcAddrClient::new(rpc_http_cli, "_membership".into()); let resp = match opt.cmd { Command::Server(server_opt) => { @@ -137,7 +141,7 @@ async fn main() { } } -async fn cmd_status(rpc_cli: RpcClient, rpc_host: SocketAddr) -> Result<(), Error> { +async fn cmd_status(rpc_cli: RpcAddrClient, rpc_host: SocketAddr) -> Result<(), Error> { let status = match rpc_cli .call(&rpc_host, &Message::PullStatus, DEFAULT_TIMEOUT) .await? @@ -196,7 +200,7 @@ async fn cmd_status(rpc_cli: RpcClient, rpc_host: SocketAddr) -> Result<(), Erro } async fn cmd_configure( - rpc_cli: RpcClient, + rpc_cli: RpcAddrClient, rpc_host: SocketAddr, args: ConfigureOpt, ) -> Result<(), Error> { @@ -249,7 +253,7 @@ async fn cmd_configure( } async fn cmd_remove( - rpc_cli: RpcClient, + rpc_cli: RpcAddrClient, rpc_host: SocketAddr, args: RemoveOpt, ) -> Result<(), Error> { diff --git a/src/membership.rs b/src/membership.rs index 6d758c59..499637fb 100644 --- a/src/membership.rs +++ b/src/membership.rs @@ -10,6 +10,7 @@ use std::time::Duration; use futures::future::join_all; use futures::select; use futures_util::future::*; +use serde::{Deserialize, Serialize}; use sha2::{Digest, Sha256}; use tokio::prelude::*; use tokio::sync::watch; @@ -20,17 +21,31 @@ use crate::data::*; use crate::error::Error; use crate::proto::*; use crate::rpc_client::*; +use crate::rpc_server::*; use crate::server::Config; const PING_INTERVAL: Duration = Duration::from_secs(10); const PING_TIMEOUT: Duration = Duration::from_secs(2); const MAX_FAILED_PINGS: usize = 3; +#[derive(Debug, Serialize, Deserialize)] +pub enum Message { + Ok, + Ping(PingMessage), + PullStatus, + PullConfig, + AdvertiseNodesUp(Vec), + AdvertiseConfig(NetworkConfig), +} + +impl RpcMessage for Message {} + pub struct System { pub config: Config, pub id: UUID, - pub rpc_client: RpcClient, + pub rpc_http_client: Arc, + rpc_client: Arc>, pub status: watch::Receiver>, pub ring: watch::Receiver>, @@ -199,7 +214,12 @@ fn read_network_config(metadata_dir: &PathBuf) -> Result { } impl System { - pub fn new(config: Config, id: UUID, background: Arc) -> Self { + pub fn new( + config: Config, + id: UUID, + background: Arc, + rpc_server: &mut RpcServer, + ) -> Arc { let net_config = match read_network_config(&config.metadata_dir) { Ok(x) => x, Err(e) => { @@ -228,17 +248,54 @@ impl System { ring.rebuild_ring(); let (update_ring, ring) = watch::channel(Arc::new(ring)); - let rpc_client = RpcClient::new(&config.rpc_tls).expect("Could not create RPC client"); + let rpc_http_client = + Arc::new(RpcHttpClient::new(&config.rpc_tls).expect("Could not create RPC client")); - System { + let rpc_path = "_membership"; + let rpc_client = RpcClient::new( + RpcAddrClient::::new(rpc_http_client.clone(), rpc_path.into()), + background.clone(), + status.clone(), + ); + + let sys = Arc::new(System { config, id, + rpc_http_client, rpc_client, status, ring, update_lock: Mutex::new((update_status, update_ring)), background, - } + }); + sys.clone().register_handler(rpc_server, rpc_path.into()); + sys + } + + fn register_handler(self: Arc, rpc_server: &mut RpcServer, path: String) { + rpc_server.add_handler::(path, move |msg, addr| { + let self2 = self.clone(); + async move { + match msg { + Message::Ping(ping) => self2.handle_ping(&addr, &ping).await, + + Message::PullStatus => self2.handle_pull_status(), + Message::PullConfig => self2.handle_pull_config(), + Message::AdvertiseNodesUp(adv) => self2.handle_advertise_nodes_up(&adv).await, + Message::AdvertiseConfig(adv) => self2.handle_advertise_config(&adv).await, + + _ => Err(Error::Message(format!("Unexpected RPC message"))), + } + } + }); + } + + pub fn rpc_client(self: &Arc, path: &str) -> Arc> { + RpcClient::new( + RpcAddrClient::new(self.rpc_http_client.clone(), path.to_string()), + self.background.clone(), + self.status.clone(), + ) } async fn save_network_config(self: Arc) -> Result<(), Error> { @@ -272,7 +329,7 @@ impl System { .filter(|x| **x != self.id) .cloned() .collect::>(); - rpc_call_many(self.clone(), &to[..], msg, timeout).await; + self.rpc_client.call_many(&to[..], msg, timeout).await; } pub async fn bootstrap(self: Arc) { @@ -299,7 +356,10 @@ impl System { ( id_option, addr.clone(), - sys.rpc_client.call(&addr, ping_msg_ref, PING_TIMEOUT).await, + sys.rpc_client + .by_addr() + .call(&addr, ping_msg_ref, PING_TIMEOUT) + .await, ) } })) @@ -509,7 +569,10 @@ impl System { peer: UUID, ) -> impl futures::future::Future + Send + 'static { async move { - let resp = rpc_call(self.clone(), &peer, &Message::PullStatus, PING_TIMEOUT).await; + let resp = self + .rpc_client + .call(&peer, Message::PullStatus, PING_TIMEOUT) + .await; if let Ok(Message::AdvertiseNodesUp(nodes)) = resp { let _: Result<_, _> = self.handle_advertise_nodes_up(&nodes).await; } @@ -517,7 +580,10 @@ impl System { } pub async fn pull_config(self: Arc, peer: UUID) { - let resp = rpc_call(self.clone(), &peer, &Message::PullConfig, PING_TIMEOUT).await; + let resp = self + .rpc_client + .call(&peer, Message::PullConfig, PING_TIMEOUT) + .await; if let Ok(Message::AdvertiseConfig(config)) = resp { let _: Result<_, _> = self.handle_advertise_config(&config).await; } diff --git a/src/proto.rs b/src/proto.rs index cf7ed1cc..d51aa36b 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -7,25 +7,6 @@ use crate::data::*; pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10); pub const BLOCK_RW_TIMEOUT: Duration = Duration::from_secs(42); -#[derive(Debug, Serialize, Deserialize)] -pub enum Message { - Ok, - Error(String), - - Ping(PingMessage), - PullStatus, - PullConfig, - AdvertiseNodesUp(Vec), - AdvertiseConfig(NetworkConfig), - - GetBlock(Hash), - PutBlock(PutBlockMessage), - NeedBlockQuery(Hash), - NeedBlockReply(bool), - - TableRPC(String, #[serde(with = "serde_bytes")] Vec), -} - #[derive(Debug, Serialize, Deserialize)] pub struct PingMessage { pub id: UUID, diff --git a/src/rpc_client.rs b/src/rpc_client.rs index f8da778c..6d26d86a 100644 --- a/src/rpc_client.rs +++ b/src/rpc_client.rs @@ -1,4 +1,5 @@ use std::borrow::Borrow; +use std::marker::PhantomData; use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; @@ -9,110 +10,166 @@ use futures::stream::StreamExt; use futures_util::future::FutureExt; use hyper::client::{Client, HttpConnector}; use hyper::{Body, Method, Request, StatusCode}; +use tokio::sync::watch; +use crate::background::*; use crate::data::*; use crate::error::Error; -use crate::membership::System; -use crate::proto::Message; +use crate::membership::Status; +use crate::rpc_server::RpcMessage; use crate::server::*; use crate::tls_util; -pub async fn rpc_call_many( - sys: Arc, - to: &[UUID], - msg: Message, - timeout: Duration, -) -> Vec> { - let msg = Arc::new(msg); - let mut resp_stream = to - .iter() - .map(|to| rpc_call(sys.clone(), to, msg.clone(), timeout)) - .collect::>(); +pub struct RpcClient { + status: watch::Receiver>, + background: Arc, - let mut results = vec![]; - while let Some(resp) = resp_stream.next().await { - results.push(resp); - } - results + pub rpc_addr_client: RpcAddrClient, } -pub async fn rpc_try_call_many( - sys: Arc, - to: &[UUID], - msg: Message, - stop_after: usize, - timeout: Duration, -) -> Result, Error> { - let sys2 = sys.clone(); - let msg = Arc::new(msg); - let mut resp_stream = to - .to_vec() - .into_iter() - .map(move |to| rpc_call(sys2.clone(), to.clone(), msg.clone(), timeout)) - .collect::>(); +impl RpcClient { + pub fn new( + rac: RpcAddrClient, + background: Arc, + status: watch::Receiver>, + ) -> Arc { + Arc::new(Self { + rpc_addr_client: rac, + background, + status, + }) + } - let mut results = vec![]; - let mut errors = vec![]; + pub fn by_addr(&self) -> &RpcAddrClient { + &self.rpc_addr_client + } - while let Some(resp) = resp_stream.next().await { - match resp { - Ok(msg) => { - results.push(msg); - if results.len() >= stop_after { - break; + pub async fn call, N: Borrow>( + &self, + to: N, + msg: MB, + timeout: Duration, + ) -> Result { + let addr = { + let status = self.status.borrow().clone(); + match status.nodes.get(to.borrow()) { + Some(status) => status.addr.clone(), + None => { + return Err(Error::Message(format!( + "Peer ID not found: {:?}", + to.borrow() + ))) } } - Err(e) => { - errors.push(e); - } - } + }; + self.rpc_addr_client.call(&addr, msg, timeout).await } - if results.len() >= stop_after { - // Continue requests in background - // TODO: make this optionnal (only usefull for write requests) - sys.background.spawn(async move { - resp_stream.collect::>().await; - Ok(()) - }); + pub async fn call_many(&self, to: &[UUID], msg: M, timeout: Duration) -> Vec> { + let msg = Arc::new(msg); + let mut resp_stream = to + .iter() + .map(|to| self.call(to, msg.clone(), timeout)) + .collect::>(); - Ok(results) - } else { - let mut msg = "Too many failures:".to_string(); - for e in errors { - msg += &format!("\n{}", e); + let mut results = vec![]; + while let Some(resp) = resp_stream.next().await { + results.push(resp); + } + results + } + + pub async fn try_call_many( + self: &Arc, + to: &[UUID], + msg: M, + stop_after: usize, + timeout: Duration, + ) -> Result, Error> { + let msg = Arc::new(msg); + let mut resp_stream = to + .to_vec() + .into_iter() + .map(|to| { + let self2 = self.clone(); + let msg = msg.clone(); + async move { self2.call(to.clone(), msg, timeout).await } + }) + .collect::>(); + + let mut results = vec![]; + let mut errors = vec![]; + + while let Some(resp) = resp_stream.next().await { + match resp { + Ok(msg) => { + results.push(msg); + if results.len() >= stop_after { + break; + } + } + Err(e) => { + errors.push(e); + } + } + } + + if results.len() >= stop_after { + // Continue requests in background + // TODO: make this optionnal (only usefull for write requests) + self.clone().background.spawn(async move { + resp_stream.collect::>().await; + Ok(()) + }); + + Ok(results) + } else { + let mut msg = "Too many failures:".to_string(); + for e in errors { + msg += &format!("\n{}", e); + } + Err(Error::Message(msg)) } - Err(Error::Message(msg)) } } -pub async fn rpc_call, N: Borrow>( - sys: Arc, - to: N, - msg: M, - timeout: Duration, -) -> Result { - let addr = { - let status = sys.status.borrow().clone(); - match status.nodes.get(to.borrow()) { - Some(status) => status.addr.clone(), - None => { - return Err(Error::Message(format!( - "Peer ID not found: {:?}", - to.borrow() - ))) - } - } - }; - sys.rpc_client.call(&addr, msg, timeout).await +pub struct RpcAddrClient { + phantom: PhantomData, + + pub http_client: Arc, + pub path: String, } -pub enum RpcClient { +impl RpcAddrClient { + pub fn new(http_client: Arc, path: String) -> Self { + Self { + phantom: PhantomData::default(), + http_client: http_client, + path, + } + } + + pub async fn call( + &self, + to_addr: &SocketAddr, + msg: MB, + timeout: Duration, + ) -> Result + where + MB: Borrow, + { + self.http_client + .call(&self.path, to_addr, msg, timeout) + .await + } +} + +pub enum RpcHttpClient { HTTP(Client), HTTPS(Client, hyper::Body>), } -impl RpcClient { +impl RpcHttpClient { pub fn new(tls_config: &Option) -> Result { if let Some(cf) = tls_config { let ca_certs = tls_util::load_certs(&cf.ca_cert)?; @@ -130,21 +187,26 @@ impl RpcClient { let connector = tls_util::HttpsConnectorFixedDnsname::::new(config, "garage"); - Ok(RpcClient::HTTPS(Client::builder().build(connector))) + Ok(RpcHttpClient::HTTPS(Client::builder().build(connector))) } else { - Ok(RpcClient::HTTP(Client::new())) + Ok(RpcHttpClient::HTTP(Client::new())) } } - pub async fn call>( + async fn call( &self, + path: &str, to_addr: &SocketAddr, - msg: M, + msg: MB, timeout: Duration, - ) -> Result { + ) -> Result + where + MB: Borrow, + M: RpcMessage, + { let uri = match self { - RpcClient::HTTP(_) => format!("http://{}/rpc", to_addr), - RpcClient::HTTPS(_) => format!("https://{}/rpc", to_addr), + RpcHttpClient::HTTP(_) => format!("http://{}/{}", to_addr, path), + RpcHttpClient::HTTPS(_) => format!("https://{}/{}", to_addr, path), }; let req = Request::builder() @@ -153,8 +215,8 @@ impl RpcClient { .body(Body::from(rmp_to_vec_all_named(msg.borrow())?))?; let resp_fut = match self { - RpcClient::HTTP(client) => client.request(req).fuse(), - RpcClient::HTTPS(client) => client.request(req).fuse(), + RpcHttpClient::HTTP(client) => client.request(req).fuse(), + RpcHttpClient::HTTPS(client) => client.request(req).fuse(), }; let resp = tokio::time::timeout(timeout, resp_fut) .await? @@ -168,11 +230,8 @@ impl RpcClient { if resp.status() == StatusCode::OK { let body = hyper::body::to_bytes(resp.into_body()).await?; - let msg = rmp_serde::decode::from_read::<_, Message>(body.into_buf())?; - match msg { - Message::Error(e) => Err(Error::RPCError(e)), - x => Ok(x), - } + let msg = rmp_serde::decode::from_read::<_, Result>(body.into_buf())?; + msg.map_err(Error::RPCError) } else { Err(Error::RPCError(format!("Status code {}", resp.status()))) } diff --git a/src/rpc_server.rs b/src/rpc_server.rs index 3410ab97..83f8ddc9 100644 --- a/src/rpc_server.rs +++ b/src/rpc_server.rs @@ -1,4 +1,6 @@ +use std::collections::HashMap; use std::net::SocketAddr; +use std::pin::Pin; use std::sync::Arc; use bytes::IntoBuf; @@ -8,175 +10,197 @@ use futures_util::stream::*; use hyper::server::conn::AddrStream; use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Method, Request, Response, Server, StatusCode}; +use serde::{Deserialize, Serialize}; use tokio::net::{TcpListener, TcpStream}; use tokio_rustls::server::TlsStream; use tokio_rustls::TlsAcceptor; use crate::data::*; use crate::error::Error; -use crate::proto::Message; -use crate::server::Garage; +use crate::server::TlsConfig; use crate::tls_util; -fn err_to_msg(x: Result) -> Message { - match x { - Err(e) => Message::Error(format!("{}", e)), - Ok(msg) => msg, - } +pub trait RpcMessage: Serialize + for<'de> Deserialize<'de> + Send + Sync {} + +type ResponseFuture = Pin, Error>> + Send>>; +type Handler = Box, SocketAddr) -> ResponseFuture + Send + Sync>; + +pub struct RpcServer { + pub bind_addr: SocketAddr, + pub tls_config: Option, + + handlers: HashMap, } -async fn handler( - garage: Arc, +async fn handle_func( + handler: Arc, req: Request, - addr: SocketAddr, -) -> Result, Error> { - if req.method() != &Method::POST { - let mut bad_request = Response::default(); - *bad_request.status_mut() = StatusCode::BAD_REQUEST; - return Ok(bad_request); - } - + sockaddr: SocketAddr, +) -> Result, Error> +where + M: RpcMessage + 'static, + F: Fn(M, SocketAddr) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, +{ let whole_body = hyper::body::to_bytes(req.into_body()).await?; - let msg = rmp_serde::decode::from_read::<_, Message>(whole_body.into_buf())?; - - // eprintln!( - // "RPC from {}: {} ({} bytes)", - // addr, - // debug_serialize(&msg), - // whole_body.len() - // ); - - let sys = garage.system.clone(); - let resp = err_to_msg(match msg { - Message::Ping(ping) => sys.handle_ping(&addr, &ping).await, - - Message::PullStatus => sys.handle_pull_status(), - Message::PullConfig => sys.handle_pull_config(), - Message::AdvertiseNodesUp(adv) => sys.handle_advertise_nodes_up(&adv).await, - Message::AdvertiseConfig(adv) => sys.handle_advertise_config(&adv).await, - - Message::PutBlock(m) => { - // A RPC can be interrupted in the middle, however we don't want to write partial blocks, - // which might happen if the write_block() future is cancelled in the middle. - // To solve this, the write itself is in a spawned task that has its own separate lifetime, - // and the request handler simply sits there waiting for the task to finish. - // (if it's cancelled, that's not an issue) - // (TODO FIXME except if garage happens to shut down at that point) - let write_fut = async move { garage.block_manager.write_block(&m.hash, &m.data).await }; - tokio::spawn(write_fut).await? + let msg = rmp_serde::decode::from_read::<_, M>(whole_body.into_buf())?; + match handler(msg, sockaddr).await { + Ok(resp) => { + let resp_bytes = rmp_to_vec_all_named::>(&Ok(resp))?; + Ok(Response::new(Body::from(resp_bytes))) } - Message::GetBlock(h) => garage.block_manager.read_block(&h).await, - Message::NeedBlockQuery(h) => garage - .block_manager - .need_block(&h) - .await - .map(Message::NeedBlockReply), - - Message::TableRPC(table, msg) => { - // Same trick for table RPCs than for PutBlock - let op_fut = async move { - if let Some(rpc_handler) = garage.table_rpc_handlers.get(&table) { - rpc_handler - .handle(&msg[..]) - .await - .map(|rep| Message::TableRPC(table.to_string(), rep)) - } else { - Ok(Message::Error(format!("Unknown table: {}", table))) - } - }; - tokio::spawn(op_fut).await? + Err(e) => { + let err_str = format!("{}", e); + let rep_bytes = rmp_to_vec_all_named::>(&Err(err_str))?; + let mut err_response = Response::new(Body::from(rep_bytes)); + *err_response.status_mut() = e.http_status_code(); + Ok(err_response) } - - _ => Ok(Message::Error(format!("Unexpected message: {:?}", msg))), - }); - - // eprintln!("reply to {}: {}", addr, debug_serialize(&resp)); - - Ok(Response::new(Body::from(rmp_to_vec_all_named(&resp)?))) + } } -pub async fn run_rpc_server( - garage: Arc, - shutdown_signal: impl Future, -) -> Result<(), Error> { - let bind_addr = ([0, 0, 0, 0, 0, 0, 0, 0], garage.system.config.rpc_port).into(); - - if let Some(tls_config) = &garage.system.config.rpc_tls { - let ca_certs = tls_util::load_certs(&tls_config.ca_cert)?; - let node_certs = tls_util::load_certs(&tls_config.node_cert)?; - let node_key = tls_util::load_private_key(&tls_config.node_key)?; - - let mut ca_store = rustls::RootCertStore::empty(); - for crt in ca_certs.iter() { - ca_store.add(crt)?; +impl RpcServer { + pub fn new(bind_addr: SocketAddr, tls_config: Option) -> Self { + Self { + bind_addr, + tls_config, + handlers: HashMap::new(), } - - let mut config = - rustls::ServerConfig::new(rustls::AllowAnyAuthenticatedClient::new(ca_store)); - config.set_single_cert([&node_certs[..], &ca_certs[..]].concat(), node_key)?; - let tls_acceptor = Arc::new(TlsAcceptor::from(Arc::new(config))); - - let mut listener = TcpListener::bind(&bind_addr).await?; - let incoming = listener.incoming().filter_map(|socket| async { - match socket { - Ok(stream) => match tls_acceptor.clone().accept(stream).await { - Ok(x) => Some(Ok::<_, hyper::Error>(x)), - Err(e) => { - eprintln!("RPC server TLS error: {}", e); - None - } - }, - Err(_) => None, - } - }); - let incoming = hyper::server::accept::from_stream(incoming); - - let service = make_service_fn(|conn: &TlsStream| { - let client_addr = conn - .get_ref() - .0 - .peer_addr() - .unwrap_or(([0, 0, 0, 0], 0).into()); - let garage = garage.clone(); - async move { - Ok::<_, Error>(service_fn(move |req: Request| { - let garage = garage.clone(); - handler(garage, req, client_addr).map_err(|e| { - eprintln!("RPC handler error: {}", e); - e - }) - })) - } - }); - - let server = Server::builder(incoming).serve(service); - - let graceful = server.with_graceful_shutdown(shutdown_signal); - println!("RPC server listening on http://{}", bind_addr); - - graceful.await?; - } else { - let service = make_service_fn(|conn: &AddrStream| { - let client_addr = conn.remote_addr(); - let garage = garage.clone(); - async move { - Ok::<_, Error>(service_fn(move |req: Request| { - let garage = garage.clone(); - handler(garage, req, client_addr).map_err(|e| { - eprintln!("RPC handler error: {}", e); - e - }) - })) - } - }); - - let server = Server::bind(&bind_addr).serve(service); - - let graceful = server.with_graceful_shutdown(shutdown_signal); - println!("RPC server listening on http://{}", bind_addr); - - graceful.await?; } - Ok(()) + pub fn add_handler(&mut self, name: String, handler: F) + where + M: RpcMessage + 'static, + F: Fn(M, SocketAddr) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, + { + let handler_arc = Arc::new(handler); + let handler = Box::new(move |req: Request, sockaddr: SocketAddr| { + let handler2 = handler_arc.clone(); + let b: ResponseFuture = Box::pin(handle_func(handler2, req, sockaddr)); + b + }); + self.handlers.insert(name, handler); + } + + async fn handler( + self: Arc, + req: Request, + addr: SocketAddr, + ) -> Result, Error> { + if req.method() != &Method::POST { + let mut bad_request = Response::default(); + *bad_request.status_mut() = StatusCode::BAD_REQUEST; + return Ok(bad_request); + } + + let path = &req.uri().path()[1..]; + let handler = match self.handlers.get(path) { + Some(h) => h, + None => { + let mut not_found = Response::default(); + *not_found.status_mut() = StatusCode::NOT_FOUND; + return Ok(not_found); + } + }; + + let resp_waiter = tokio::spawn(handler(req, addr)); + match resp_waiter.await { + Err(_err) => { + let mut ise = Response::default(); + *ise.status_mut() = StatusCode::INTERNAL_SERVER_ERROR; + Ok(ise) + } + Ok(Err(err)) => { + let mut bad_request = Response::new(Body::from(format!("{}", err))); + *bad_request.status_mut() = StatusCode::BAD_REQUEST; + Ok(bad_request) + } + Ok(Ok(resp)) => Ok(resp), + } + } + + pub async fn run( + self: Arc, + shutdown_signal: impl Future, + ) -> Result<(), Error> { + if let Some(tls_config) = self.tls_config.as_ref() { + let ca_certs = tls_util::load_certs(&tls_config.ca_cert)?; + let node_certs = tls_util::load_certs(&tls_config.node_cert)?; + let node_key = tls_util::load_private_key(&tls_config.node_key)?; + + let mut ca_store = rustls::RootCertStore::empty(); + for crt in ca_certs.iter() { + ca_store.add(crt)?; + } + + let mut config = + rustls::ServerConfig::new(rustls::AllowAnyAuthenticatedClient::new(ca_store)); + config.set_single_cert([&node_certs[..], &ca_certs[..]].concat(), node_key)?; + let tls_acceptor = Arc::new(TlsAcceptor::from(Arc::new(config))); + + let mut listener = TcpListener::bind(&self.bind_addr).await?; + let incoming = listener.incoming().filter_map(|socket| async { + match socket { + Ok(stream) => match tls_acceptor.clone().accept(stream).await { + Ok(x) => Some(Ok::<_, hyper::Error>(x)), + Err(e) => { + eprintln!("RPC server TLS error: {}", e); + None + } + }, + Err(_) => None, + } + }); + let incoming = hyper::server::accept::from_stream(incoming); + + let self_arc = self.clone(); + let service = make_service_fn(|conn: &TlsStream| { + let client_addr = conn + .get_ref() + .0 + .peer_addr() + .unwrap_or(([0, 0, 0, 0], 0).into()); + let self_arc = self_arc.clone(); + async move { + Ok::<_, Error>(service_fn(move |req: Request| { + self_arc.clone().handler(req, client_addr).map_err(|e| { + eprintln!("RPC handler error: {}", e); + e + }) + })) + } + }); + + let server = Server::builder(incoming).serve(service); + + let graceful = server.with_graceful_shutdown(shutdown_signal); + println!("RPC server listening on http://{}", self.bind_addr); + + graceful.await?; + } else { + let self_arc = self.clone(); + let service = make_service_fn(move |conn: &AddrStream| { + let client_addr = conn.remote_addr(); + let self_arc = self_arc.clone(); + async move { + Ok::<_, Error>(service_fn(move |req: Request| { + self_arc.clone().handler(req, client_addr).map_err(|e| { + eprintln!("RPC handler error: {}", e); + e + }) + })) + } + }); + + let server = Server::bind(&self.bind_addr).serve(service); + + let graceful = server.with_graceful_shutdown(shutdown_signal); + println!("RPC server listening on http://{}", self.bind_addr); + + graceful.await?; + } + + Ok(()) + } } diff --git a/src/server.rs b/src/server.rs index 591a7bf9..57faea21 100644 --- a/src/server.rs +++ b/src/server.rs @@ -1,4 +1,3 @@ -use std::collections::HashMap; use std::io::{Read, Write}; use std::net::SocketAddr; use std::path::PathBuf; @@ -15,7 +14,7 @@ use crate::data::*; use crate::error::Error; use crate::membership::System; use crate::proto::*; -use crate::rpc_server; +use crate::rpc_server::RpcServer; use crate::table::*; #[derive(Deserialize, Debug, Clone)] @@ -53,8 +52,6 @@ pub struct Garage { pub system: Arc, pub block_manager: Arc, - pub table_rpc_handlers: HashMap>, - pub object_table: Arc>, pub version_table: Arc>, pub block_ref_table: Arc>, @@ -66,12 +63,14 @@ impl Garage { id: UUID, db: sled::Db, background: Arc, + rpc_server: &mut RpcServer, ) -> Arc { println!("Initialize membership management system..."); - let system = Arc::new(System::new(config.clone(), id, background.clone())); + let system = System::new(config.clone(), id, background.clone(), rpc_server); println!("Initialize block manager..."); - let block_manager = BlockManager::new(&db, config.data_dir.clone(), system.clone()); + let block_manager = + BlockManager::new(&db, config.data_dir.clone(), system.clone(), rpc_server); let data_rep_param = TableReplicationParams { replication_factor: system.config.data_replication_factor, @@ -97,6 +96,7 @@ impl Garage { &db, "block_ref".to_string(), data_rep_param.clone(), + rpc_server, ) .await; @@ -110,6 +110,7 @@ impl Garage { &db, "version".to_string(), meta_rep_param.clone(), + rpc_server, ) .await; @@ -123,35 +124,20 @@ impl Garage { &db, "object".to_string(), meta_rep_param.clone(), + rpc_server, ) .await; println!("Initialize Garage..."); - let mut garage = Self { + let garage = Arc::new(Self { db, system: system.clone(), block_manager, background, - table_rpc_handlers: HashMap::new(), object_table, version_table, block_ref_table, - }; - - garage.table_rpc_handlers.insert( - garage.object_table.name.clone(), - garage.object_table.clone().rpc_handler(), - ); - garage.table_rpc_handlers.insert( - garage.version_table.name.clone(), - garage.version_table.clone().rpc_handler(), - ); - garage.table_rpc_handlers.insert( - garage.block_ref_table.name.clone(), - garage.block_ref_table.clone().rpc_handler(), - ); - - let garage = Arc::new(garage); + }); println!("Start block manager background thread..."); garage.block_manager.garage.swap(Some(garage.clone())); @@ -232,20 +218,23 @@ pub async fn run_server(config_file: PathBuf) -> Result<(), Error> { db_path.push("db"); let db = sled::open(db_path).expect("Unable to open DB"); - let (send_cancel, watch_cancel) = watch::channel(false); + println!("Initialize RPC server..."); + let rpc_bind_addr = ([0, 0, 0, 0, 0, 0, 0, 0], config.rpc_port).into(); + let mut rpc_server = RpcServer::new(rpc_bind_addr, config.rpc_tls.clone()); println!("Initializing background runner..."); + let (send_cancel, watch_cancel) = watch::channel(false); let background = BackgroundRunner::new(8, watch_cancel.clone()); - let garage = Garage::new(config, id, db, background.clone()).await; + let garage = Garage::new(config, id, db, background.clone(), &mut rpc_server).await; println!("Initializing RPC and API servers..."); - let rpc_server = rpc_server::run_rpc_server(garage.clone(), wait_from(watch_cancel.clone())); + let run_rpc_server = Arc::new(rpc_server).run(wait_from(watch_cancel.clone())); let api_server = api_server::run_api_server(garage.clone(), wait_from(watch_cancel.clone())); futures::try_join!( garage.system.clone().bootstrap().map(Ok), - rpc_server, + run_rpc_server, api_server, background.run().map(Ok), shutdown_signal(send_cancel), diff --git a/src/table.rs b/src/table.rs index 3ad08cff..f7354376 100644 --- a/src/table.rs +++ b/src/table.rs @@ -11,14 +11,15 @@ use serde_bytes::ByteBuf; use crate::data::*; use crate::error::Error; use crate::membership::System; -use crate::proto::*; use crate::rpc_client::*; +use crate::rpc_server::*; use crate::table_sync::*; pub struct Table { pub instance: F, pub name: String, + pub rpc_client: Arc>>, pub system: Arc, pub store: sled::Tree, @@ -35,24 +36,6 @@ pub struct TableReplicationParams { pub timeout: Duration, } -#[async_trait] -pub trait TableRpcHandler { - async fn handle(&self, rpc: &[u8]) -> Result, Error>; -} - -struct TableRpcHandlerAdapter { - table: Arc>, -} - -#[async_trait] -impl TableRpcHandler for TableRpcHandlerAdapter { - async fn handle(&self, rpc: &[u8]) -> Result, Error> { - let msg = rmp_serde::decode::from_read_ref::<_, TableRPC>(rpc)?; - let rep = self.table.handle(msg).await?; - Ok(rmp_to_vec_all_named(&rep)?) - } -} - #[derive(Serialize, Deserialize)] pub enum TableRPC { Ok, @@ -67,6 +50,8 @@ pub enum TableRPC { SyncRPC(SyncRPC), } +impl RpcMessage for TableRPC {} + pub trait PartitionKey { fn hash(&self) -> Hash; } @@ -136,18 +121,27 @@ impl Table { db: &sled::Db, name: String, param: TableReplicationParams, + rpc_server: &mut RpcServer, ) -> Arc { let store = db.open_tree(&name).expect("Unable to open DB tree"); + + let rpc_path = format!("table_{}", name); + let rpc_client = system.rpc_client::>(&rpc_path); + let table = Arc::new(Self { instance, name, + rpc_client, system, store, param, syncer: ArcSwapOption::from(None), }); + table.clone().register_handler(rpc_server, rpc_path); + let syncer = TableSyncer::launch(table.clone()).await; table.syncer.swap(Some(syncer)); + table } @@ -158,9 +152,10 @@ impl Table { //eprintln!("insert who: {:?}", who); let e_enc = Arc::new(ByteBuf::from(rmp_to_vec_all_named(e)?)); - let rpc = &TableRPC::::Update(vec![e_enc]); + let rpc = TableRPC::::Update(vec![e_enc]); - self.rpc_try_call_many(&who[..], &rpc, self.param.write_quorum) + self.rpc_client + .try_call_many(&who[..], rpc, self.param.write_quorum, self.param.timeout) .await?; Ok(()) } @@ -183,10 +178,8 @@ impl Table { let call_futures = call_list.drain().map(|(node, entries)| async move { let rpc = TableRPC::::Update(entries); - let rpc_bytes = rmp_to_vec_all_named(&rpc)?; - let message = Message::TableRPC(self.name.to_string(), rpc_bytes); - let resp = rpc_call(self.system.clone(), &node, &message, self.param.timeout).await?; + let resp = self.rpc_client.call(&node, rpc, self.param.timeout).await?; Ok::<_, Error>((node, resp)) }); let mut resps = call_futures.collect::>(); @@ -214,9 +207,10 @@ impl Table { let who = ring.walk_ring(&hash, self.param.replication_factor); //eprintln!("get who: {:?}", who); - let rpc = &TableRPC::::ReadEntry(partition_key.clone(), sort_key.clone()); + let rpc = TableRPC::::ReadEntry(partition_key.clone(), sort_key.clone()); let resps = self - .rpc_try_call_many(&who[..], &rpc, self.param.read_quorum) + .rpc_client + .try_call_many(&who[..], rpc, self.param.read_quorum, self.param.timeout) .await?; let mut ret = None; @@ -264,9 +258,10 @@ impl Table { let who = ring.walk_ring(&hash, self.param.replication_factor); let rpc = - &TableRPC::::ReadRange(partition_key.clone(), begin_sort_key.clone(), filter, limit); + TableRPC::::ReadRange(partition_key.clone(), begin_sort_key.clone(), filter, limit); let resps = self - .rpc_try_call_many(&who[..], &rpc, self.param.read_quorum) + .rpc_client + .try_call_many(&who[..], rpc, self.param.read_quorum, self.param.timeout) .await?; let mut ret = BTreeMap::new(); @@ -315,71 +310,24 @@ impl Table { async fn repair_on_read(&self, who: &[UUID], what: F::E) -> Result<(), Error> { let what_enc = Arc::new(ByteBuf::from(rmp_to_vec_all_named(&what)?)); - self.rpc_try_call_many(&who[..], &TableRPC::::Update(vec![what_enc]), who.len()) + self.rpc_client + .try_call_many( + &who[..], + TableRPC::::Update(vec![what_enc]), + who.len(), + self.param.timeout, + ) .await?; Ok(()) } - async fn rpc_try_call_many( - &self, - who: &[UUID], - rpc: &TableRPC, - quorum: usize, - ) -> Result>, Error> { - //eprintln!("Table RPC to {:?}: {}", who, serde_json::to_string(&rpc)?); - - let rpc_bytes = rmp_to_vec_all_named(rpc)?; - let rpc_msg = Message::TableRPC(self.name.to_string(), rpc_bytes); - - let resps = rpc_try_call_many( - self.system.clone(), - who, - rpc_msg, - quorum, - self.param.timeout, - ) - .await?; - - let mut resps_vals = vec![]; - for resp in resps { - if let Message::TableRPC(tbl, rep_by) = &resp { - if *tbl == self.name { - resps_vals.push(rmp_serde::decode::from_read_ref(&rep_by)?); - continue; - } - } - return Err(Error::Message(format!( - "Invalid reply to TableRPC: {:?}", - resp - ))); - } - //eprintln!( - // "Table RPC responses: {}", - // serde_json::to_string(&resps_vals)? - //); - Ok(resps_vals) - } - - pub async fn rpc_call(&self, who: &UUID, rpc: &TableRPC) -> Result, Error> { - let rpc_bytes = rmp_to_vec_all_named(rpc)?; - let rpc_msg = Message::TableRPC(self.name.to_string(), rpc_bytes); - - let resp = rpc_call(self.system.clone(), who, &rpc_msg, self.param.timeout).await?; - if let Message::TableRPC(tbl, rep_by) = &resp { - if *tbl == self.name { - return Ok(rmp_serde::decode::from_read_ref(&rep_by)?); - } - } - Err(Error::Message(format!( - "Invalid reply to TableRPC: {:?}", - resp - ))) - } - // =============== HANDLERS FOR RPC OPERATIONS (SERVER SIDE) ============== - pub fn rpc_handler(self: Arc) -> Box { - Box::new(TableRpcHandlerAdapter:: { table: self }) + fn register_handler(self: Arc, rpc_server: &mut RpcServer, path: String) { + rpc_server.add_handler::, _, _>(path, move |msg, _addr| { + let self2 = self.clone(); + async move { self2.handle(msg).await } + }) } async fn handle(self: &Arc, msg: TableRPC) -> Result, Error> { diff --git a/src/table_sync.rs b/src/table_sync.rs index 024e239f..3ba2fc6a 100644 --- a/src/table_sync.rs +++ b/src/table_sync.rs @@ -360,12 +360,14 @@ impl TableSyncer { // If their root checksum has level > than us, use that as a reference let root_cks_resp = self .table - .rpc_call( + .rpc_client + .call( &who, &TableRPC::::SyncRPC(SyncRPC::GetRootChecksumRange( partition.begin.clone(), partition.end.clone(), )), + self.table.param.timeout, ) .await?; if let TableRPC::::SyncRPC(SyncRPC::RootChecksumRange(range)) = root_cks_resp { @@ -392,9 +394,11 @@ impl TableSyncer { let rpc_resp = self .table - .rpc_call( + .rpc_client + .call( &who, &TableRPC::::SyncRPC(SyncRPC::Checksums(step, retain)), + self.table.param.timeout, ) .await?; if let TableRPC::::SyncRPC(SyncRPC::Difference(mut diff_ranges, diff_items)) = @@ -451,7 +455,12 @@ impl TableSyncer { } let rpc_resp = self .table - .rpc_call(&who, &TableRPC::::Update(values)) + .rpc_client + .call( + &who, + &TableRPC::::Update(values), + self.table.param.timeout, + ) .await?; if let TableRPC::::Ok = rpc_resp { Ok(())