From d5272ce486075673f6f8f6ed6034b823a47b9542 Mon Sep 17 00:00:00 2001 From: Alex Auvolat Date: Fri, 1 Jul 2022 19:23:29 +0200 Subject: [PATCH] Incomming locking loop --- src/mail/incoming.rs | 179 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 145 insertions(+), 34 deletions(-) diff --git a/src/mail/incoming.rs b/src/mail/incoming.rs index 398f555..1d16d38 100644 --- a/src/mail/incoming.rs +++ b/src/mail/incoming.rs @@ -1,8 +1,10 @@ use std::collections::HashMap; +use std::pin::Pin; use std::sync::{Arc, Weak}; use std::time::Duration; -use anyhow::Result; +use anyhow::{anyhow, bail, Result}; +use futures::{future::BoxFuture, Future, FutureExt}; use k2v_client::{CausalValue, CausalityToken, K2vClient, K2vValue}; use rusoto_s3::{PutObjectRequest, S3Client, S3}; use tokio::sync::watch; @@ -20,6 +22,12 @@ const INCOMING_PK: &str = "incoming"; const INCOMING_LOCK_SK: &str = "lock"; const INCOMING_WATCH_SK: &str = "watch"; +// When a lock is held, it is held for LOCK_DURATION (here 5 minutes) +// It is renewed every LOCK_DURATION/3 +// If we are at 2*LOCK_DURATION/3 and haven't renewed, we assume we +// lost the lock. +const LOCK_DURATION: Duration = Duration::from_secs(300); + pub async fn incoming_mail_watch_process( user: Weak, creds: Credentials, @@ -124,55 +132,158 @@ async fn handle_incoming_mail(user: &Arc, s3: &S3Client, inbox: &Arc watch::Receiver { let (held_tx, held_rx) = watch::channel(false); - tokio::spawn(async move { - let _ = k2v_lock_loop_internal(k2v, pk, sk, held_tx).await; - }); + tokio::spawn(k2v_lock_loop_internal(k2v, pk, sk, held_tx)); held_rx } +#[derive(Clone, Debug)] +enum LockState { + Unknown, + Empty, + Held(UniqueIdent, u64, CausalityToken), +} + async fn k2v_lock_loop_internal( k2v: K2vClient, pk: &'static str, sk: &'static str, held_tx: watch::Sender, -) -> std::result::Result<(), watch::error::SendError> { - let pid = gen_ident(); +) { + let (state_tx, mut state_rx) = watch::channel::(LockState::Unknown); + let mut state_rx_2 = state_rx.clone(); - let mut state: Option<(UniqueIdent, u64, CausalityToken)> = None; - loop { - let held_until = match &state { - None => None, - Some((_holder, expiration_time, _ct)) => Some(expiration_time), - }; + let our_pid = gen_ident(); - let now = now_msec(); - let wait_half_held_time = async { - match held_until { - None => futures::future::pending().await, - Some(t) => tokio::time::sleep(Duration::from_millis((now_msec() - t) / 2)).await, - } - }; - - unimplemented!(); - - /* - tokio::select! { - ret = k2v_wait_value_changed(&k2v, pk, sk, &state.as_ref().map(|(_, _, ct)| ct.clone())) => { - match ret { - Err(e) => { - held_tx.send(false)?; - tokio::time::sleep(Duration::from_secs(30)).await; - continue; - } - Ok(cv) => { - unimplemented!(); + // Loop 1: watch state of lock in K2V, save that in corresponding watch channel + let watch_lock_loop: BoxFuture> = async { + let mut ct = None; + loop { + match k2v_wait_value_changed(&k2v, pk, sk, &ct).await { + Err(e) => { + error!( + "Error in k2v wait value changed: {} ; assuming we no longer hold lock.", + e + ); + state_tx.send(LockState::Unknown)?; + tokio::time::sleep(Duration::from_secs(30)).await; + } + Ok(cv) => { + let mut lock_state = None; + for v in cv.value.iter() { + if let K2vValue::Value(vbytes) = v { + if vbytes.len() == 32 { + let ts = u64::from_be_bytes(vbytes[..8].try_into().unwrap()); + let pid = UniqueIdent(vbytes[8..].try_into().unwrap()); + if lock_state + .map(|(pid2, ts2)| ts > ts2 || (ts == ts2 && pid > pid2)) + .unwrap_or(true) + { + lock_state = Some((pid, ts)); + } + } + } } + state_tx.send( + lock_state + .map(|(pid, ts)| LockState::Held(pid, ts, cv.causality.clone())) + .unwrap_or(LockState::Empty), + )?; + ct = Some(cv.causality); } } + info!("Stopping lock state watch"); } - */ } + .boxed(); + + // Loop 2: notify user whether we are holding the lock or not + let lock_notify_loop: BoxFuture> = async { + loop { + let now = now_msec(); + let held_with_expiration_time = match &*state_rx.borrow_and_update() { + LockState::Held(pid, ts, _ct) if *pid == our_pid => { + let expiration_time = *ts - (LOCK_DURATION / 3).as_millis() as u64; + if now < expiration_time { + Some(expiration_time) + } else { + None + } + } + _ => None, + }; + held_tx.send(held_with_expiration_time.is_some())?; + + let await_expired = async { + match held_with_expiration_time { + None => futures::future::pending().await, + Some(expiration_time) => { + tokio::time::sleep(Duration::from_millis(expiration_time - now)).await + } + }; + }; + + tokio::select!( + r = state_rx.changed() => { + r?; + } + _ = held_tx.closed() => bail!("held_tx closed, don't need to hold lock anymore"), + _ = await_expired => continue, + ); + } + } + .boxed(); + + // Loop 3: acquire lock when relevant + let take_lock_loop: BoxFuture> = async { + loop { + let now = now_msec(); + let state: LockState = state_rx_2.borrow_and_update().clone(); + let (acquire_at, ct) = match state { + LockState::Unknown => { + // If state of the lock is unknown, don't try to acquire + state_rx_2.changed().await?; + continue; + } + LockState::Empty => (now, None), + LockState::Held(pid, ts, ct) => { + if pid == our_pid { + (ts - (2 * LOCK_DURATION / 3).as_millis() as u64, Some(ct)) + } else { + (ts, Some(ct)) + } + } + }; + + // Wait until it is time to acquire lock + if acquire_at > now { + tokio::select!( + r = state_rx_2.changed() => { + // If lock state changed in the meantime, don't acquire and loop around + r?; + continue; + } + _ = tokio::time::sleep(Duration::from_millis(acquire_at - now)) => () + ); + } + + // Acquire lock + let mut lock = vec![0u8; 32]; + lock[..8].copy_from_slice(&u64::to_be_bytes(now_msec())); + lock[8..].copy_from_slice(&our_pid.0); + if let Err(e) = k2v.insert_item(pk, sk, lock, ct).await { + error!("Could not take lock: {}", e); + tokio::time::sleep(Duration::from_secs(30)).await; + } + + // Wait for new information to loop back + state_rx_2.changed().await?; + } + } + .boxed(); + + let res = futures::try_join!(watch_lock_loop, lock_notify_loop, take_lock_loop); + info!("lock loop exited: {:?}", res); } async fn k2v_wait_value_changed<'a>(