diff --git a/examples/basalt.rs b/examples/basalt.rs index 5628056..318e37c 100644 --- a/examples/basalt.rs +++ b/examples/basalt.rs @@ -94,7 +94,7 @@ async fn main() { info!("KYEV SK {}", hex::encode(&privkey)); info!("KYEV PK {}", hex::encode(&privkey.public_key())); - let netapp = NetApp::new(netid, privkey); + let netapp = NetApp::new(0u64, netid, privkey); let mut bootstrap_peers = vec![]; for peer in opt.bootstrap_peers.iter() { diff --git a/examples/fullmesh.rs b/examples/fullmesh.rs index afc4deb..b068410 100644 --- a/examples/fullmesh.rs +++ b/examples/fullmesh.rs @@ -71,7 +71,7 @@ async fn main() { info!("Node public address: {:?}", public_addr); info!("Node listen address: {}", listen_addr); - let netapp = NetApp::new(netid.clone(), privkey.clone()); + let netapp = NetApp::new(0u64, netid.clone(), privkey.clone()); let mut bootstrap_peers = vec![]; for peer in opt.bootstrap_peers.iter() { diff --git a/src/client.rs b/src/client.rs index e84c85e..27cd1b8 100644 --- a/src/client.rs +++ b/src/client.rs @@ -52,6 +52,7 @@ impl ClientConn { let remote_addr = socket.peer_addr()?; let mut socket = socket.compat(); + // Do handshake to authenticate and prove our identity to server let handshake = handshake_client( &mut socket, netapp.netid.clone(), @@ -67,11 +68,25 @@ impl ClientConn { remote_addr ); + // Create BoxStream layer that encodes content let (read, write) = socket.split(); - - let (read, write) = + let (mut read, write) = BoxStream::from_handshake(read, write, handshake, 0x8000).split_read_write(); + // Before doing anything, receive version tag and + // check they are running the same version as us + let mut their_version_tag = VersionTag::default(); + read.read_exact(&mut their_version_tag[..]).await?; + if their_version_tag != netapp.version_tag { + let msg = format!( + "Different netapp versions: {:?} (theirs) vs. {:?} (ours)", + their_version_tag, netapp.version_tag + ); + error!("{}", msg); + return Err(Error::VersionMismatch(msg)); + } + + // Build and launch stuff that manages sending requests client-side let (query_send, query_recv) = mpsc::unbounded_channel(); let (stop_recv_loop, stop_recv_loop_recv) = watch::channel(false); diff --git a/src/netapp.rs b/src/netapp.rs index acaed62..1ac5f37 100644 --- a/src/netapp.rs +++ b/src/netapp.rs @@ -24,6 +24,14 @@ use crate::proto::*; use crate::server::*; use crate::util::*; +/// Tag which is exchanged between client and server upon connection establishment +/// to check that they are running compatible versions of Netapp, +/// composed of 8 bytes for Netapp version and 8 bytes for client version +pub(crate) type VersionTag = [u8; 16]; + +/// Value of the Netapp version used in the version tag +pub(crate) const NETAPP_VERSION_TAG: u64 = 0x6e65746170700004; // netapp 0x0004 + #[derive(Serialize, Deserialize, Debug)] pub(crate) struct HelloMessage { pub server_addr: Option, @@ -48,6 +56,8 @@ type OnDisconnectHandler = Box; pub struct NetApp { listen_params: ArcSwapOption, + /// Version tag, 8 bytes for netapp version, 8 bytes for app version + pub version_tag: VersionTag, /// Network secret key pub netid: auth::Key, /// Our peer ID @@ -76,10 +86,15 @@ impl NetApp { /// using `.listen()` /// /// Our Peer ID is the public key associated to the secret key given here. - pub fn new(netid: auth::Key, privkey: ed25519::SecretKey) -> Arc { + pub fn new(app_version_tag: u64, netid: auth::Key, privkey: ed25519::SecretKey) -> Arc { + let mut version_tag = [0u8; 16]; + version_tag[0..8].copy_from_slice(&u64::to_be_bytes(NETAPP_VERSION_TAG)[..]); + version_tag[8..16].copy_from_slice(&u64::to_be_bytes(app_version_tag)[..]); + let id = privkey.public_key(); let netapp = Arc::new(Self { listen_params: ArcSwapOption::new(None), + version_tag, netid, id, privkey, diff --git a/src/proto.rs b/src/proto.rs index 146211b..e843bff 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -1,7 +1,7 @@ use std::collections::{HashMap, VecDeque}; use std::sync::Arc; -use log::{error, trace}; +use log::trace; use futures::{AsyncReadExt, AsyncWriteExt}; use kuska_handshake::async_std::BoxStreamWrite; @@ -12,10 +12,6 @@ use async_trait::async_trait; use crate::error::*; -/// Tag which is exchanged between client and server upon connection establishment -/// to check that they are running compatible versions of Netapp -pub const VERSION_TAG: [u8; 8] = [b'n', b'e', b't', b'a', b'p', b'p', 0x00, 0x04]; - /// Priority of a request (click to read more about priorities). /// /// This priority value is used to priorize messages @@ -118,10 +114,6 @@ pub(crate) trait SendLoop: Sync { where W: AsyncWriteExt + Unpin + Send + Sync, { - // Before anything, send version tag, which is checked in recv_loop - write.write_all(&VERSION_TAG[..]).await?; - write.flush().await?; - let mut sending = SendQueue::new(); let mut should_exit = false; while !should_exit || !sending.is_empty() { @@ -198,17 +190,6 @@ pub(crate) trait RecvLoop: Sync + 'static { where R: AsyncReadExt + Unpin + Send + Sync, { - let mut their_version_tag = [0u8; 8]; - read.read_exact(&mut their_version_tag[..]).await?; - if their_version_tag != VERSION_TAG { - let msg = format!( - "Different netapp versions: {:?} (theirs) vs. {:?} (ours)", - their_version_tag, VERSION_TAG - ); - error!("{}", msg); - return Err(Error::VersionMismatch(msg)); - } - let mut receiving = HashMap::new(); loop { trace!("recv_loop: reading packet"); diff --git a/src/server.rs b/src/server.rs index 31f6ad6..5465307 100644 --- a/src/server.rs +++ b/src/server.rs @@ -20,7 +20,7 @@ use tokio::select; use tokio::sync::{mpsc, watch}; use tokio_util::compat::*; -use futures::io::AsyncReadExt; +use futures::io::{AsyncReadExt, AsyncWriteExt}; use async_trait::async_trait; @@ -67,6 +67,7 @@ impl ServerConn { let remote_addr = socket.peer_addr()?; let mut socket = socket.compat(); + // Do handshake to authenticate client let handshake = handshake_server( &mut socket, netapp.netid.clone(), @@ -82,11 +83,17 @@ impl ServerConn { remote_addr ); + // Create BoxStream layer that encodes content let (read, write) = socket.split(); - - let (read, write) = + let (read, mut write) = BoxStream::from_handshake(read, write, handshake, 0x8000).split_read_write(); + // Before doing anything, send version tag, so that client + // can check and disconnect if version is wrong + write.write_all(&netapp.version_tag[..]).await?; + write.flush().await?; + + // Build and launch stuff that handles requests server-side let (resp_send, resp_recv) = mpsc::unbounded_channel(); let conn = Arc::new(ServerConn {