Skip to content

Commit

Permalink
Merge pull request #100 from mineshp-mecha/chore/networking
Browse files Browse the repository at this point in the history
chore: header enabled for handshake channel
  • Loading branch information
shoaibmerchant authored Apr 25, 2024
2 parents 7f7451b + d766068 commit 37abb67
Show file tree
Hide file tree
Showing 11 changed files with 163 additions and 104 deletions.
16 changes: 15 additions & 1 deletion commons/nats-client/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ pub use async_nats::Subscriber;
pub use bytes::Bytes;
use events::Event;
use nkeys::KeyPair;
use std::collections::HashMap;
use std::{str::FromStr, sync::Arc};
use tokio::sync::broadcast::Sender;
use tracing::{debug, error, info, trace};
Expand Down Expand Up @@ -147,7 +148,12 @@ impl NatsClient {
Ok(client)
}

pub async fn publish(&self, subject: &str, data: Bytes) -> Result<bool> {
pub async fn publish(
&self,
subject: &str,
req_headers: Option<HashMap<String, String>>,
data: Bytes,
) -> Result<bool> {
trace!(
func = "publish",
package = PACKAGE_NAME,
Expand All @@ -174,6 +180,14 @@ impl NatsClient {
"X-Agent",
async_nats::HeaderValue::from_str(version_detail.as_str()).unwrap(),
);
if req_headers.is_some() {
for (k, v) in req_headers.unwrap() {
headers.insert(
k.as_str(),
async_nats::HeaderValue::from_str(v.as_str()).unwrap(),
);
}
}
debug!(
func = "publish",
package = PACKAGE_NAME,
Expand Down
1 change: 1 addition & 0 deletions grpc-server/src/services/messaging.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ impl MessagingService for MessagingServiceHandler {
reply_to: tx,
message: message_request.message,
subject: message_request.subject,
headers: None,
})
.await;

Expand Down
7 changes: 5 additions & 2 deletions messaging/src/handler.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::collections::HashMap;

use crate::errors::{MessagingError, MessagingErrorCodes};
use crate::service::{get_machine_id, Messaging};
use anyhow::{bail, Result};
Expand Down Expand Up @@ -32,6 +34,7 @@ pub enum MessagingMessage {
reply_to: oneshot::Sender<Result<bool>>,
message: String,
subject: String,
headers: Option<HashMap<String, String>>,
},
Request {
reply_to: oneshot::Sender<Result<Bytes>>,
Expand Down Expand Up @@ -78,8 +81,8 @@ impl MessagingHandler {
continue;
}
match msg.unwrap() {
MessagingMessage::Send{reply_to, message, subject} => {
let res = self.messaging_client.publish(&subject.as_str(), Bytes::from(message)).await;
MessagingMessage::Send{reply_to, message, subject, headers} => {
let res = self.messaging_client.publish(&subject.as_str(), headers, Bytes::from(message)).await;
let _ = reply_to.send(res);
}
MessagingMessage::Request{reply_to, message, subject} => {
Expand Down
12 changes: 10 additions & 2 deletions messaging/src/service.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use std::collections::HashMap;

use agent_settings::read_settings_yml;
use agent_settings::{messaging::MessagingSettings, AgentSettings};
use anyhow::{bail, Result};
Expand Down Expand Up @@ -197,7 +199,12 @@ impl Messaging {
);
Ok(true)
}
pub async fn publish(&self, subject: &str, data: Bytes) -> Result<bool> {
pub async fn publish(
&self,
subject: &str,
headers: Option<HashMap<String, String>>,
data: Bytes,
) -> Result<bool> {
let fn_name = "publish";
debug!(
func = fn_name,
Expand All @@ -219,7 +226,7 @@ impl Messaging {
}

let nats_client = self.nats_client.as_ref().unwrap();
let is_published = match nats_client.publish(subject, data).await {
let is_published = match nats_client.publish(subject, headers, data).await {
Ok(s) => s,
Err(e) => {
error!(
Expand Down Expand Up @@ -408,6 +415,7 @@ pub async fn authenticate(
package = PACKAGE_NAME,
"authentication token obtained!"
);
println!("Auth token: -{:?}", token);
Ok(token)
}

Expand Down
4 changes: 4 additions & 0 deletions networking/src/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub enum NetworkingErrorCodes {
MessageAcknowledgeError,
NetworkingInitError,
NetworkingDiscoSocketBindError,
ExtractMessageHeadersError,
}

impl fmt::Display for NetworkingErrorCodes {
Expand Down Expand Up @@ -58,6 +59,9 @@ impl fmt::Display for NetworkingErrorCodes {
NetworkingErrorCodes::NetworkingDiscoSocketBindError => {
write!(f, "NetworkingErrorCodes: NetworkingDiscoSocketBindError")
}
NetworkingErrorCodes::ExtractMessageHeadersError => {
write!(f, "NetworkingErrorCodes: ExtractMessageHeadersError")
}
}
}
}
Expand Down
19 changes: 16 additions & 3 deletions networking/src/handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ use wireguard::Wireguard;

use crate::errors::{NetworkingError, NetworkingErrorCodes};
use crate::handshake_handler::{
await_networking_handshake_request, HandshakeChannelHandler, HandshakeMessage, Manifest,
await_networking_handshake_message, HandshakeChannelHandler, HandshakeMessage, Manifest,
};
use crate::service::{
await_consumer_message, configure_wireguard, create_channel_sync_consumer,
await_consumer_message, configure_wireguard, create_channel_sync_consumer, get_machine_id,
get_networking_subscriber, publish_networking_channel, reconnect_messaging_service,
};

Expand Down Expand Up @@ -92,7 +92,7 @@ impl NetworkingHandler {
}
};
let mut futures = JoinSet::new();
futures.spawn(await_networking_handshake_request(
futures.spawn(await_networking_handshake_message(
subscribers.handshake_request.unwrap(),
handshake_handler.handshake_tx.clone(),
));
Expand Down Expand Up @@ -127,6 +127,18 @@ impl NetworkingHandler {
if exist_consumer_token.is_some() {
let _ = exist_consumer_token.as_ref().unwrap().cancel();
}
let machine_id = match get_machine_id(self.identity_tx.clone()).await {
Ok(id) => id,
Err(e) => {
error!(
func = fn_name,
package = PACKAGE_NAME,
error = e.to_string().as_str(),
"Error getting machine id"
);
bail!(e)
}
};
//TODO: handle this error unwrap
let handshake_handler = self.handshake_handler.as_ref().unwrap();
// create a new token
Expand Down Expand Up @@ -161,6 +173,7 @@ impl NetworkingHandler {
messaging_tx.clone(),
self.settings_tx.clone(),
handshake_handler.channel_id.clone(),
machine_id,
));
// create spawn for timer
let _: JoinHandle<Result<()>> = tokio::task::spawn(async move {
Expand Down
119 changes: 93 additions & 26 deletions networking/src/handshake_handler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@ use std::str::FromStr;

use agent_settings::{read_settings_yml, AgentSettings};
use anyhow::{bail, Result};
use chrono::format;
use crypto::random::generate_random_alphanumeric;
use futures::StreamExt;
use local_ip_address::list_afinet_netifas;
use messaging::async_nats::Message;
use messaging::async_nats::{HeaderMap, Message};
use messaging::handler::MessagingMessage;
use messaging::Subscriber as NatsSubscriber;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -187,6 +188,8 @@ impl HandshakeChannelHandler {
txn_id: txn_id,
candidates: Some(candidates),
};
let mut header_map: HashMap<String, String> = HashMap::new();
header_map.insert(String::from("Message-Type"), String::from("REPLY"));
println!("manifest: {:?}", manifest);
// send reply to NATS
let (tx, _rx) = oneshot::channel();
Expand All @@ -196,6 +199,7 @@ impl HandshakeChannelHandler {
subject: reply_subject,
message: json!(manifest).to_string(),
reply_to: tx,
headers: Some(header_map),
})
.await;
Ok(true)
Expand Down Expand Up @@ -241,7 +245,7 @@ pub fn discover_endpoints() -> Result<Vec<Ipv4Addr>> {
Ok(ipv4_addr)
}

pub async fn await_networking_handshake_request(
pub async fn await_networking_handshake_message(
mut subscriber: NatsSubscriber,
handshake_tx: mpsc::Sender<HandshakeMessage>,
) -> Result<()> {
Expand Down Expand Up @@ -281,37 +285,100 @@ async fn process_handshake_request(
))
}
};
let request_payload: ChannelDetails = match serde_json::from_str(&payload_str) {
Ok(s) => s,
Err(e) => bail!(NetworkingError::new(
NetworkingErrorCodes::PayloadDeserializationError,
format!("error while deserializing message payload {}", e),
true
)),
let message_type =
match get_header_by_key(message.headers.clone(), String::from("Message-Type")) {
Ok(s) => s,
Err(e) => {
error! {
func = fn_name,
package = PACKAGE_NAME,
"error getting message type from headers - {}",
e
};
bail!(e)
}
};
match message_type.as_str() {
"REQUEST" => {
let request_payload: ChannelDetails = match serde_json::from_str(&payload_str) {
Ok(s) => s,
Err(e) => bail!(NetworkingError::new(
NetworkingErrorCodes::PayloadDeserializationError,
format!("error while deserializing message payload {}", e),
true
)),
};
info!(
func = fn_name,
package = PACKAGE_NAME,
"received handshake request: {:?}",
request_payload
);
let reply_subject = format!(
"network.{}.node.handshake.{}",
sha256::digest(request_payload.network_id.clone()),
request_payload.channel.clone()
);
let _ = handshake_tx
.send(HandshakeMessage::Request {
machine_id: request_payload.machine_id.clone(),
reply_subject: reply_subject,
})
.await;
}
"REPLY" => {
let reply_payload: Manifest = match serde_json::from_str(&payload_str) {
Ok(s) => s,
Err(e) => bail!(NetworkingError::new(
NetworkingErrorCodes::PayloadDeserializationError,
format!("error while deserializing message payload {}", e),
true
)),
};
println!("manifest received: {:?}", reply_payload);
}
_ => {
warn!(
func = fn_name,
package = PACKAGE_NAME,
"Unknown message type: {}",
message_type
);
}
}
Ok(true)
}

fn get_header_by_key(headers: Option<HeaderMap>, header_key: String) -> Result<String> {
let fn_name = "get_header_by_key";
let message_headers = match headers {
Some(h) => h,
None => {
warn!(
func = fn_name,
package = PACKAGE_NAME,
"No headers found in message",
);
bail!(NetworkingError::new(
NetworkingErrorCodes::ExtractMessageHeadersError,
String::from("no headers found in message"),
false
))
}
};
info!(
func = fn_name,
package = PACKAGE_NAME,
"received handshake request: {:?}",
request_payload
);
let reply_to_subject = match message.reply {
Some(subject) => subject.to_string(),
let message_type = match message_headers.get(header_key.as_str()) {
Some(v) => v.to_string(),
None => {
warn!(
func = fn_name,
package = PACKAGE_NAME,
"No reply subject found in message: {:?}", message
"No message type found in message headers: {:?}",
message_headers
);
String::from("") //TODO: need to discuss
String::from("")
}
};
let _ = handshake_tx
.send(HandshakeMessage::Request {
machine_id: request_payload.machine_id.clone(),
reply_subject: reply_to_subject,
})
.await;
Ok(true)
Ok(message_type)
}
pub async fn create_disco_socket(addr: String) -> Result<UdpSocket> {
info!(func = "create_disco_socket", package = PACKAGE_NAME, "init");
Expand Down
Loading

0 comments on commit 37abb67

Please sign in to comment.