From 2ed12b1a3d81458cddd4af4d07bfc7c83cf74929 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20F=C3=A9ron?= Date: Fri, 18 Oct 2024 11:30:38 +0200 Subject: [PATCH] Replace ServiceAddress by libsignal_protocol::ServiceId --- src/account_manager.rs | 26 +++--- src/cipher.rs | 25 +++--- src/content.rs | 7 +- src/envelope.rs | 58 +++---------- src/groups_v2/model.rs | 5 +- src/lib.rs | 1 - src/profile_service.rs | 10 ++- src/push_service/error.rs | 5 +- src/push_service/keys.rs | 20 +++-- src/push_service/profile.rs | 14 +-- src/sender.rs | 100 ++++++++++----------- src/service_address.rs | 169 +++++------------------------------- src/session_store.rs | 22 ++--- 13 files changed, 158 insertions(+), 304 deletions(-) diff --git a/src/account_manager.rs b/src/account_manager.rs index d8251c92b..10beb8822 100644 --- a/src/account_manager.rs +++ b/src/account_manager.rs @@ -7,9 +7,9 @@ use aes::cipher::{KeyIvInit, StreamCipher as _}; use hmac::digest::Output; use hmac::{Hmac, Mac}; use libsignal_protocol::{ - kem, GenericSignedPreKey, IdentityKey, IdentityKeyPair, IdentityKeyStore, - KeyPair, KyberPreKeyRecord, PrivateKey, ProtocolStore, PublicKey, - SenderKeyStore, SignedPreKeyRecord, Timestamp, + kem, Aci, GenericSignedPreKey, IdentityKey, IdentityKeyPair, + IdentityKeyStore, KeyPair, KyberPreKeyRecord, PrivateKey, ProtocolStore, + PublicKey, SenderKeyStore, SignedPreKeyRecord, Timestamp, }; use prost::Message; use serde::{Deserialize, Serialize}; @@ -33,10 +33,10 @@ use crate::push_service::{ DEFAULT_DEVICE_ID, }; use crate::sender::OutgoingPushMessage; +use crate::service_address::ServiceIdExt; use crate::session_store::SessionStoreExt; use crate::timestamp::TimestampExt as _; use crate::utils::{random_length_padding, BASE64_RELAXED}; -use crate::ServiceAddress; use crate::{ configuration::{Endpoint, ServiceCredentials}, pre_keys::PreKeyState, @@ -486,7 +486,7 @@ impl AccountManager { pub async fn retrieve_profile( &mut self, - address: ServiceAddress, + address: Aci, ) -> Result { let profile_key = self.profile_key.expect("set profile key in AccountManager"); @@ -628,19 +628,19 @@ impl AccountManager { /// Should be called as the primary device to migrate from pre-PNI to PNI. /// /// This is the equivalent of Android's PnpInitializeDevicesJob or iOS' PniHelloWorldManager. - #[tracing::instrument(skip(self, aci_protocol_store, pni_protocol_store, sender, local_aci, csprng), fields(local_aci = %local_aci))] + #[tracing::instrument(skip(self, aci_protocol_store, pni_protocol_store, sender, local_aci, csprng), fields(local_aci =% local_aci.service_id_string()))] pub async fn pnp_initialize_devices< // XXX So many constraints here, all imposed by the MessageSender R: rand::Rng + rand::CryptoRng, - Aci: PreKeysStore + SessionStoreExt, - Pni: PreKeysStore, + AciStore: PreKeysStore + SessionStoreExt, + PniStore: PreKeysStore, AciOrPni: ProtocolStore + SenderKeyStore + SessionStoreExt + Sync + Clone, >( &mut self, - aci_protocol_store: &mut Aci, - pni_protocol_store: &mut Pni, + aci_protocol_store: &mut AciStore, + pni_protocol_store: &mut PniStore, mut sender: MessageSender, - local_aci: ServiceAddress, + local_aci: Aci, e164: PhoneNumber, csprng: &mut R, ) -> Result<(), MessageSenderError> { @@ -651,7 +651,7 @@ impl AccountManager { // For every linked device, we generate a new set of pre-keys, and send them to the device. let local_device_ids = aci_protocol_store - .get_sub_device_sessions(&local_aci) + .get_sub_device_sessions(&local_aci.into()) .await?; let mut device_messages = @@ -795,7 +795,7 @@ impl AccountManager { let content: ContentBody = msg.into(); let msg = sender .create_encrypted_message( - &local_aci, + &local_aci.into(), None, local_device_id.into(), &content.into_proto().encode_to_vec(), diff --git a/src/cipher.rs b/src/cipher.rs index 09649b679..6c447ae63 100644 --- a/src/cipher.rs +++ b/src/cipher.rs @@ -9,8 +9,8 @@ use libsignal_protocol::{ CiphertextMessageType, DeviceId, IdentityKeyStore, KyberPreKeyStore, PreKeySignalMessage, PreKeyStore, ProtocolAddress, ProtocolStore, PublicKey, SealedSenderDecryptionResult, SenderCertificate, - SenderKeyDistributionMessage, SenderKeyStore, SessionStore, SignalMessage, - SignalProtocolError, SignedPreKeyStore, Timestamp, + SenderKeyDistributionMessage, SenderKeyStore, ServiceId, SessionStore, + SignalMessage, SignalProtocolError, SignedPreKeyStore, Timestamp, }; use prost::Message; use rand::{CryptoRng, Rng}; @@ -23,7 +23,7 @@ use crate::{ sender::OutgoingPushMessage, session_store::SessionStoreExt, utils::BASE64_RELAXED, - ServiceAddress, + ServiceIdExt, }; /// Decrypts incoming messages and encrypts outgoing messages. /// @@ -277,13 +277,16 @@ where ) .await?; - let sender = ServiceAddress::try_from(sender_uuid.as_str()) - .map_err(|e| { - tracing::error!("{:?}", e); - SignalProtocolError::InvalidSealedSenderMessage( - "invalid sender UUID".to_string(), - ) - })?; + let sender = + ServiceId::parse_from_service_id_string(&sender_uuid) + .ok_or_else(|| { + tracing::error!( + "failed to parse ServiceId from string" + ); + SignalProtocolError::InvalidSealedSenderMessage( + "invalid sender UUID".to_string(), + ) + })?; let needs_receipt = if envelope.source_service_id.is_some() { tracing::warn!(?envelope, "Received an unidentified delivery over an identified channel. Marking needs_receipt=false"); @@ -461,7 +464,7 @@ fn strip_padding(contents: &mut Vec) -> Result<(), ServiceError> { /// Equivalent of `SignalServiceCipher::getPreferredProtocolAddress` pub async fn get_preferred_protocol_address( session_store: &S, - address: &ServiceAddress, + address: &ServiceId, device_id: DeviceId, ) -> Result { let address = address.to_protocol_address(device_id); diff --git a/src/content.rs b/src/content.rs index 8fead84b8..d2cf411ba 100644 --- a/src/content.rs +++ b/src/content.rs @@ -1,6 +1,7 @@ -use libsignal_protocol::ProtocolAddress; +use libsignal_protocol::{ProtocolAddress, ServiceId}; use uuid::Uuid; +use crate::ServiceIdExt; pub use crate::{ proto::{ attachment_pointer::Flags as AttachmentPointerFlags, @@ -18,8 +19,8 @@ mod story_message; #[derive(Clone, Debug)] pub struct Metadata { - pub sender: crate::ServiceAddress, - pub destination: crate::ServiceAddress, + pub sender: ServiceId, + pub destination: ServiceId, pub sender_device: u32, pub timestamp: u64, pub needs_receipt: bool, diff --git a/src/envelope.rs b/src/envelope.rs index 1a33de669..563225112 100644 --- a/src/envelope.rs +++ b/src/envelope.rs @@ -1,30 +1,15 @@ -use std::convert::{TryFrom, TryInto}; - use aes::cipher::block_padding::Pkcs7; use aes::cipher::{BlockDecryptMut, KeyIvInit}; +use libsignal_protocol::ServiceId; use prost::Message; use crate::{ configuration::SignalingKey, push_service::ServiceError, - utils::serde_optional_base64, ParseServiceAddressError, ServiceAddress, + utils::serde_optional_base64, }; pub use crate::proto::Envelope; -impl TryFrom for Envelope { - type Error = ParseServiceAddressError; - - fn try_from(entity: EnvelopeEntity) -> Result { - match entity.source_uuid.as_deref() { - Some(uuid) => { - let address = uuid.try_into()?; - Ok(Envelope::new_with_source(entity, address)) - }, - None => Ok(Envelope::new_from_entity(entity)), - } - } -} - impl Envelope { #[tracing::instrument(skip(input, signaling_key), fields(input_size = input.len()))] pub fn decrypt( @@ -85,29 +70,6 @@ impl Envelope { } } - fn new_from_entity(entity: EnvelopeEntity) -> Self { - Envelope { - r#type: Some(entity.r#type), - timestamp: Some(entity.timestamp), - server_timestamp: Some(entity.server_timestamp), - server_guid: entity.source_uuid, - content: entity.content, - ..Default::default() - } - } - - fn new_with_source(entity: EnvelopeEntity, source: ServiceAddress) -> Self { - Envelope { - r#type: Some(entity.r#type), - source_device: Some(entity.source_device), - timestamp: Some(entity.timestamp), - server_timestamp: Some(entity.server_timestamp), - source_service_id: Some(source.uuid.to_string()), - content: entity.content, - ..Default::default() - } - } - pub fn is_unidentified_sender(&self) -> bool { self.r#type() == crate::proto::envelope::Type::UnidentifiedSender } @@ -133,18 +95,22 @@ impl Envelope { self.story.unwrap_or(false) } - pub fn source_address(&self) -> ServiceAddress { + pub fn source_address(&self) -> ServiceId { match self.source_service_id.as_deref() { - Some(service_id) => ServiceAddress::try_from(service_id) - .expect("invalid source ProtocolAddress UUID or prefix"), + Some(service_id) => { + ServiceId::parse_from_service_id_string(service_id) + .expect("invalid source ProtocolAddress UUID or prefix") + }, None => panic!("source_service_id is set"), } } - pub fn destination_address(&self) -> ServiceAddress { + pub fn destination_address(&self) -> ServiceId { match self.destination_service_id.as_deref() { - Some(service_id) => ServiceAddress::try_from(service_id) - .expect("invalid destination ProtocolAddress UUID or prefix"), + Some(service_id) => ServiceId::parse_from_service_id_string( + service_id, + ) + .expect("invalid destination ProtocolAddress UUID or prefix"), None => panic!("destination_address is set"), } } diff --git a/src/groups_v2/model.rs b/src/groups_v2/model.rs index 92be11d0d..75990ba59 100644 --- a/src/groups_v2/model.rs +++ b/src/groups_v2/model.rs @@ -1,12 +1,11 @@ use std::{convert::TryFrom, convert::TryInto}; use derivative::Derivative; +use libsignal_protocol::ServiceId; use serde::{Deserialize, Serialize}; use uuid::Uuid; use zkgroup::profiles::ProfileKey; -use crate::ServiceAddress; - use super::GroupDecodingError; #[derive(Copy, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] @@ -34,7 +33,7 @@ impl PartialEq for Member { #[derive(Clone, Debug, PartialEq, Eq)] pub struct PendingMember { - pub address: ServiceAddress, + pub address: ServiceId, pub role: Role, pub added_by_uuid: Uuid, pub timestamp: u64, diff --git a/src/lib.rs b/src/lib.rs index 2a881ca8d..daa81561a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -49,7 +49,6 @@ pub const GROUP_UPDATE_FLAG: u32 = 1; pub const GROUP_LEAVE_FLAG: u32 = 2; pub mod prelude { - pub use super::ServiceAddress; pub use crate::{ cipher::ServiceCipher, configuration::{ diff --git a/src/profile_service.rs b/src/profile_service.rs index e23a4b51b..d989255b1 100644 --- a/src/profile_service.rs +++ b/src/profile_service.rs @@ -1,8 +1,10 @@ +use libsignal_protocol::ServiceId; + use crate::{ proto::WebSocketRequestMessage, push_service::{ServiceError, SignalServiceProfile}, websocket::SignalWebSocket, - ServiceAddress, + ServiceIdExt, }; pub struct ProfileService { @@ -16,7 +18,7 @@ impl ProfileService { pub async fn retrieve_profile_by_id( &mut self, - address: ServiceAddress, + address: ServiceId, profile_key: Option, ) -> Result { let endpoint = match profile_key { @@ -27,10 +29,10 @@ impl ProfileService { ))?; let version = std::str::from_utf8(&version) .expect("hex encoded profile key version"); - format!("/v1/profile/{}/{}", address.uuid, version) + format!("/v1/profile/{}/{}", address.raw_uuid(), version) }, None => { - format!("/v1/profile/{}", address.uuid) + format!("/v1/profile/{}", address.raw_uuid()) }, }; diff --git a/src/push_service/error.rs b/src/push_service/error.rs index d197667d9..1a7e1675f 100644 --- a/src/push_service/error.rs +++ b/src/push_service/error.rs @@ -1,7 +1,7 @@ use libsignal_protocol::SignalProtocolError; use zkgroup::ZkGroupDeserializationFailure; -use crate::{groups_v2::GroupDecodingError, ParseServiceAddressError}; +use crate::groups_v2::GroupDecodingError; use super::{ MismatchedDevices, ProofRequired, RegistrationLockFailure, StaleDevices, @@ -77,9 +77,6 @@ pub enum ServiceError { #[error("unsupported content")] UnsupportedContent, - #[error(transparent)] - ParseServiceAddress(#[from] ParseServiceAddressError), - #[error("Not found.")] NotFoundError, diff --git a/src/push_service/keys.rs b/src/push_service/keys.rs index 50894d1f2..665a1ab88 100644 --- a/src/push_service/keys.rs +++ b/src/push_service/keys.rs @@ -1,6 +1,8 @@ use std::collections::HashMap; -use libsignal_protocol::{IdentityKey, PreKeyBundle, SenderCertificate}; +use libsignal_protocol::{ + IdentityKey, PreKeyBundle, SenderCertificate, ServiceId, +}; use serde::Deserialize; use crate::{ @@ -9,7 +11,6 @@ use crate::{ push_service::PreKeyResponse, sender::OutgoingPushMessage, utils::serde_base64, - ServiceAddress, }; use super::{ @@ -60,11 +61,14 @@ impl PushService { pub async fn get_pre_key( &mut self, - destination: &ServiceAddress, + destination: &ServiceId, device_id: u32, ) -> Result { - let path = - format!("/v2/keys/{}/{}?pq=true", destination.uuid, device_id); + let path = format!( + "/v2/keys/{}/{}?pq=true", + destination.raw_uuid(), + device_id + ); let mut pre_key_response: PreKeyResponse = self .get_json( @@ -83,13 +87,13 @@ impl PushService { pub(crate) async fn get_pre_keys( &mut self, - destination: &ServiceAddress, + destination: &ServiceId, device_id: u32, ) -> Result, ServiceError> { let path = if device_id == 1 { - format!("/v2/keys/{}/*?pq=true", destination.uuid) + format!("/v2/keys/{}/*?pq=true", destination.raw_uuid()) } else { - format!("/v2/keys/{}/{}?pq=true", destination.uuid, device_id) + format!("/v2/keys/{}/{}?pq=true", destination.raw_uuid(), device_id) }; let pre_key_response: PreKeyResponse = self .get_json( diff --git a/src/push_service/profile.rs b/src/push_service/profile.rs index 0a444fef6..709355db6 100644 --- a/src/push_service/profile.rs +++ b/src/push_service/profile.rs @@ -1,3 +1,4 @@ +use libsignal_protocol::Aci; use serde::{Deserialize, Serialize}; use zkgroup::profiles::{ProfileKeyCommitment, ProfileKeyVersion}; @@ -7,7 +8,7 @@ use crate::{ profile_cipher::ProfileCipherError, push_service::{AvatarWrite, HttpAuthOverride}, utils::{serde_base64, serde_optional_base64}, - Profile, ServiceAddress, + Profile, }; use super::{DeviceCapabilities, PushService}; @@ -88,18 +89,17 @@ struct SignalServiceProfileWrite<'s> { impl PushService { pub async fn retrieve_profile_by_id( &mut self, - address: ServiceAddress, + address: Aci, profile_key: Option, ) -> Result { let endpoint = if let Some(key) = profile_key { - let version = bincode::serialize(&key.get_profile_key_version( - address.aci().expect("profile by ACI ProtocolAddress"), - ))?; + let version = + bincode::serialize(&key.get_profile_key_version(address))?; let version = std::str::from_utf8(&version) .expect("hex encoded profile key version"); - format!("/v1/profile/{}/{}", address.uuid, version) + format!("/v1/profile/{}/{}", address.service_id_string(), version) } else { - format!("/v1/profile/{}", address.uuid) + format!("/v1/profile/{}", address.service_id_string()) }; // TODO: set locale to en_US self.get_json( diff --git a/src/sender.rs b/src/sender.rs index ae8a862e1..a3c3c53a2 100644 --- a/src/sender.rs +++ b/src/sender.rs @@ -2,8 +2,9 @@ use std::{collections::HashSet, time::SystemTime}; use chrono::prelude::*; use libsignal_protocol::{ - process_prekey_bundle, DeviceId, IdentityKey, IdentityKeyPair, - ProtocolStore, SenderCertificate, SenderKeyStore, SignalProtocolError, + process_prekey_bundle, Aci, DeviceId, IdentityKey, IdentityKeyPair, Pni, + ProtocolStore, SenderCertificate, SenderKeyStore, ServiceId, + SignalProtocolError, }; use rand::{CryptoRng, Rng}; use tracing::{error, info, trace}; @@ -24,10 +25,10 @@ use crate::{ AttachmentPointer, SyncMessage, }, push_service::*, + service_address::ServiceIdExt, session_store::SessionStoreExt, unidentified_access::UnidentifiedAccess, websocket::SignalWebSocket, - ServiceAddress, }; pub use crate::proto::{ContactDetails, GroupDetails}; @@ -59,7 +60,7 @@ pub type SendMessageResult = Result; #[derive(Debug, Clone)] pub struct SentMessage { - pub recipient: ServiceAddress, + pub recipient: ServiceId, pub used_identity_key: IdentityKey, pub unidentified: bool, pub needs_sync: bool, @@ -91,8 +92,8 @@ pub struct MessageSender { cipher: ServiceCipher, csprng: R, protocol_store: S, - local_aci: ServiceAddress, - local_pni: ServiceAddress, + local_aci: Aci, + local_pni: Pni, aci_identity: IdentityKeyPair, pni_identity: Option, device_id: DeviceId, @@ -117,7 +118,7 @@ pub enum MessageSenderError { AttachmentUploadError(#[from] AttachmentUploadError), #[error("Untrusted identity key with {address:?}")] - UntrustedIdentity { address: ServiceAddress }, + UntrustedIdentity { address: ServiceId }, #[error("Exceeded maximum number of retries")] MaximumRetriesLimitExceeded, @@ -126,7 +127,7 @@ pub enum MessageSenderError { ProofRequired { token: String, options: Vec }, #[error("Recipient not found: {addr:?}")] - NotFound { addr: ServiceAddress }, + NotFound { addr: ServiceId }, } pub type GroupV2Id = [u8; GROUP_IDENTIFIER_LEN]; @@ -150,8 +151,8 @@ where cipher: ServiceCipher, csprng: R, protocol_store: S, - local_aci: impl Into, - local_pni: impl Into, + local_aci: impl Into, + local_pni: impl Into, aci_identity: IdentityKeyPair, pni_identity: Option, device_id: DeviceId, @@ -304,7 +305,7 @@ where async fn is_multi_device(&self) -> bool { if self.device_id == DEFAULT_DEVICE_ID.into() { self.protocol_store - .get_sub_device_sessions(&self.local_aci) + .get_sub_device_sessions(&self.local_aci.into()) .await .map_or(false, |s| !s.is_empty()) } else { @@ -314,12 +315,12 @@ where /// Send a message `content` to a single `recipient`. #[tracing::instrument( - skip(self, unidentified_access, message, recipient), - fields(unidentified_access = unidentified_access.is_some(), recipient = %recipient), + skip(self, unidentified_access, message), + fields(unidentified_access = unidentified_access.is_some(), recipient = recipient.service_id_string()), )] pub async fn send_message( &mut self, - recipient: &ServiceAddress, + recipient: &ServiceId, mut unidentified_access: Option, message: impl Into, timestamp: u64, @@ -378,7 +379,7 @@ where Some(&result), ); self.try_send_message( - self.local_aci, + self.local_aci.into(), None, &sync_message, timestamp, @@ -390,7 +391,11 @@ where if end_session { let n = self.protocol_store.delete_all_sessions(recipient).await?; - tracing::debug!("ended {} sessions with {}", n, recipient.uuid); + tracing::debug!( + "ended {} sessions with {}", + n, + recipient.raw_uuid() + ); } result @@ -408,7 +413,7 @@ where )] pub async fn send_message_to_group( &mut self, - recipients: impl AsRef<[(ServiceAddress, Option, bool)]>, + recipients: impl AsRef<[(ServiceId, Option, bool)]>, message: impl Into, timestamp: u64, online: bool, @@ -476,7 +481,7 @@ where // See Signal Android `SignalServiceMessageSender.java:2817` if let Err(error) = self .try_send_message( - self.local_aci, + self.local_aci.into(), None, &sync_message, timestamp, @@ -496,11 +501,11 @@ where #[tracing::instrument( level = "trace", skip(self, unidentified_access, content_body, recipient), - fields(unidentified_access = unidentified_access.is_some(), recipient = %recipient), + fields(unidentified_access = unidentified_access.is_some(), recipient = recipient.service_id_string()), )] async fn try_send_message( &mut self, - recipient: ServiceAddress, + recipient: ServiceId, mut unidentified_access: Option<&UnidentifiedAccess>, content_body: &ContentBody, timestamp: u64, @@ -526,7 +531,7 @@ where .await?; let messages = OutgoingPushMessages { - destination: recipient.uuid, + destination: recipient.raw_uuid(), timestamp, messages, online, @@ -646,11 +651,11 @@ where /// Upload contact details to the CDN and send a sync message #[tracing::instrument( skip(self, unidentified_access, contacts, recipient), - fields(unidentified_access = unidentified_access.is_some(), recipient = %recipient), + fields(unidentified_access = unidentified_access.is_some(), recipient = recipient.service_id_string()), )] pub async fn send_contact_details( &mut self, - recipient: &ServiceAddress, + recipient: &ServiceId, unidentified_access: Option, // XXX It may be interesting to use an intermediary type, // instead of ContactDetails directly, @@ -686,10 +691,10 @@ where } /// Send `Configuration` synchronization message - #[tracing::instrument(skip(self, recipient), fields(recipient = %recipient))] + #[tracing::instrument(skip(self), fields(recipient = recipient.service_id_string()))] pub async fn send_configuration( &mut self, - recipient: &ServiceAddress, + recipient: &ServiceId, configuration: sync_message::Configuration, ) -> Result<(), MessageSenderError> { let msg = SyncMessage { @@ -705,10 +710,10 @@ where } /// Send `MessageRequestResponse` synchronization message with either a recipient ACI or a GroupV2 ID - #[tracing::instrument(skip(self, recipient), fields(recipient = %recipient))] + #[tracing::instrument(skip(self), fields(recipient = recipient.service_id_string()))] pub async fn send_message_request_response( &mut self, - recipient: &ServiceAddress, + recipient: &ServiceId, thread: &ThreadIdentifier, action: message_request_response::Type, ) -> Result<(), MessageSenderError> { @@ -752,10 +757,10 @@ where } /// Send `Keys` synchronization message - #[tracing::instrument(skip(self))] + #[tracing::instrument(skip(self), fields(recipient = recipient.service_id_string()))] pub async fn send_keys( &mut self, - recipient: &ServiceAddress, + recipient: &ServiceId, keys: sync_message::Keys, ) -> Result<(), MessageSenderError> { let msg = SyncMessage { @@ -774,7 +779,7 @@ where #[tracing::instrument(skip(self))] pub async fn send_sync_message_request( &mut self, - recipient: &ServiceAddress, + recipient: &ServiceId, request_type: sync_message::request::Type, ) -> Result<(), MessageSenderError> { if self.device_id == DEFAULT_DEVICE_ID.into() { @@ -813,7 +818,7 @@ where &mut self.csprng, )?; Ok(crate::proto::PniSignatureMessage { - pni: Some(self.local_pni.uuid.as_bytes().to_vec()), + pni: Some(self.local_pni.service_id_binary()), signature: Some(signature.into()), }) } @@ -821,12 +826,12 @@ where // Equivalent with `getEncryptedMessages` #[tracing::instrument( level = "trace", - skip(self, unidentified_access, content, recipient), - fields(unidentified_access = unidentified_access.is_some(), recipient = %recipient), + skip(self, unidentified_access, content), + fields(unidentified_access = unidentified_access.is_some(), recipient = recipient.service_id_string()), )] async fn create_encrypted_messages( &mut self, - recipient: &ServiceAddress, + recipient: &ServiceId, unidentified_access: Option<&SenderCertificate>, content: &[u8], ) -> Result<(Vec, IdentityKey), MessageSenderError> @@ -845,18 +850,14 @@ where devices.insert(DEFAULT_DEVICE_ID.into()); // never try to send messages to the sender device - match recipient.identity { - ServiceIdType::AccountIdentity => { - if recipient.aci().is_some() - && recipient.aci() == self.local_aci.aci() - { + match recipient { + ServiceId::Aci(aci) => { + if *aci == self.local_aci { devices.remove(&self.device_id); } }, - ServiceIdType::PhoneNumberIdentity => { - if recipient.pni().is_some() - && recipient.pni() == self.local_aci.pni() - { + ServiceId::Pni(pni) => { + if *pni == self.local_pni { devices.remove(&self.device_id); } }, @@ -931,12 +932,12 @@ where /// When no session with the recipient exists, we need to create one. #[tracing::instrument( level = "trace", - skip(self, unidentified_access, content, recipient), - fields(unidentified_access = unidentified_access.is_some(), recipient = %recipient), + skip(self, unidentified_access, content), + fields(unidentified_access = unidentified_access.is_some(), recipient = recipient.service_id_string()), )] pub(crate) async fn create_encrypted_message( &mut self, - recipient: &ServiceAddress, + recipient: &ServiceId, unidentified_access: Option<&SenderCertificate>, device_id: DeviceId, content: &[u8], @@ -1016,7 +1017,7 @@ where fn create_multi_device_sent_transcript_content<'a>( &self, - recipient: Option<&ServiceAddress>, + recipient: Option<&ServiceId>, data_message: Option, edit_message: Option, timestamp: u64, @@ -1036,7 +1037,7 @@ where } = sent; UnidentifiedDeliveryStatus { destination_service_id: Some( - recipient.uuid.to_string(), + recipient.service_id_string(), ), unidentified: Some(*unidentified), destination_identity_key: Some( @@ -1047,7 +1048,8 @@ where .collect(); ContentBody::SynchronizeMessage(SyncMessage { sent: Some(sync_message::Sent { - destination_service_id: recipient.map(|r| r.uuid.to_string()), + destination_service_id: recipient + .map(ServiceId::service_id_string), destination_e164: None, expiration_start_timestamp: data_message .as_ref() diff --git a/src/service_address.rs b/src/service_address.rs index de3680879..de83363ab 100644 --- a/src/service_address.rs +++ b/src/service_address.rs @@ -1,160 +1,39 @@ -use std::convert::TryFrom; +use libsignal_protocol::{Aci, DeviceId, Pni, ProtocolAddress, ServiceId}; -use libsignal_protocol::{DeviceId, ProtocolAddress, ServiceId}; -use uuid::Uuid; - -pub use crate::push_service::ServiceIdType; - -#[derive(thiserror::Error, Debug, Clone)] -pub enum ParseServiceAddressError { - #[error("Supplied UUID could not be parsed")] - InvalidUuid(#[from] uuid::Error), - - #[error("Envelope without UUID")] - NoUuid, -} +pub trait ServiceIdExt { + fn to_protocol_address( + self, + device_id: impl Into, + ) -> ProtocolAddress; -#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)] -pub struct ServiceAddress { - pub uuid: Uuid, - pub identity: ServiceIdType, -} + fn aci(self) -> Option; -impl std::fmt::Display for ServiceAddress { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - // This is used in ServiceAddress::to_service_id(&self), so keep this consistent. - match self.identity { - ServiceIdType::AccountIdentity => write!(f, "{}", self.uuid), - ServiceIdType::PhoneNumberIdentity => { - write!(f, "PNI:{}", self.uuid) - }, - } - } + fn pni(self) -> Option; } -impl ServiceAddress { - pub fn to_protocol_address( - &self, +impl ServiceIdExt for A +where + A: Into, +{ + fn to_protocol_address( + self, device_id: impl Into, ) -> ProtocolAddress { - match self.identity { - ServiceIdType::AccountIdentity => { - ProtocolAddress::new(self.uuid.to_string(), device_id.into()) - }, - ServiceIdType::PhoneNumberIdentity => ProtocolAddress::new( - format!("PNI:{}", self.uuid), - device_id.into(), - ), - } - } - - #[deprecated] - pub fn new_aci(uuid: Uuid) -> Self { - Self::from_aci(uuid) - } - - pub fn from_aci(uuid: Uuid) -> Self { - Self { - uuid, - identity: ServiceIdType::AccountIdentity, - } - } - - #[deprecated] - pub fn new_pni(uuid: Uuid) -> Self { - Self::from_pni(uuid) - } - - pub fn from_pni(uuid: Uuid) -> Self { - Self { - uuid, - identity: ServiceIdType::PhoneNumberIdentity, - } - } - - pub fn aci(&self) -> Option { - use libsignal_protocol::Aci; - match self.identity { - ServiceIdType::AccountIdentity => { - Some(Aci::from_uuid_bytes(self.uuid.into_bytes())) - }, - ServiceIdType::PhoneNumberIdentity => None, - } - } - - pub fn pni(&self) -> Option { - use libsignal_protocol::Pni; - match self.identity { - ServiceIdType::AccountIdentity => None, - ServiceIdType::PhoneNumberIdentity => { - Some(Pni::from_uuid_bytes(self.uuid.into_bytes())) - }, - } - } - - pub fn to_service_id(&self) -> String { - self.to_string() - } -} - -impl From for ServiceAddress { - fn from(service_id: ServiceId) -> Self { - match service_id { - ServiceId::Aci(service_id) => { - ServiceAddress::from_aci(service_id.into()) - }, - ServiceId::Pni(service_id) => { - ServiceAddress::from_pni(service_id.into()) - }, - } - } -} - -impl TryFrom<&ProtocolAddress> for ServiceAddress { - type Error = ParseServiceAddressError; - - fn try_from(addr: &ProtocolAddress) -> Result { - let value = addr.name(); - if let Some(pni) = value.strip_prefix("PNI:") { - Ok(ServiceAddress::from_pni(Uuid::parse_str(pni)?)) - } else { - Ok(ServiceAddress::from_aci(Uuid::parse_str(value)?)) - } - .map_err(|e| { - tracing::error!("Parsing ServiceAddress from {:?}", addr); - ParseServiceAddressError::InvalidUuid(e) - }) + let service_id: ServiceId = self.into(); + ProtocolAddress::new(service_id.service_id_string(), device_id.into()) } -} - -impl TryFrom<&str> for ServiceAddress { - type Error = ParseServiceAddressError; - fn try_from(value: &str) -> Result { - if let Some(pni) = value.strip_prefix("PNI:") { - Ok(ServiceAddress::from_pni(Uuid::parse_str(pni)?)) - } else { - Ok(ServiceAddress::from_aci(Uuid::parse_str(value)?)) + fn aci(self) -> Option { + match self.into() { + ServiceId::Aci(aci) => Some(aci), + ServiceId::Pni(_) => None, } - .map_err(|e| { - tracing::error!("Parsing ServiceAddress from '{}'", value); - ParseServiceAddressError::InvalidUuid(e) - }) } -} - -impl TryFrom<&[u8]> for ServiceAddress { - type Error = ParseServiceAddressError; - fn try_from(value: &[u8]) -> Result { - if let Some(pni) = value.strip_prefix(b"PNI:") { - Ok(ServiceAddress::from_pni(Uuid::from_slice(pni)?)) - } else { - Ok(ServiceAddress::from_aci(Uuid::from_slice(value)?)) + fn pni(self) -> Option { + match self.into() { + ServiceId::Aci(_) => None, + ServiceId::Pni(pni) => Some(pni), } - .map_err(|e| { - tracing::error!("Parsing ServiceAddress from {:?}", value); - ParseServiceAddressError::InvalidUuid(e) - }) } } diff --git a/src/session_store.rs b/src/session_store.rs index affbff09b..983b18717 100644 --- a/src/session_store.rs +++ b/src/session_store.rs @@ -1,7 +1,9 @@ use async_trait::async_trait; -use libsignal_protocol::{ProtocolAddress, SessionStore, SignalProtocolError}; +use libsignal_protocol::{ + ProtocolAddress, ServiceId, SessionStore, SignalProtocolError, +}; -use crate::{push_service::DEFAULT_DEVICE_ID, ServiceAddress}; +use crate::push_service::DEFAULT_DEVICE_ID; /// This is additional functions required to handle /// session deletion. It might be a candidate for inclusion into @@ -13,7 +15,7 @@ pub trait SessionStoreExt: SessionStore { /// This should return every device except for the main device [DEFAULT_DEVICE_ID]. async fn get_sub_device_sessions( &self, - name: &ServiceAddress, + name: &ServiceId, ) -> Result, SignalProtocolError>; /// Remove a session record for a recipient ID + device ID tuple. @@ -28,7 +30,7 @@ pub trait SessionStoreExt: SessionStore { /// Returns the number of deleted sessions. async fn delete_all_sessions( &self, - address: &ServiceAddress, + address: &ServiceId, ) -> Result; /// Remove a session record for a recipient ID + device ID tuple. @@ -48,10 +50,10 @@ pub trait SessionStoreExt: SessionStore { Ok(count) } - async fn compute_safety_number<'s>( - &'s self, - local_address: &'s ServiceAddress, - address: &'s ServiceAddress, + async fn compute_safety_number( + &self, + local_address: &ServiceId, + address: &ServiceId, ) -> Result where Self: Sized + libsignal_protocol::IdentityKeyStore, @@ -73,9 +75,9 @@ pub trait SessionStoreExt: SessionStore { let fp = libsignal_protocol::Fingerprint::new( 2, 5200, - local_address.uuid.as_bytes(), + local_address.raw_uuid().as_bytes(), local.identity_key(), - address.uuid.as_bytes(), + address.raw_uuid().as_bytes(), &ident, )?; fp.display_string()