diff --git a/TODO b/TODO index a362f35d..3c3682cc 100644 --- a/TODO +++ b/TODO @@ -10,8 +10,6 @@ Improvements Membership: keep IP addresses of failed nodes and try to reping them regularly -RPC client/server: do not go through the serialization+HTTP+TLS+deserialization when doing a request to ourself. - Attaining S3 compatibility -------------------------- diff --git a/src/admin_rpc.rs b/src/admin_rpc.rs index aa6bd82e..458df360 100644 --- a/src/admin_rpc.rs +++ b/src/admin_rpc.rs @@ -172,7 +172,7 @@ impl AdminRpcHandler { if self .rpc_client .call( - node, + *node, AdminRPC::LaunchRepair(opt_to_send.clone()), ADMIN_RPC_TIMEOUT, ) diff --git a/src/block.rs b/src/block.rs index 21ddf837..23222a7f 100644 --- a/src/block.rs +++ b/src/block.rs @@ -96,19 +96,27 @@ impl BlockManager { } fn register_handler(self: Arc, rpc_server: &mut RpcServer, path: String) { + let self2 = self.clone(); rpc_server.add_handler::(path, move |msg, _addr| { - let self2 = self.clone(); - async move { - match msg { - Message::PutBlock(m) => self2.write_block(&m.hash, &m.data).await, - Message::GetBlock(h) => self2.read_block(&h).await, - Message::NeedBlockQuery(h) => { - self2.need_block(&h).await.map(Message::NeedBlockReply) - } - _ => Err(Error::BadRequest(format!("Unexpected RPC message"))), - } - } + let self2 = self2.clone(); + async move { self2.handle(&msg).await } }); + + let self2 = self.clone(); + self.rpc_client + .set_local_handler(self.system.id, move |msg| { + let self2 = self2.clone(); + async move { self2.handle(&msg).await } + }); + } + + async fn handle(self: Arc, msg: &Message) -> Result { + match msg { + Message::PutBlock(m) => self.write_block(&m.hash, &m.data).await, + Message::GetBlock(h) => self.read_block(h).await, + Message::NeedBlockQuery(h) => self.need_block(h).await.map(Message::NeedBlockReply), + _ => Err(Error::BadRequest(format!("Unexpected RPC message"))), + } } pub async fn spawn_background_worker(self: Arc) { @@ -299,7 +307,7 @@ impl BlockManager { let msg = Arc::new(Message::NeedBlockQuery(*hash)); let who_needs_fut = who.iter().map(|to| { self.rpc_client - .call(to, msg.clone(), NEED_BLOCK_QUERY_TIMEOUT) + .call_arc(*to, msg.clone(), NEED_BLOCK_QUERY_TIMEOUT) }); let who_needs = join_all(who_needs_fut).await; diff --git a/src/membership.rs b/src/membership.rs index f9ffa3b4..c0c88a43 100644 --- a/src/membership.rs +++ b/src/membership.rs @@ -297,7 +297,7 @@ impl System { let (update_ring, ring) = watch::channel(Arc::new(ring)); let rpc_http_client = Arc::new( - RpcHttpClient::new(config.max_concurrent_requests, &config.rpc_tls) + RpcHttpClient::new(config.max_concurrent_rpc_requests, &config.rpc_tls) .expect("Could not create RPC client"), ); @@ -633,7 +633,7 @@ impl System { async move { let resp = self .rpc_client - .call(&peer, Message::PullStatus, PING_TIMEOUT) + .call(peer, Message::PullStatus, PING_TIMEOUT) .await; if let Ok(Message::AdvertiseNodesUp(nodes)) = resp { let _: Result<_, _> = self.handle_advertise_nodes_up(&nodes).await; @@ -644,7 +644,7 @@ impl System { pub async fn pull_config(self: Arc, peer: UUID) { let resp = self .rpc_client - .call(&peer, Message::PullConfig, PING_TIMEOUT) + .call(peer, Message::PullConfig, PING_TIMEOUT) .await; if let Ok(Message::AdvertiseConfig(config)) = resp { let _: Result<_, _> = self.handle_advertise_config(&config).await; diff --git a/src/rpc_client.rs b/src/rpc_client.rs index 8bc3fe50..e78079c2 100644 --- a/src/rpc_client.rs +++ b/src/rpc_client.rs @@ -1,10 +1,13 @@ use std::borrow::Borrow; use std::marker::PhantomData; use std::net::SocketAddr; +use std::pin::Pin; use std::sync::Arc; use std::time::Duration; +use arc_swap::ArcSwapOption; use bytes::IntoBuf; +use futures::future::Future; use futures::stream::futures_unordered::FuturesUnordered; use futures::stream::StreamExt; use futures_util::future::FutureExt; @@ -47,10 +50,15 @@ impl RequestStrategy { } } +pub type LocalHandlerFn = + Box) -> Pin> + Send>> + Send + Sync>; + pub struct RpcClient { status: watch::Receiver>, background: Arc, + local_handler: ArcSwapOption<(UUID, LocalHandlerFn)>, + pub rpc_addr_client: RpcAddrClient, } @@ -64,19 +72,38 @@ impl RpcClient { rpc_addr_client: rac, background, status, + local_handler: ArcSwapOption::new(None), }) } + pub fn set_local_handler(&self, my_id: UUID, handler: F) + where + F: Fn(Arc) -> Fut + Send + Sync + 'static, + Fut: Future> + Send + 'static, + { + let handler_arc = Arc::new(handler); + let handler: LocalHandlerFn = Box::new(move |msg| { + let handler_arc2 = handler_arc.clone(); + Box::pin(async move { handler_arc2(msg).await }) + }); + self.local_handler.swap(Some(Arc::new((my_id, handler)))); + } + pub fn by_addr(&self) -> &RpcAddrClient { &self.rpc_addr_client } - pub async fn call, N: Borrow>( - &self, - to: N, - msg: MB, - timeout: Duration, - ) -> Result { + pub async fn call(&self, to: UUID, msg: M, timeout: Duration) -> Result { + self.call_arc(to, Arc::new(msg), timeout).await + } + + pub async fn call_arc(&self, to: UUID, msg: Arc, timeout: Duration) -> Result { + if let Some(lh) = self.local_handler.load_full() { + let (my_id, local_handler) = lh.as_ref(); + if to.borrow() == my_id { + return local_handler(msg).await; + } + } let addr = { let status = self.status.borrow().clone(); match status.nodes.get(to.borrow()) { @@ -96,7 +123,7 @@ impl RpcClient { let msg = Arc::new(msg); let mut resp_stream = to .iter() - .map(|to| self.call(to, msg.clone(), timeout)) + .map(|to| self.call_arc(*to, msg.clone(), timeout)) .collect::>(); let mut results = vec![]; @@ -121,7 +148,7 @@ impl RpcClient { .map(|to| { let self2 = self.clone(); let msg = msg.clone(); - async move { self2.call(to, msg, timeout).await } + async move { self2.call_arc(to, msg, timeout).await } }) .collect::>(); @@ -155,7 +182,7 @@ impl RpcClient { resp_stream.collect::>().await; Ok(()) }); - self.clone().background.spawn(wait_finished_fut.map(|x| { + self.background.spawn(wait_finished_fut.map(|x| { x.unwrap_or_else(|e| Err(Error::Message(format!("Await failed: {}", e)))) })); } diff --git a/src/server.rs b/src/server.rs index 7b6f2240..3ea29105 100644 --- a/src/server.rs +++ b/src/server.rs @@ -35,8 +35,8 @@ pub struct Config { pub bootstrap_peers: Vec, - #[serde(default = "default_max_concurrent_requests")] - pub max_concurrent_requests: usize, + #[serde(default = "default_max_concurrent_rpc_requests")] + pub max_concurrent_rpc_requests: usize, #[serde(default = "default_block_size")] pub block_size: usize, @@ -53,7 +53,7 @@ pub struct Config { pub rpc_tls: Option, } -fn default_max_concurrent_requests() -> usize { +fn default_max_concurrent_rpc_requests() -> usize { 12 } fn default_block_size() -> usize { @@ -262,7 +262,7 @@ pub async fn run_server(config_file: PathBuf) -> Result<(), Error> { info!("Initializing background runner..."); let (send_cancel, watch_cancel) = watch::channel(false); - let background = BackgroundRunner::new(8, watch_cancel.clone()); + let background = BackgroundRunner::new(16, watch_cancel.clone()); let garage = Garage::new(config, id, db, background.clone(), &mut rpc_server).await; diff --git a/src/table.rs b/src/table.rs index 53e17396..a3d02d0c 100644 --- a/src/table.rs +++ b/src/table.rs @@ -204,7 +204,7 @@ where let call_futures = call_list.drain().map(|(node, entries)| async move { let rpc = TableRPC::::Update(entries); - let resp = self.rpc_client.call(&node, rpc, TABLE_RPC_TIMEOUT).await?; + let resp = self.rpc_client.call(node, rpc, TABLE_RPC_TIMEOUT).await?; Ok::<_, Error>((node, resp)) }); let mut resps = call_futures.collect::>(); @@ -358,20 +358,28 @@ where // =============== HANDLERS FOR RPC OPERATIONS (SERVER SIDE) ============== fn register_handler(self: Arc, rpc_server: &mut RpcServer, path: String) { + let self2 = self.clone(); rpc_server.add_handler::, _, _>(path, move |msg, _addr| { - let self2 = self.clone(); - async move { self2.handle(msg).await } - }) + let self2 = self2.clone(); + async move { self2.handle(&msg).await } + }); + + let self2 = self.clone(); + self.rpc_client + .set_local_handler(self.system.id, move |msg| { + let self2 = self2.clone(); + async move { self2.handle(&msg).await } + }); } - async fn handle(self: &Arc, msg: TableRPC) -> Result, Error> { + async fn handle(self: &Arc, msg: &TableRPC) -> Result, Error> { match msg { TableRPC::ReadEntry(key, sort_key) => { - let value = self.handle_read_entry(&key, &sort_key)?; + let value = self.handle_read_entry(key, sort_key)?; Ok(TableRPC::ReadEntryResponse(value)) } TableRPC::ReadRange(key, begin_sort_key, filter, limit) => { - let values = self.handle_read_range(&key, &begin_sort_key, &filter, limit)?; + let values = self.handle_read_range(key, begin_sort_key, filter, *limit)?; Ok(TableRPC::Update(values)) } TableRPC::Update(pairs) => { @@ -381,7 +389,7 @@ where TableRPC::SyncRPC(rpc) => { let syncer = self.syncer.load_full().unwrap(); let response = syncer - .handle_rpc(&rpc, self.system.background.stop_signal.clone()) + .handle_rpc(rpc, self.system.background.stop_signal.clone()) .await?; Ok(TableRPC::SyncRPC(response)) } @@ -433,14 +441,11 @@ where Ok(ret) } - pub async fn handle_update( - self: &Arc, - mut entries: Vec>, - ) -> Result<(), Error> { + pub async fn handle_update(self: &Arc, entries: &[Arc]) -> Result<(), Error> { let syncer = self.syncer.load_full().unwrap(); let mut epidemic_propagate = vec![]; - for update_bytes in entries.drain(..) { + for update_bytes in entries.iter() { let update = rmp_serde::decode::from_read_ref::<_, F::E>(update_bytes.as_slice())?; let tree_key = self.tree_key(update.partition_key(), update.sort_key()); diff --git a/src/table_sync.rs b/src/table_sync.rs index 603c7aa6..60d5c4df 100644 --- a/src/table_sync.rs +++ b/src/table_sync.rs @@ -457,8 +457,8 @@ where .table .rpc_client .call( - &who, - &TableRPC::::SyncRPC(SyncRPC::GetRootChecksumRange( + who, + TableRPC::::SyncRPC(SyncRPC::GetRootChecksumRange( partition.begin.clone(), partition.end.clone(), )), @@ -496,8 +496,8 @@ where .table .rpc_client .call( - &who, - &TableRPC::::SyncRPC(SyncRPC::Checksums(step, retain)), + who, + TableRPC::::SyncRPC(SyncRPC::Checksums(step, retain)), TABLE_SYNC_RPC_TIMEOUT, ) .await?; @@ -523,7 +523,7 @@ where } } if retain && diff_items.len() > 0 { - self.table.handle_update(diff_items).await?; + self.table.handle_update(&diff_items[..]).await?; } if items_to_send.len() > 0 { self.send_items(who, items_to_send).await?; @@ -555,7 +555,7 @@ where let rpc_resp = self .table .rpc_client - .call(&who, &TableRPC::::Update(values), TABLE_SYNC_RPC_TIMEOUT) + .call(who, TableRPC::::Update(values), TABLE_SYNC_RPC_TIMEOUT) .await?; if let TableRPC::::Ok = rpc_resp { Ok(())