implemented business logic

This commit is contained in:
Quentin 2024-01-24 21:36:46 +01:00
parent bbb050e399
commit b86acd5ed0
Signed by: quentin
GPG key ID: E9602264D639FF68

View file

@ -1,7 +1,6 @@
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc;
use anyhow::{Result, anyhow}; use anyhow::{Result, anyhow, bail};
use futures::stream::{FuturesUnordered, StreamExt}; use futures::stream::{FuturesUnordered, StreamExt};
use tokio::io::BufStream; use tokio::io::BufStream;
use tokio::io::{AsyncBufReadExt, AsyncWriteExt}; use tokio::io::{AsyncBufReadExt, AsyncWriteExt};
@ -82,7 +81,7 @@ impl AuthServer {
}; };
tracing::info!("AUTH: accepted connection from {}", remote_addr); tracing::info!("AUTH: accepted connection from {}", remote_addr);
let conn = tokio::spawn(NetLoop::new(socket, must_exit.clone()).run_error()); let conn = tokio::spawn(NetLoop::new(socket, self.login_provider.clone(), must_exit.clone()).run_error());
connections.push(conn); connections.push(conn);
@ -97,15 +96,23 @@ impl AuthServer {
} }
struct NetLoop { struct NetLoop {
login: ArcLoginProvider,
stream: BufStream<TcpStream>, stream: BufStream<TcpStream>,
stop: watch::Receiver<bool>, stop: watch::Receiver<bool>,
state: State,
read_buf: Vec<u8>,
write_buf: BytesMut,
} }
impl NetLoop { impl NetLoop {
fn new(stream: TcpStream, stop: watch::Receiver<bool>) -> Self { fn new(stream: TcpStream, login: ArcLoginProvider, stop: watch::Receiver<bool>) -> Self {
Self { Self {
login,
stream: BufStream::new(stream), stream: BufStream::new(stream),
state: State::Init,
stop, stop,
read_buf: Vec::new(),
write_buf: BytesMut::new(),
} }
} }
@ -117,25 +124,39 @@ impl NetLoop {
} }
async fn run(mut self) -> Result<()> { async fn run(mut self) -> Result<()> {
let mut resp_buff = BytesMut::new();
let mut buff: Vec<u8> = Vec::new();
loop { loop {
buff.clear();
tokio::select! { tokio::select! {
read_res = self.stream.read_until(b'\n', &mut buff) => { read_res = self.stream.read_until(b'\n', &mut self.read_buf) => {
// Detect EOF / socket close
let bread = read_res?; let bread = read_res?;
if bread == 0 { if bread == 0 {
tracing::info!("Reading buffer empty, connection has been closed. Exiting AUTH session."); tracing::info!("Reading buffer empty, connection has been closed. Exiting AUTH session.");
return Ok(()) return Ok(())
} }
let (input, cmd) = client_command(&buff).map_err(|_| anyhow!("Unable to parse command"))?;
println!("input: {:?}, cmd: {:?}", input, cmd); // Parse command
ServerCommand::Version { let (_, cmd) = client_command(&self.read_buf).map_err(|_| anyhow!("Unable to parse command"))?;
major: 1, tracing::debug!(cmd=?cmd, "Received command");
minor: 2,
}.encode(&mut resp_buff)?; // Make some progress in our local state
self.stream.write_all(&resp_buff).await?; self.state.progress(cmd, &self.login).await;
if matches!(self.state, State::Error) {
bail!("Internal state is in error, previous logs explain what went wrong");
}
// Build response
let srv_cmds = self.state.response();
srv_cmds.iter().try_for_each(|r| r.encode(&mut self.write_buf))?;
// Send responses if at least one command response has been generated
if !srv_cmds.is_empty() {
self.stream.write_all(&self.write_buf).await?;
self.stream.flush().await?; self.stream.flush().await?;
}
// Reset buffers
self.read_buf.clear();
self.write_buf.clear();
}, },
_ = self.stop.changed() => { _ = self.stop.changed() => {
tracing::debug!("Server is stopping, quitting this runner"); tracing::debug!("Server is stopping, quitting this runner");
@ -146,13 +167,150 @@ impl NetLoop {
} }
} }
// -----------------------------------------------------------------
//
// BUSINESS LOGIC
//
// -----------------------------------------------------------------
use rand::prelude::*;
#[derive(Debug)]
enum AuthRes {
Success(String),
Failed(Option<String>, Option<FailCode>),
}
#[derive(Debug)]
enum State {
Error,
Init,
HandshakePart(Version),
HandshakeDone,
AuthPlainProgress {
id: u64,
},
AuthDone {
id: u64,
res: AuthRes
},
}
const SERVER_MAJOR: u64 = 1;
const SERVER_MINOR: u64 = 2;
impl State {
async fn progress(&mut self, cmd: ClientCommand, login: &ArcLoginProvider) {
let new_state = 'state: {
match (std::mem::replace(self, State::Error), cmd) {
(Self::Init, ClientCommand::Version(v)) => Self::HandshakePart(v),
(Self::HandshakePart(version), ClientCommand::Cpid(_cpid)) => {
if version.major != SERVER_MAJOR {
tracing::error!(client_major=version.major, server_major=SERVER_MAJOR, "Unsupported client major version");
break 'state Self::Error
}
Self::HandshakeDone
},
(Self::HandshakeDone { .. }, ClientCommand::Auth { id, mech, .. }) |
(Self::AuthDone { .. }, ClientCommand::Auth { id, mech, ..}) => {
if mech != Mechanism::Plain {
tracing::error!(mechanism=?mech, "Unsupported Authentication Mechanism");
break 'state Self::AuthDone { id, res: AuthRes::Failed(None, None) }
}
Self::AuthPlainProgress { id }
},
(Self::AuthPlainProgress { id }, ClientCommand::Cont { id: cid, data }) => {
// Check that ID matches
if cid != id {
tracing::error!(auth_id=id, cont_id=cid, "CONT id does not match AUTH id");
break 'state Self::AuthDone { id, res: AuthRes::Failed(None, None) }
}
// Check that we can extract user's login+pass
let (ubin, pbin) = match auth_plain(&data) {
Ok(([], ([], user, pass))) => (user, pass),
Ok(_) => {
tracing::error!("Impersonating user is not supported");
break 'state Self::AuthDone { id, res: AuthRes::Failed(None, None) }
}
Err(e) => {
tracing::error!(err=?e, "Could not parse the SASL PLAIN data chunk");
break 'state Self::AuthDone { id, res: AuthRes::Failed(None, None) }
},
};
// Try to convert it to UTF-8
let (user, password) = match (std::str::from_utf8(ubin), std::str::from_utf8(pbin)) {
(Ok(u), Ok(p)) => (u, p),
_ => {
tracing::error!("Username or password contain invalid UTF-8 characters");
break 'state Self::AuthDone { id, res: AuthRes::Failed(None, None) }
}
};
// Try to connect user
match login.login(user, password).await {
Ok(_) => Self::AuthDone { id, res: AuthRes::Success(user.to_string())},
Err(e) => {
tracing::warn!(err=?e, "login failed");
Self::AuthDone { id, res: AuthRes::Failed(Some(user.to_string()), None) }
}
}
},
_ => {
tracing::error!("This command is not valid in this context");
Self::Error
},
}
};
tracing::debug!(state=?new_state, "Made progress");
*self = new_state;
}
fn response(&self) -> Vec<ServerCommand> {
let mut srv_cmd: Vec<ServerCommand> = Vec::new();
match self {
Self::HandshakeDone { .. } => {
srv_cmd.push(ServerCommand::Version(Version { major: SERVER_MAJOR, minor: SERVER_MINOR }));
srv_cmd.push(ServerCommand::Spid(1u64));
srv_cmd.push(ServerCommand::Cuid(1u64));
let mut cookie = [0u8; 16];
thread_rng().fill(&mut cookie);
srv_cmd.push(ServerCommand::Cookie(cookie));
srv_cmd.push(ServerCommand::Mech {
kind: Mechanism::Plain,
parameters: vec![MechanismParameters::PlainText],
});
srv_cmd.push(ServerCommand::Done);
},
Self::AuthPlainProgress { id } => {
srv_cmd.push(ServerCommand::Cont { id: *id, data: None });
},
Self::AuthDone { id, res: AuthRes::Success(user) } => {
srv_cmd.push(ServerCommand::Ok { id: *id, user_id: Some(user.to_string()), extra_parameters: vec![]});
},
Self::AuthDone { id, res: AuthRes::Failed(maybe_user, maybe_failcode) } => {
srv_cmd.push(ServerCommand::Fail { id: *id, user_id: maybe_user.clone(), code: maybe_failcode.clone(), extra_parameters: vec![]});
},
_ => (),
};
srv_cmd
}
}
// ----------------------------------------------------------------- // -----------------------------------------------------------------
// //
// DOVECOT AUTH TYPES // DOVECOT AUTH TYPES
// //
// ------------------------------------------------------------------ // -----------------------------------------------------------------
#[derive(Debug, Clone)] #[derive(Debug, Clone, PartialEq)]
enum Mechanism { enum Mechanism {
Plain, Plain,
Login, Login,
@ -214,13 +372,16 @@ enum AuthOption {
Resp(Vec<u8>), Resp(Vec<u8>),
} }
#[derive(Debug, Clone)]
struct Version {
major: u64,
minor: u64,
}
#[derive(Debug)] #[derive(Debug)]
enum ClientCommand { enum ClientCommand {
/// Both client and server should check that they support the same major version number. If they dont, the other side isnt expected to be talking the same protocol and should be disconnected. Minor version can be ignored. This document specifies the version number 1.2. /// Both client and server should check that they support the same major version number. If they dont, the other side isnt expected to be talking the same protocol and should be disconnected. Minor version can be ignored. This document specifies the version number 1.2.
Version { Version(Version),
major: u64,
minor: u64,
},
/// CPID finishes the handshake from client. /// CPID finishes the handshake from client.
Cpid(u64), Cpid(u64),
Auth { Auth {
@ -261,7 +422,7 @@ enum MechanismParameters {
Private, Private,
} }
#[derive(Debug)] #[derive(Debug, Clone)]
enum FailCode { enum FailCode {
/// This is a temporary internal failure, e.g. connection was lost to SQL database. /// This is a temporary internal failure, e.g. connection was lost to SQL database.
TempFail, TempFail,
@ -276,10 +437,7 @@ enum FailCode {
#[derive(Debug)] #[derive(Debug)]
enum ServerCommand { enum ServerCommand {
/// Both client and server should check that they support the same major version number. If they dont, the other side isnt expected to be talking the same protocol and should be disconnected. Minor version can be ignored. This document specifies the version number 1.2. /// Both client and server should check that they support the same major version number. If they dont, the other side isnt expected to be talking the same protocol and should be disconnected. Minor version can be ignored. This document specifies the version number 1.2.
Version { Version(Version),
major: u64,
minor: u64,
},
/// CPID and SPID specify client and server Process Identifiers (PIDs). They should be unique identifiers for the specific process. UNIX process IDs are good choices. /// CPID and SPID specify client and server Process Identifiers (PIDs). They should be unique identifiers for the specific process. UNIX process IDs are good choices.
/// SPID can be used by authentication client to tell master which server process handled the authentication. /// SPID can be used by authentication client to tell master which server process handled the authentication.
Spid(u64), Spid(u64),
@ -298,18 +456,19 @@ enum ServerCommand {
Fail { Fail {
id: u64, id: u64,
user_id: Option<String>, user_id: Option<String>,
code: FailCode, code: Option<FailCode>,
extra_parameters: Vec<Vec<u8>>,
}, },
Cont { Cont {
id: u64, id: u64,
data: Vec<u8>, data: Option<Vec<u8>>,
}, },
/// FAIL and OK may contain multiple unspecified parameters which authentication client may handle specially. /// FAIL and OK may contain multiple unspecified parameters which authentication client may handle specially.
/// The only one specified here is user=<userid> parameter, which should always be sent if the userid is known. /// The only one specified here is user=<userid> parameter, which should always be sent if the userid is known.
Ok { Ok {
id: u64, id: u64,
user_id: Option<String>, user_id: Option<String>,
parameters: Vec<u8>, extra_parameters: Vec<Vec<u8>>,
}, },
} }
@ -324,9 +483,9 @@ use nom::{
branch::alt, branch::alt,
error::{ErrorKind, Error}, error::{ErrorKind, Error},
character::complete::{tab, u64, u16}, character::complete::{tab, u64, u16},
bytes::complete::{tag, tag_no_case, take, take_while, take_while1}, bytes::complete::{is_not, tag, tag_no_case, take, take_while, take_while1},
multi::{many1, separated_list0}, multi::{many1, separated_list0},
combinator::{map, opt, recognize, value,}, combinator::{map, opt, recognize, value, rest},
sequence::{pair, preceded, tuple}, sequence::{pair, preceded, tuple},
}; };
use base64::Engine; use base64::Engine;
@ -341,7 +500,7 @@ fn version_command<'a>(input: &'a [u8]) -> IResult<&'a [u8], ClientCommand> {
)); ));
let (input, (_, _, major, _, minor)) = parser(input)?; let (input, (_, _, major, _, minor)) = parser(input)?;
Ok((input, ClientCommand::Version { major, minor })) Ok((input, ClientCommand::Version(Version { major, minor })))
} }
fn cpid_command<'a>(input: &'a [u8]) -> IResult<&'a [u8], ClientCommand> { fn cpid_command<'a>(input: &'a [u8]) -> IResult<&'a [u8], ClientCommand> {
@ -510,6 +669,17 @@ fn server_command(buf: &u8) -> IResult<&u8, ServerCommand> {
} }
*/ */
// -----------------------------------------------------------------
//
// SASL DECODING
//
// -----------------------------------------------------------------
// impersonated user, login, password
fn auth_plain<'a>(input: &'a [u8]) -> IResult<&'a [u8], (&'a [u8], &'a [u8], &'a [u8])> {
tuple((is_not([0x0]), is_not([0x0]), rest))(input)
}
// ----------------------------------------------------------------- // -----------------------------------------------------------------
// //
// DOVECOT AUTH ENCODING // DOVECOT AUTH ENCODING
@ -531,7 +701,7 @@ fn lf_enc(out: &mut BytesMut) {
impl Encode for ServerCommand { impl Encode for ServerCommand {
fn encode(&self, out: &mut BytesMut) -> Result<()> { fn encode(&self, out: &mut BytesMut) -> Result<()> {
match self { match self {
Self::Version { major, minor } => { Self::Version (Version { major, minor }) => {
out.put(&b"VERSION"[..]); out.put(&b"VERSION"[..]);
tab_enc(out); tab_enc(out);
out.put(major.to_string().as_bytes()); out.put(major.to_string().as_bytes());
@ -544,9 +714,9 @@ impl Encode for ServerCommand {
Self::Mech { kind, parameters } => unimplemented!(), Self::Mech { kind, parameters } => unimplemented!(),
Self::Cookie(v) => unimplemented!(), Self::Cookie(v) => unimplemented!(),
Self::Done => unimplemented!(), Self::Done => unimplemented!(),
Self::Fail {id, user_id, code } => unimplemented!(),
Self::Cont { id, data } => unimplemented!(), Self::Cont { id, data } => unimplemented!(),
Self::Ok { id, user_id, parameters } => unimplemented!(), Self::Ok { id, user_id, extra_parameters } => unimplemented!(),
Self::Fail {id, user_id, code, extra_parameters } => unimplemented!(),
} }
Ok(()) Ok(())
} }