Add semaphore to limit RAM used by buffered outgoing requests
All checks were successful
continuous-integration/drone/pr Build is passing
continuous-integration/drone/push Build is passing
continuous-integration/drone Build is passing

This commit is contained in:
Alex 2021-11-03 17:00:40 +01:00
parent 8c4f418fe8
commit 6f13d083ab
No known key found for this signature in database
GPG key ID: EDABF9711E244EB1
3 changed files with 34 additions and 7 deletions

View file

@ -7,6 +7,7 @@ use futures::stream::futures_unordered::FuturesUnordered;
use futures::stream::StreamExt; use futures::stream::StreamExt;
use futures_util::future::FutureExt; use futures_util::future::FutureExt;
use tokio::select; use tokio::select;
use tokio::sync::Semaphore;
pub use netapp::endpoint::{Endpoint, EndpointHandler, Message as Rpc}; pub use netapp::endpoint::{Endpoint, EndpointHandler, Message as Rpc};
use netapp::peering::fullmesh::FullMeshPeeringStrategy; use netapp::peering::fullmesh::FullMeshPeeringStrategy;
@ -14,11 +15,16 @@ pub use netapp::proto::*;
pub use netapp::{NetApp, NodeID}; pub use netapp::{NetApp, NodeID};
use garage_util::background::BackgroundRunner; use garage_util::background::BackgroundRunner;
use garage_util::data::Uuid; use garage_util::data::*;
use garage_util::error::Error; use garage_util::error::Error;
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10); const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10);
// Try to never have more than 200MB of outgoing requests
// buffered at the same time. Other requests are queued until
// space is freed.
const REQUEST_BUFFER_SIZE: usize = 200 * 1024 * 1024;
/// Strategy to apply when making RPC /// Strategy to apply when making RPC
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
pub struct RequestStrategy { pub struct RequestStrategy {
@ -64,9 +70,21 @@ impl RequestStrategy {
pub struct RpcHelper { pub struct RpcHelper {
pub(crate) fullmesh: Arc<FullMeshPeeringStrategy>, pub(crate) fullmesh: Arc<FullMeshPeeringStrategy>,
pub(crate) background: Arc<BackgroundRunner>, pub(crate) background: Arc<BackgroundRunner>,
request_buffer_semaphore: Arc<Semaphore>,
} }
impl RpcHelper { impl RpcHelper {
pub(crate) fn new(
fullmesh: Arc<FullMeshPeeringStrategy>,
background: Arc<BackgroundRunner>,
) -> Self {
Self {
fullmesh,
background,
request_buffer_semaphore: Arc::new(Semaphore::new(REQUEST_BUFFER_SIZE)),
}
}
pub async fn call<M, H, S>( pub async fn call<M, H, S>(
&self, &self,
endpoint: &Endpoint<M, H>, endpoint: &Endpoint<M, H>,
@ -92,10 +110,19 @@ impl RpcHelper {
M: Rpc<Response = Result<S, Error>>, M: Rpc<Response = Result<S, Error>>,
H: EndpointHandler<M>, H: EndpointHandler<M>,
{ {
let msg_size = rmp_to_vec_all_named(&msg)?.len() as u32;
let permit = self.request_buffer_semaphore.acquire_many(msg_size).await?;
let node_id = to.into(); let node_id = to.into();
select! { select! {
res = endpoint.call(&node_id, &msg, strat.rs_priority) => Ok(res??), res = endpoint.call(&node_id, &msg, strat.rs_priority) => {
_ = tokio::time::sleep(strat.rs_timeout) => Err(Error::Timeout), drop(permit);
Ok(res??)
}
_ = tokio::time::sleep(strat.rs_timeout) => {
drop(permit);
Err(Error::Timeout)
}
} }
} }

View file

@ -235,10 +235,7 @@ impl System {
node_status: RwLock::new(HashMap::new()), node_status: RwLock::new(HashMap::new()),
netapp: netapp.clone(), netapp: netapp.clone(),
fullmesh: fullmesh.clone(), fullmesh: fullmesh.clone(),
rpc: RpcHelper { rpc: RpcHelper::new(fullmesh, background.clone()),
fullmesh,
background: background.clone(),
},
system_endpoint, system_endpoint,
replication_factor, replication_factor,
rpc_listen_addr: config.rpc_bind_addr, rpc_listen_addr: config.rpc_bind_addr,

View file

@ -41,6 +41,9 @@ pub enum Error {
#[error(display = "Tokio join error: {}", _0)] #[error(display = "Tokio join error: {}", _0)]
TokioJoin(#[error(source)] tokio::task::JoinError), TokioJoin(#[error(source)] tokio::task::JoinError),
#[error(display = "Tokio semaphore acquire error: {}", _0)]
TokioSemAcquire(#[error(source)] tokio::sync::AcquireError),
#[error(display = "Remote error: {}", _0)] #[error(display = "Remote error: {}", _0)]
RemoteError(String), RemoteError(String),