format + fix storage bug

This commit is contained in:
Quentin 2024-01-31 11:01:18 +01:00
parent c27919a757
commit 22f0eb901a
Signed by: quentin
GPG key ID: E9602264D639FF68
5 changed files with 291 additions and 177 deletions

View file

@ -1,6 +1,6 @@
use std::net::SocketAddr; use std::net::SocketAddr;
use anyhow::{Result, anyhow, bail}; use anyhow::{anyhow, bail, Result};
use futures::stream::{FuturesUnordered, StreamExt}; use futures::stream::{FuturesUnordered, StreamExt};
use tokio::io::BufStream; use tokio::io::BufStream;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt}; use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
@ -54,22 +54,20 @@ pub struct AuthServer {
bind_addr: SocketAddr, bind_addr: SocketAddr,
} }
impl AuthServer { impl AuthServer {
pub fn new( pub fn new(config: AuthConfig, login_provider: ArcLoginProvider) -> Self {
config: AuthConfig,
login_provider: ArcLoginProvider,
) -> Self {
Self { Self {
bind_addr: config.bind_addr, bind_addr: config.bind_addr,
login_provider, login_provider,
} }
} }
pub async fn run(self: Self, mut must_exit: watch::Receiver<bool>) -> Result<()> { pub async fn run(self: Self, mut must_exit: watch::Receiver<bool>) -> Result<()> {
let tcp = TcpListener::bind(self.bind_addr).await?; let tcp = TcpListener::bind(self.bind_addr).await?;
tracing::info!("SASL Authentication Protocol listening on {:#}", self.bind_addr); tracing::info!(
"SASL Authentication Protocol listening on {:#}",
self.bind_addr
);
let mut connections = FuturesUnordered::new(); let mut connections = FuturesUnordered::new();
@ -89,8 +87,9 @@ impl AuthServer {
}; };
tracing::info!("AUTH: accepted connection from {}", remote_addr); tracing::info!("AUTH: accepted connection from {}", remote_addr);
let conn = tokio::spawn(NetLoop::new(socket, self.login_provider.clone(), must_exit.clone()).run_error()); let conn = tokio::spawn(
NetLoop::new(socket, self.login_provider.clone(), must_exit.clone()).run_error(),
);
connections.push(conn); connections.push(conn);
} }
@ -106,7 +105,7 @@ impl AuthServer {
struct NetLoop { struct NetLoop {
login: ArcLoginProvider, login: ArcLoginProvider,
stream: BufStream<TcpStream>, stream: BufStream<TcpStream>,
stop: watch::Receiver<bool>, stop: watch::Receiver<bool>,
state: State, state: State,
read_buf: Vec<u8>, read_buf: Vec<u8>,
write_buf: BytesMut, write_buf: BytesMut,
@ -197,45 +196,53 @@ enum State {
Init, Init,
HandshakePart(Version), HandshakePart(Version),
HandshakeDone, HandshakeDone,
AuthPlainProgress { AuthPlainProgress { id: u64 },
id: u64, AuthDone { id: u64, res: AuthRes },
},
AuthDone {
id: u64,
res: AuthRes
},
} }
const SERVER_MAJOR: u64 = 1; const SERVER_MAJOR: u64 = 1;
const SERVER_MINOR: u64 = 2; const SERVER_MINOR: u64 = 2;
impl State { impl State {
async fn progress(&mut self, cmd: ClientCommand, login: &ArcLoginProvider) { async fn progress(&mut self, cmd: ClientCommand, login: &ArcLoginProvider) {
let new_state = 'state: { let new_state = 'state: {
match (std::mem::replace(self, State::Error), cmd) { match (std::mem::replace(self, State::Error), cmd) {
(Self::Init, ClientCommand::Version(v)) => Self::HandshakePart(v), (Self::Init, ClientCommand::Version(v)) => Self::HandshakePart(v),
(Self::HandshakePart(version), ClientCommand::Cpid(_cpid)) => { (Self::HandshakePart(version), ClientCommand::Cpid(_cpid)) => {
if version.major != SERVER_MAJOR { if version.major != SERVER_MAJOR {
tracing::error!(client_major=version.major, server_major=SERVER_MAJOR, "Unsupported client major version"); tracing::error!(
break 'state Self::Error client_major = version.major,
server_major = SERVER_MAJOR,
"Unsupported client major version"
);
break 'state Self::Error;
} }
Self::HandshakeDone Self::HandshakeDone
}, }
(Self::HandshakeDone { .. }, ClientCommand::Auth { id, mech, .. }) | (Self::HandshakeDone { .. }, ClientCommand::Auth { id, mech, .. })
(Self::AuthDone { .. }, ClientCommand::Auth { id, mech, ..}) => { | (Self::AuthDone { .. }, ClientCommand::Auth { id, mech, .. }) => {
if mech != Mechanism::Plain { if mech != Mechanism::Plain {
tracing::error!(mechanism=?mech, "Unsupported Authentication Mechanism"); tracing::error!(mechanism=?mech, "Unsupported Authentication Mechanism");
break 'state Self::AuthDone { id, res: AuthRes::Failed(None, None) } break 'state Self::AuthDone {
id,
res: AuthRes::Failed(None, None),
};
} }
Self::AuthPlainProgress { id } Self::AuthPlainProgress { id }
}, }
(Self::AuthPlainProgress { id }, ClientCommand::Cont { id: cid, data }) => { (Self::AuthPlainProgress { id }, ClientCommand::Cont { id: cid, data }) => {
// Check that ID matches // Check that ID matches
if cid != id { if cid != id {
tracing::error!(auth_id=id, cont_id=cid, "CONT id does not match AUTH id"); tracing::error!(
break 'state Self::AuthDone { id, res: AuthRes::Failed(None, None) } auth_id = id,
cont_id = cid,
"CONT id does not match AUTH id"
);
break 'state Self::AuthDone {
id,
res: AuthRes::Failed(None, None),
};
} }
// Check that we can extract user's login+pass // Check that we can extract user's login+pass
@ -243,36 +250,54 @@ impl State {
Ok(([], ([], user, pass))) => (user, pass), Ok(([], ([], user, pass))) => (user, pass),
Ok(_) => { Ok(_) => {
tracing::error!("Impersonating user is not supported"); tracing::error!("Impersonating user is not supported");
break 'state Self::AuthDone { id, res: AuthRes::Failed(None, None) } break 'state Self::AuthDone {
id,
res: AuthRes::Failed(None, None),
};
} }
Err(e) => { Err(e) => {
tracing::error!(err=?e, "Could not parse the SASL PLAIN data chunk"); tracing::error!(err=?e, "Could not parse the SASL PLAIN data chunk");
break 'state Self::AuthDone { id, res: AuthRes::Failed(None, None) } break 'state Self::AuthDone {
}, id,
res: AuthRes::Failed(None, None),
};
}
}; };
// Try to convert it to UTF-8 // Try to convert it to UTF-8
let (user, password) = match (std::str::from_utf8(ubin), std::str::from_utf8(pbin)) { let (user, password) =
(Ok(u), Ok(p)) => (u, p), match (std::str::from_utf8(ubin), std::str::from_utf8(pbin)) {
_ => { (Ok(u), Ok(p)) => (u, p),
tracing::error!("Username or password contain invalid UTF-8 characters"); _ => {
break 'state Self::AuthDone { id, res: AuthRes::Failed(None, None) } tracing::error!(
} "Username or password contain invalid UTF-8 characters"
}; );
break 'state Self::AuthDone {
id,
res: AuthRes::Failed(None, None),
};
}
};
// Try to connect user // Try to connect user
match login.login(user, password).await { match login.login(user, password).await {
Ok(_) => Self::AuthDone { id, res: AuthRes::Success(user.to_string())}, Ok(_) => Self::AuthDone {
id,
res: AuthRes::Success(user.to_string()),
},
Err(e) => { Err(e) => {
tracing::warn!(err=?e, "login failed"); tracing::warn!(err=?e, "login failed");
Self::AuthDone { id, res: AuthRes::Failed(Some(user.to_string()), None) } Self::AuthDone {
id,
res: AuthRes::Failed(Some(user.to_string()), None),
}
} }
} }
}, }
_ => { _ => {
tracing::error!("This command is not valid in this context"); tracing::error!("This command is not valid in this context");
Self::Error Self::Error
}, }
} }
}; };
tracing::debug!(state=?new_state, "Made progress"); tracing::debug!(state=?new_state, "Made progress");
@ -284,7 +309,10 @@ impl State {
match self { match self {
Self::HandshakeDone { .. } => { Self::HandshakeDone { .. } => {
srv_cmd.push(ServerCommand::Version(Version { major: SERVER_MAJOR, minor: SERVER_MINOR })); srv_cmd.push(ServerCommand::Version(Version {
major: SERVER_MAJOR,
minor: SERVER_MINOR,
}));
srv_cmd.push(ServerCommand::Mech { srv_cmd.push(ServerCommand::Mech {
kind: Mechanism::Plain, kind: Mechanism::Plain,
@ -299,16 +327,34 @@ impl State {
srv_cmd.push(ServerCommand::Cookie(cookie)); srv_cmd.push(ServerCommand::Cookie(cookie));
srv_cmd.push(ServerCommand::Done); srv_cmd.push(ServerCommand::Done);
}, }
Self::AuthPlainProgress { id } => { Self::AuthPlainProgress { id } => {
srv_cmd.push(ServerCommand::Cont { id: *id, data: None }); srv_cmd.push(ServerCommand::Cont {
}, id: *id,
Self::AuthDone { id, res: AuthRes::Success(user) } => { data: None,
srv_cmd.push(ServerCommand::Ok { id: *id, user_id: Some(user.to_string()), extra_parameters: vec![]}); });
}, }
Self::AuthDone { id, res: AuthRes::Failed(maybe_user, maybe_failcode) } => { Self::AuthDone {
srv_cmd.push(ServerCommand::Fail { id: *id, user_id: maybe_user.clone(), code: maybe_failcode.clone(), extra_parameters: vec![]}); id,
}, res: AuthRes::Success(user),
} => {
srv_cmd.push(ServerCommand::Ok {
id: *id,
user_id: Some(user.to_string()),
extra_parameters: vec![],
});
}
Self::AuthDone {
id,
res: AuthRes::Failed(maybe_user, maybe_failcode),
} => {
srv_cmd.push(ServerCommand::Fail {
id: *id,
user_id: maybe_user.clone(),
code: maybe_failcode.clone(),
extra_parameters: vec![],
});
}
_ => (), _ => (),
}; };
@ -316,7 +362,6 @@ impl State {
} }
} }
// ----------------------------------------------------------------- // -----------------------------------------------------------------
// //
// DOVECOT AUTH TYPES // DOVECOT AUTH TYPES
@ -329,7 +374,6 @@ enum Mechanism {
Login, Login,
} }
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
enum AuthOption { enum AuthOption {
/// Unique session ID. Mainly used for logging. /// Unique session ID. Mainly used for logging.
@ -409,14 +453,13 @@ enum ClientCommand {
service: String, service: String,
/// All the optional parameters /// All the optional parameters
options: Vec<AuthOption>, options: Vec<AuthOption>,
}, },
Cont { Cont {
/// The <id> must match the <id> of the AUTH command. /// The <id> must match the <id> of the AUTH command.
id: u64, id: u64,
/// Data that will be serialized to / deserialized from base64 /// Data that will be serialized to / deserialized from base64
data: Vec<u8>, data: Vec<u8>,
} },
} }
#[derive(Debug)] #[derive(Debug)]
@ -464,7 +507,7 @@ enum ServerCommand {
parameters: Vec<MechanismParameters>, parameters: Vec<MechanismParameters>,
}, },
/// COOKIE returns connection-specific 128 bit cookie in hex. It must be given to REQUEST command. (Protocol v1.1+ / Dovecot v2.0+) /// COOKIE returns connection-specific 128 bit cookie in hex. It must be given to REQUEST command. (Protocol v1.1+ / Dovecot v2.0+)
Cookie([u8;16]), Cookie([u8; 16]),
/// DONE finishes the handshake from server. /// DONE finishes the handshake from server.
Done, Done,
@ -493,26 +536,20 @@ enum ServerCommand {
// //
// ------------------------------------------------------------------ // ------------------------------------------------------------------
use nom::{
IResult,
branch::alt,
error::{ErrorKind, Error},
character::complete::{tab, u64, u16},
bytes::complete::{is_not, tag, tag_no_case, take, take_while, take_while1},
multi::{many1, separated_list0},
combinator::{map, opt, recognize, value, rest},
sequence::{pair, preceded, tuple},
};
use base64::Engine; use base64::Engine;
use nom::{
branch::alt,
bytes::complete::{is_not, tag, tag_no_case, take, take_while, take_while1},
character::complete::{tab, u16, u64},
combinator::{map, opt, recognize, rest, value},
error::{Error, ErrorKind},
multi::{many1, separated_list0},
sequence::{pair, preceded, tuple},
IResult,
};
fn version_command<'a>(input: &'a [u8]) -> IResult<&'a [u8], ClientCommand> { fn version_command<'a>(input: &'a [u8]) -> IResult<&'a [u8], ClientCommand> {
let mut parser = tuple(( let mut parser = tuple((tag_no_case(b"VERSION"), tab, u64, tab, u64));
tag_no_case(b"VERSION"),
tab,
u64,
tab,
u64
));
let (input, (_, _, major, _, minor)) = parser(input)?; let (input, (_, _, major, _, minor)) = parser(input)?;
Ok((input, ClientCommand::Version(Version { major, minor }))) Ok((input, ClientCommand::Version(Version { major, minor })))
@ -521,7 +558,7 @@ fn version_command<'a>(input: &'a [u8]) -> IResult<&'a [u8], ClientCommand> {
fn cpid_command<'a>(input: &'a [u8]) -> IResult<&'a [u8], ClientCommand> { fn cpid_command<'a>(input: &'a [u8]) -> IResult<&'a [u8], ClientCommand> {
preceded( preceded(
pair(tag_no_case(b"CPID"), tab), pair(tag_no_case(b"CPID"), tab),
map(u64, |v| ClientCommand::Cpid(v)) map(u64, |v| ClientCommand::Cpid(v)),
)(input) )(input)
} }
@ -541,10 +578,7 @@ fn is_esc<'a>(input: &'a [u8]) -> IResult<&'a [u8], &[u8]> {
} }
fn parameter<'a>(input: &'a [u8]) -> IResult<&'a [u8], &[u8]> { fn parameter<'a>(input: &'a [u8]) -> IResult<&'a [u8], &[u8]> {
recognize(many1(alt(( recognize(many1(alt((take_while1(is_not_tab_or_esc_or_lf), is_esc))))(input)
take_while1(is_not_tab_or_esc_or_lf),
is_esc
))))(input)
} }
fn parameter_str(input: &[u8]) -> IResult<&[u8], String> { fn parameter_str(input: &[u8]) -> IResult<&[u8], String> {
@ -568,10 +602,7 @@ fn parameter_name(input: &[u8]) -> IResult<&[u8], String> {
} }
fn service<'a>(input: &'a [u8]) -> IResult<&'a [u8], String> { fn service<'a>(input: &'a [u8]) -> IResult<&'a [u8], String> {
preceded( preceded(tag_no_case("service="), parameter_str)(input)
tag_no_case("service="),
parameter_str
)(input)
} }
fn auth_option<'a>(input: &'a [u8]) -> IResult<&'a [u8], AuthOption> { fn auth_option<'a>(input: &'a [u8]) -> IResult<&'a [u8], AuthOption> {
@ -583,31 +614,74 @@ fn auth_option<'a>(input: &'a [u8]) -> IResult<&'a [u8], AuthOption> {
value(ClientId, tag_no_case(b"client_id")), value(ClientId, tag_no_case(b"client_id")),
value(NoLogin, tag_no_case(b"nologin")), value(NoLogin, tag_no_case(b"nologin")),
map(preceded(tag_no_case(b"session="), u64), |id| Session(id)), map(preceded(tag_no_case(b"session="), u64), |id| Session(id)),
map(preceded(tag_no_case(b"lip="), parameter_str), |ip| LocalIp(ip)), map(preceded(tag_no_case(b"lip="), parameter_str), |ip| {
map(preceded(tag_no_case(b"rip="), parameter_str), |ip| RemoteIp(ip)), LocalIp(ip)
map(preceded(tag_no_case(b"lport="), u16), |port| LocalPort(port)), }),
map(preceded(tag_no_case(b"rport="), u16), |port| RemotePort(port)), map(preceded(tag_no_case(b"rip="), parameter_str), |ip| {
map(preceded(tag_no_case(b"real_rip="), parameter_str), |ip| RealRemoteIp(ip)), RemoteIp(ip)
map(preceded(tag_no_case(b"real_lip="), parameter_str), |ip| RealLocalIp(ip)), }),
map(preceded(tag_no_case(b"real_lport="), u16), |port| RealLocalPort(port)), map(preceded(tag_no_case(b"lport="), u16), |port| {
map(preceded(tag_no_case(b"real_rport="), u16), |port| RealRemotePort(port)), LocalPort(port)
}),
map(preceded(tag_no_case(b"rport="), u16), |port| {
RemotePort(port)
}),
map(preceded(tag_no_case(b"real_rip="), parameter_str), |ip| {
RealRemoteIp(ip)
}),
map(preceded(tag_no_case(b"real_lip="), parameter_str), |ip| {
RealLocalIp(ip)
}),
map(preceded(tag_no_case(b"real_lport="), u16), |port| {
RealLocalPort(port)
}),
map(preceded(tag_no_case(b"real_rport="), u16), |port| {
RealRemotePort(port)
}),
)), )),
alt(( alt((
map(preceded(tag_no_case(b"local_name="), parameter_str), |name| LocalName(name)), map(
map(preceded(tag_no_case(b"forward_views="), parameter), |views| ForwardViews(views.into())), preceded(tag_no_case(b"local_name="), parameter_str),
map(preceded(tag_no_case(b"secured="), parameter_str), |info| Secured(Some(info))), |name| LocalName(name),
),
map(
preceded(tag_no_case(b"forward_views="), parameter),
|views| ForwardViews(views.into()),
),
map(preceded(tag_no_case(b"secured="), parameter_str), |info| {
Secured(Some(info))
}),
value(Secured(None), tag_no_case(b"secured")), value(Secured(None), tag_no_case(b"secured")),
value(CertUsername, tag_no_case(b"cert_username")), value(CertUsername, tag_no_case(b"cert_username")),
map(preceded(tag_no_case(b"transport="), parameter_str), |ts| Transport(ts)), map(preceded(tag_no_case(b"transport="), parameter_str), |ts| {
map(preceded(tag_no_case(b"tls_cipher="), parameter_str), |cipher| TlsCipher(cipher)), Transport(ts)
map(preceded(tag_no_case(b"tls_cipher_bits="), parameter_str), |bits| TlsCipherBits(bits)), }),
map(preceded(tag_no_case(b"tls_pfs="), parameter_str), |pfs| TlsPfs(pfs)), map(
map(preceded(tag_no_case(b"tls_protocol="), parameter_str), |proto| TlsProtocol(proto)), preceded(tag_no_case(b"tls_cipher="), parameter_str),
map(preceded(tag_no_case(b"valid-client-cert="), parameter_str), |cert| ValidClientCert(cert)), |cipher| TlsCipher(cipher),
),
map(
preceded(tag_no_case(b"tls_cipher_bits="), parameter_str),
|bits| TlsCipherBits(bits),
),
map(preceded(tag_no_case(b"tls_pfs="), parameter_str), |pfs| {
TlsPfs(pfs)
}),
map(
preceded(tag_no_case(b"tls_protocol="), parameter_str),
|proto| TlsProtocol(proto),
),
map(
preceded(tag_no_case(b"valid-client-cert="), parameter_str),
|cert| ValidClientCert(cert),
),
)), )),
alt(( alt((
map(preceded(tag_no_case(b"resp="), base64), |data| Resp(data)), map(preceded(tag_no_case(b"resp="), base64), |data| Resp(data)),
map(tuple((parameter_name, tag(b"="), parameter)), |(n, _, v)| UnknownPair(n, v.into())), map(
tuple((parameter_name, tag(b"="), parameter)),
|(n, _, v)| UnknownPair(n, v.into()),
),
map(parameter, |v| UnknownBool(v.into())), map(parameter, |v| UnknownBool(v.into())),
)), )),
))(input) ))(input)
@ -622,13 +696,20 @@ fn auth_command<'a>(input: &'a [u8]) -> IResult<&'a [u8], ClientCommand> {
mechanism, mechanism,
tab, tab,
service, service,
map( map(opt(preceded(tab, separated_list0(tab, auth_option))), |o| {
opt(preceded(tab, separated_list0(tab, auth_option))), o.unwrap_or(vec![])
|o| o.unwrap_or(vec![]) }),
),
)); ));
let (input, (_, _, id, _, mech, _, service, options)) = parser(input)?; let (input, (_, _, id, _, mech, _, service, options)) = parser(input)?;
Ok((input, ClientCommand::Auth { id, mech, service, options })) Ok((
input,
ClientCommand::Auth {
id,
mech,
service,
options,
},
))
} }
fn is_base64_core(c: u8) -> bool { fn is_base64_core(c: u8) -> bool {
@ -644,10 +725,7 @@ fn is_base64_pad(c: u8) -> bool {
} }
fn base64(input: &[u8]) -> IResult<&[u8], Vec<u8>> { fn base64(input: &[u8]) -> IResult<&[u8], Vec<u8>> {
let (input, (b64, _)) = tuple(( let (input, (b64, _)) = tuple((take_while1(is_base64_core), take_while(is_base64_pad)))(input)?;
take_while1(is_base64_core),
take_while(is_base64_pad),
))(input)?;
let data = base64::engine::general_purpose::STANDARD_NO_PAD let data = base64::engine::general_purpose::STANDARD_NO_PAD
.decode(b64) .decode(b64)
@ -657,26 +735,15 @@ fn base64(input: &[u8]) -> IResult<&[u8], Vec<u8>> {
} }
/// @FIXME Dovecot does not say if base64 content must be padded or not /// @FIXME Dovecot does not say if base64 content must be padded or not
fn cont_command<'a>(input: &'a [u8]) -> IResult<&'a [u8], ClientCommand> { fn cont_command<'a>(input: &'a [u8]) -> IResult<&'a [u8], ClientCommand> {
let mut parser = tuple(( let mut parser = tuple((tag_no_case(b"CONT"), tab, u64, tab, base64));
tag_no_case(b"CONT"),
tab,
u64,
tab,
base64
));
let (input, (_, _, id, _, data)) = parser(input)?; let (input, (_, _, id, _, data)) = parser(input)?;
Ok((input, ClientCommand::Cont { id, data })) Ok((input, ClientCommand::Cont { id, data }))
} }
fn client_command<'a>(input: &'a [u8]) -> IResult<&'a [u8], ClientCommand> { fn client_command<'a>(input: &'a [u8]) -> IResult<&'a [u8], ClientCommand> {
alt(( alt((version_command, cpid_command, auth_command, cont_command))(input)
version_command,
cpid_command,
auth_command,
cont_command,
))(input)
} }
/* /*
@ -698,7 +765,13 @@ fn not_null(c: u8) -> bool {
// impersonated user, login, password // impersonated user, login, password
fn auth_plain<'a>(input: &'a [u8]) -> IResult<&'a [u8], (&'a [u8], &'a [u8], &'a [u8])> { fn auth_plain<'a>(input: &'a [u8]) -> IResult<&'a [u8], (&'a [u8], &'a [u8], &'a [u8])> {
map( map(
tuple((take_while(not_null), take(1usize), take_while(not_null), take(1usize), rest)), tuple((
take_while(not_null),
take(1usize),
take_while(not_null),
take(1usize),
rest,
)),
|(imp, _, user, _, pass)| (imp, user, pass), |(imp, _, user, _, pass)| (imp, user, pass),
)(input) )(input)
} }
@ -746,7 +819,6 @@ impl Encode for MechanismParameters {
} }
} }
impl Encode for FailCode { impl Encode for FailCode {
fn encode(&self, out: &mut BytesMut) -> Result<()> { fn encode(&self, out: &mut BytesMut) -> Result<()> {
match self { match self {
@ -762,33 +834,32 @@ impl Encode for FailCode {
impl Encode for ServerCommand { impl Encode for ServerCommand {
fn encode(&self, out: &mut BytesMut) -> Result<()> { fn encode(&self, out: &mut BytesMut) -> Result<()> {
match self { match self {
Self::Version (Version { major, minor }) => { Self::Version(Version { major, minor }) => {
out.put(&b"VERSION"[..]); out.put(&b"VERSION"[..]);
tab_enc(out); tab_enc(out);
out.put(major.to_string().as_bytes()); out.put(major.to_string().as_bytes());
tab_enc(out); tab_enc(out);
out.put(minor.to_string().as_bytes()); out.put(minor.to_string().as_bytes());
lf_enc(out); lf_enc(out);
}, }
Self::Spid(pid) => { Self::Spid(pid) => {
out.put(&b"SPID"[..]); out.put(&b"SPID"[..]);
tab_enc(out); tab_enc(out);
out.put(pid.to_string().as_bytes()); out.put(pid.to_string().as_bytes());
lf_enc(out); lf_enc(out);
}, }
Self::Cuid(pid) => { Self::Cuid(pid) => {
out.put(&b"CUID"[..]); out.put(&b"CUID"[..]);
tab_enc(out); tab_enc(out);
out.put(pid.to_string().as_bytes()); out.put(pid.to_string().as_bytes());
lf_enc(out); lf_enc(out);
}, }
Self::Cookie(cval) => { Self::Cookie(cval) => {
out.put(&b"COOKIE"[..]); out.put(&b"COOKIE"[..]);
tab_enc(out); tab_enc(out);
out.put(hex::encode(cval).as_bytes()); out.put(hex::encode(cval).as_bytes());
lf_enc(out); lf_enc(out);
}
},
Self::Mech { kind, parameters } => { Self::Mech { kind, parameters } => {
out.put(&b"MECH"[..]); out.put(&b"MECH"[..]);
tab_enc(out); tab_enc(out);
@ -798,11 +869,11 @@ impl Encode for ServerCommand {
p.encode(out)?; p.encode(out)?;
} }
lf_enc(out); lf_enc(out);
}, }
Self::Done => { Self::Done => {
out.put(&b"DONE"[..]); out.put(&b"DONE"[..]);
lf_enc(out); lf_enc(out);
}, }
Self::Cont { id, data } => { Self::Cont { id, data } => {
out.put(&b"CONT"[..]); out.put(&b"CONT"[..]);
tab_enc(out); tab_enc(out);
@ -813,8 +884,12 @@ impl Encode for ServerCommand {
out.put(b64.as_bytes()); out.put(b64.as_bytes());
} }
lf_enc(out); lf_enc(out);
}, }
Self::Ok { id, user_id, extra_parameters } => { Self::Ok {
id,
user_id,
extra_parameters,
} => {
out.put(&b"OK"[..]); out.put(&b"OK"[..]);
tab_enc(out); tab_enc(out);
out.put(id.to_string().as_bytes()); out.put(id.to_string().as_bytes());
@ -828,8 +903,13 @@ impl Encode for ServerCommand {
out.put(&p[..]); out.put(&p[..]);
} }
lf_enc(out); lf_enc(out);
}, }
Self::Fail {id, user_id, code, extra_parameters } => { Self::Fail {
id,
user_id,
code,
extra_parameters,
} => {
out.put(&b"FAIL"[..]); out.put(&b"FAIL"[..]);
tab_enc(out); tab_enc(out);
out.put(id.to_string().as_bytes()); out.put(id.to_string().as_bytes());
@ -848,7 +928,7 @@ impl Encode for ServerCommand {
out.put(&p[..]); out.put(&p[..]);
} }
lf_enc(out); lf_enc(out);
}, }
} }
Ok(()) Ok(())
} }

View file

@ -26,8 +26,8 @@ use imap_codec::imap_types::response::{Code, CommandContinuationRequest, Respons
use imap_codec::imap_types::{core::Text, response::Greeting}; use imap_codec::imap_types::{core::Text, response::Greeting};
use imap_flow::server::{ServerFlow, ServerFlowEvent, ServerFlowOptions}; use imap_flow::server::{ServerFlow, ServerFlowEvent, ServerFlowOptions};
use imap_flow::stream::AnyStream; use imap_flow::stream::AnyStream;
use tokio_rustls::TlsAcceptor;
use rustls_pemfile::{certs, private_key}; use rustls_pemfile::{certs, private_key};
use tokio_rustls::TlsAcceptor;
use crate::config::{ImapConfig, ImapUnsecureConfig}; use crate::config::{ImapConfig, ImapUnsecureConfig};
use crate::imap::capability::ServerCapability; use crate::imap::capability::ServerCapability;
@ -53,8 +53,14 @@ struct ClientContext {
} }
pub fn new(config: ImapConfig, login: ArcLoginProvider) -> Result<Server> { pub fn new(config: ImapConfig, login: ArcLoginProvider) -> Result<Server> {
let loaded_certs = certs(&mut std::io::BufReader::new(std::fs::File::open(config.certs)?)).collect::<Result<Vec<_>, _>>()?; let loaded_certs = certs(&mut std::io::BufReader::new(std::fs::File::open(
let loaded_key = private_key(&mut std::io::BufReader::new(std::fs::File::open(config.key)?))?.unwrap(); config.certs,
)?))
.collect::<Result<Vec<_>, _>>()?;
let loaded_key = private_key(&mut std::io::BufReader::new(std::fs::File::open(
config.key,
)?))?
.unwrap();
let tls_config = rustls::ServerConfig::builder() let tls_config = rustls::ServerConfig::builder()
.with_no_client_auth() .with_no_client_auth()
@ -109,7 +115,7 @@ impl Server {
} }
}; };
AnyStream::new(stream) AnyStream::new(stream)
}, }
None => AnyStream::new(socket), None => AnyStream::new(socket),
}; };

View file

@ -34,7 +34,12 @@ struct Args {
#[clap(long)] #[clap(long)]
dev: bool, dev: bool,
#[clap(short, long, env = "AEROGRAMME_CONFIG", default_value = "aerogramme.toml")] #[clap(
short,
long,
env = "AEROGRAMME_CONFIG",
default_value = "aerogramme.toml"
)]
/// Path to the main Aerogramme configuration file /// Path to the main Aerogramme configuration file
config_file: PathBuf, config_file: PathBuf,
} }
@ -187,7 +192,10 @@ async fn main() -> Result<()> {
hostname: "example.tld".to_string(), hostname: "example.tld".to_string(),
}), }),
auth: Some(AuthConfig { auth: Some(AuthConfig {
bind_addr: SocketAddr::new(IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)), 12345), bind_addr: SocketAddr::new(
IpAddr::V6(Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)),
12345,
),
}), }),
users: UserManagement::Demo, users: UserManagement::Demo,
}) })

View file

@ -7,9 +7,9 @@ use futures::try_join;
use log::*; use log::*;
use tokio::sync::watch; use tokio::sync::watch;
use crate::auth;
use crate::config::*; use crate::config::*;
use crate::imap; use crate::imap;
use crate::auth;
use crate::lmtp::*; use crate::lmtp::*;
use crate::login::ArcLoginProvider; use crate::login::ArcLoginProvider;
use crate::login::{demo_provider::*, ldap_provider::*, static_provider::*}; use crate::login::{demo_provider::*, ldap_provider::*, static_provider::*};
@ -47,9 +47,16 @@ impl Server {
}; };
let lmtp_server = config.lmtp.map(|lmtp| LmtpServer::new(lmtp, login.clone())); let lmtp_server = config.lmtp.map(|lmtp| LmtpServer::new(lmtp, login.clone()));
let imap_unsecure_server = config.imap_unsecure.map(|imap| imap::new_unsecure(imap, login.clone())); let imap_unsecure_server = config
let imap_server = config.imap.map(|imap| imap::new(imap, login.clone())).transpose()?; .imap_unsecure
let auth_server = config.auth.map(|auth| auth::AuthServer::new(auth, login.clone())); .map(|imap| imap::new_unsecure(imap, login.clone()));
let imap_server = config
.imap
.map(|imap| imap::new(imap, login.clone()))
.transpose()?;
let auth_server = config
.auth
.map(|auth| auth::AuthServer::new(auth, login.clone()));
Ok(Self { Ok(Self {
lmtp_server, lmtp_server,

View file

@ -105,6 +105,7 @@ fn causal_to_row_val(row_ref: RowRef, causal_value: k2v_client::CausalValue) ->
#[async_trait] #[async_trait]
impl IStore for GarageStore { impl IStore for GarageStore {
async fn row_fetch<'a>(&self, select: &Selector<'a>) -> Result<Vec<RowVal>, StorageError> { async fn row_fetch<'a>(&self, select: &Selector<'a>) -> Result<Vec<RowVal>, StorageError> {
tracing::trace!(select=%select, command="row_fetch");
let (pk_list, batch_op) = match select { let (pk_list, batch_op) = match select {
Selector::Range { Selector::Range {
shard, shard,
@ -196,21 +197,26 @@ impl IStore for GarageStore {
} }
Ok(v) => v, Ok(v) => v,
}; };
//println!("fetch res -> {:?}", all_raw_res);
let row_vals = all_raw_res let row_vals =
.into_iter() all_raw_res
.fold(vec![], |mut acc, v| { .into_iter()
acc.extend(v.items); .zip(pk_list.into_iter())
acc .fold(vec![], |mut acc, (page, pk)| {
}) page.items
.into_iter() .into_iter()
.zip(pk_list.into_iter()) .map(|(sk, cv)| causal_to_row_val(RowRef::new(&pk, &sk), cv))
.map(|((sk, cv), pk)| causal_to_row_val(RowRef::new(&pk, &sk), cv)) .for_each(|rr| acc.push(rr));
.collect::<Vec<_>>();
acc
});
tracing::debug!(fetch_count = row_vals.len(), command = "row_fetch");
Ok(row_vals) Ok(row_vals)
} }
async fn row_rm<'a>(&self, select: &Selector<'a>) -> Result<(), StorageError> { async fn row_rm<'a>(&self, select: &Selector<'a>) -> Result<(), StorageError> {
tracing::trace!(select=%select, command="row_rm");
let del_op = match select { let del_op = match select {
Selector::Range { Selector::Range {
shard, shard,
@ -280,6 +286,7 @@ impl IStore for GarageStore {
} }
async fn row_insert(&self, values: Vec<RowVal>) -> Result<(), StorageError> { async fn row_insert(&self, values: Vec<RowVal>) -> Result<(), StorageError> {
tracing::trace!(entries=%values.iter().map(|v| v.row_ref.to_string()).collect::<Vec<_>>().join(","), command="row_insert");
let batch_ops = values let batch_ops = values
.iter() .iter()
.map(|v| k2v_client::BatchInsertOp { .map(|v| k2v_client::BatchInsertOp {
@ -307,6 +314,7 @@ impl IStore for GarageStore {
} }
} }
async fn row_poll(&self, value: &RowRef) -> Result<RowVal, StorageError> { async fn row_poll(&self, value: &RowRef) -> Result<RowVal, StorageError> {
tracing::trace!(entry=%value, command="row_poll");
loop { loop {
if let Some(ct) = &value.causality { if let Some(ct) = &value.causality {
match self match self
@ -343,6 +351,7 @@ impl IStore for GarageStore {
} }
async fn blob_fetch(&self, blob_ref: &BlobRef) -> Result<BlobVal, StorageError> { async fn blob_fetch(&self, blob_ref: &BlobRef) -> Result<BlobVal, StorageError> {
tracing::trace!(entry=%blob_ref, command="blob_fetch");
let maybe_out = self let maybe_out = self
.s3 .s3
.get_object() .get_object()
@ -382,6 +391,7 @@ impl IStore for GarageStore {
Ok(bv) Ok(bv)
} }
async fn blob_insert(&self, blob_val: BlobVal) -> Result<(), StorageError> { async fn blob_insert(&self, blob_val: BlobVal) -> Result<(), StorageError> {
tracing::trace!(entry=%blob_val.blob_ref, command="blob_insert");
let streamable_value = s3::primitives::ByteStream::from(blob_val.value); let streamable_value = s3::primitives::ByteStream::from(blob_val.value);
let maybe_send = self let maybe_send = self
@ -406,6 +416,7 @@ impl IStore for GarageStore {
} }
} }
async fn blob_copy(&self, src: &BlobRef, dst: &BlobRef) -> Result<(), StorageError> { async fn blob_copy(&self, src: &BlobRef, dst: &BlobRef) -> Result<(), StorageError> {
tracing::trace!(src=%src, dst=%dst, command="blob_copy");
let maybe_copy = self let maybe_copy = self
.s3 .s3
.copy_object() .copy_object()
@ -433,6 +444,7 @@ impl IStore for GarageStore {
} }
} }
async fn blob_list(&self, prefix: &str) -> Result<Vec<BlobRef>, StorageError> { async fn blob_list(&self, prefix: &str) -> Result<Vec<BlobRef>, StorageError> {
tracing::trace!(prefix = prefix, command = "blob_list");
let maybe_list = self let maybe_list = self
.s3 .s3
.list_objects_v2() .list_objects_v2()
@ -462,6 +474,7 @@ impl IStore for GarageStore {
} }
} }
async fn blob_rm(&self, blob_ref: &BlobRef) -> Result<(), StorageError> { async fn blob_rm(&self, blob_ref: &BlobRef) -> Result<(), StorageError> {
tracing::trace!(entry=%blob_ref, command="blob_rm");
let maybe_delete = self let maybe_delete = self
.s3 .s3
.delete_object() .delete_object()