Skip to content

Commit

Permalink
Replace ServiceAddress by libsignal_protocol::ServiceId
Browse files Browse the repository at this point in the history
  • Loading branch information
gferon committed Oct 18, 2024
1 parent 026d751 commit 1d3e40e
Show file tree
Hide file tree
Showing 13 changed files with 172 additions and 312 deletions.
26 changes: 13 additions & 13 deletions src/account_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ use aes::cipher::{KeyIvInit, StreamCipher as _};
use hmac::digest::Output;
use hmac::{Hmac, Mac};
use libsignal_protocol::{
kem, GenericSignedPreKey, IdentityKey, IdentityKeyPair, IdentityKeyStore,
KeyPair, KyberPreKeyRecord, PrivateKey, ProtocolStore, PublicKey,
SenderKeyStore, SignedPreKeyRecord, Timestamp,
kem, Aci, GenericSignedPreKey, IdentityKey, IdentityKeyPair,
IdentityKeyStore, KeyPair, KyberPreKeyRecord, PrivateKey, ProtocolStore,
PublicKey, SenderKeyStore, SignedPreKeyRecord, Timestamp,
};
use prost::Message;
use serde::{Deserialize, Serialize};
Expand All @@ -33,10 +33,10 @@ use crate::push_service::{
DEFAULT_DEVICE_ID,
};
use crate::sender::OutgoingPushMessage;
use crate::service_address::ServiceIdExt;
use crate::session_store::SessionStoreExt;
use crate::timestamp::TimestampExt as _;
use crate::utils::{random_length_padding, BASE64_RELAXED};
use crate::ServiceAddress;
use crate::{
configuration::{Endpoint, ServiceCredentials},
pre_keys::PreKeyState,
Expand Down Expand Up @@ -486,7 +486,7 @@ impl AccountManager {

pub async fn retrieve_profile(
&mut self,
address: ServiceAddress,
address: Aci,
) -> Result<Profile, ProfileManagerError> {
let profile_key =
self.profile_key.expect("set profile key in AccountManager");
Expand Down Expand Up @@ -628,19 +628,19 @@ impl AccountManager {
/// Should be called as the primary device to migrate from pre-PNI to PNI.
///
/// This is the equivalent of Android's PnpInitializeDevicesJob or iOS' PniHelloWorldManager.
#[tracing::instrument(skip(self, aci_protocol_store, pni_protocol_store, sender, local_aci, csprng), fields(local_aci = %local_aci))]
#[tracing::instrument(skip(self, aci_protocol_store, pni_protocol_store, sender, local_aci, csprng), fields(local_aci =% local_aci.service_id_string()))]
pub async fn pnp_initialize_devices<
// XXX So many constraints here, all imposed by the MessageSender
R: rand::Rng + rand::CryptoRng,
Aci: PreKeysStore + SessionStoreExt,
Pni: PreKeysStore,
AciStore: PreKeysStore + SessionStoreExt,
PniStore: PreKeysStore,
AciOrPni: ProtocolStore + SenderKeyStore + SessionStoreExt + Sync + Clone,
>(
&mut self,
aci_protocol_store: &mut Aci,
pni_protocol_store: &mut Pni,
aci_protocol_store: &mut AciStore,
pni_protocol_store: &mut PniStore,
mut sender: MessageSender<AciOrPni, R>,
local_aci: ServiceAddress,
local_aci: Aci,
e164: PhoneNumber,
csprng: &mut R,
) -> Result<(), MessageSenderError> {
Expand All @@ -651,7 +651,7 @@ impl AccountManager {

// For every linked device, we generate a new set of pre-keys, and send them to the device.
let local_device_ids = aci_protocol_store
.get_sub_device_sessions(&local_aci)
.get_sub_device_sessions(&local_aci.into())
.await?;

let mut device_messages =
Expand Down Expand Up @@ -795,7 +795,7 @@ impl AccountManager {
let content: ContentBody = msg.into();
let msg = sender
.create_encrypted_message(
&local_aci,
&local_aci.into(),
None,
local_device_id.into(),
&content.into_proto().encode_to_vec(),
Expand Down
25 changes: 14 additions & 11 deletions src/cipher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@ use libsignal_protocol::{
CiphertextMessageType, DeviceId, IdentityKeyStore, KyberPreKeyStore,
PreKeySignalMessage, PreKeyStore, ProtocolAddress, ProtocolStore,
PublicKey, SealedSenderDecryptionResult, SenderCertificate,
SenderKeyDistributionMessage, SenderKeyStore, SessionStore, SignalMessage,
SignalProtocolError, SignedPreKeyStore, Timestamp,
SenderKeyDistributionMessage, SenderKeyStore, ServiceId, SessionStore,
SignalMessage, SignalProtocolError, SignedPreKeyStore, Timestamp,
};
use prost::Message;
use rand::{CryptoRng, Rng};
Expand All @@ -23,7 +23,7 @@ use crate::{
sender::OutgoingPushMessage,
session_store::SessionStoreExt,
utils::BASE64_RELAXED,
ServiceAddress,
ServiceIdExt,
};
/// Decrypts incoming messages and encrypts outgoing messages.
///
Expand Down Expand Up @@ -277,13 +277,16 @@ where
)
.await?;

let sender = ServiceAddress::try_from(sender_uuid.as_str())
.map_err(|e| {
tracing::error!("{:?}", e);
SignalProtocolError::InvalidSealedSenderMessage(
"invalid sender UUID".to_string(),
)
})?;
let sender =
ServiceId::parse_from_service_id_string(&sender_uuid)
.ok_or_else(|| {
tracing::error!(
"failed to parse ServiceId from string"
);
SignalProtocolError::InvalidSealedSenderMessage(
"invalid sender UUID".to_string(),
)
})?;

let needs_receipt = if envelope.source_service_id.is_some() {
tracing::warn!(?envelope, "Received an unidentified delivery over an identified channel. Marking needs_receipt=false");
Expand Down Expand Up @@ -461,7 +464,7 @@ fn strip_padding(contents: &mut Vec<u8>) -> Result<(), ServiceError> {
/// Equivalent of `SignalServiceCipher::getPreferredProtocolAddress`
pub async fn get_preferred_protocol_address<S: SessionStore>(
session_store: &S,
address: &ServiceAddress,
address: &ServiceId,
device_id: DeviceId,
) -> Result<ProtocolAddress, libsignal_protocol::error::SignalProtocolError> {
let address = address.to_protocol_address(device_id);
Expand Down
7 changes: 4 additions & 3 deletions src/content.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use libsignal_protocol::ProtocolAddress;
use libsignal_protocol::{ProtocolAddress, ServiceId};
use uuid::Uuid;

use crate::ServiceIdExt;
pub use crate::{
proto::{
attachment_pointer::Flags as AttachmentPointerFlags,
Expand All @@ -18,8 +19,8 @@ mod story_message;

#[derive(Clone, Debug)]
pub struct Metadata {
pub sender: crate::ServiceAddress,
pub destination: crate::ServiceAddress,
pub sender: ServiceId,
pub destination: ServiceId,
pub sender_device: u32,
pub timestamp: u64,
pub needs_receipt: bool,
Expand Down
58 changes: 12 additions & 46 deletions src/envelope.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,15 @@
use std::convert::{TryFrom, TryInto};

use aes::cipher::block_padding::Pkcs7;
use aes::cipher::{BlockDecryptMut, KeyIvInit};
use libsignal_protocol::ServiceId;
use prost::Message;

use crate::{
configuration::SignalingKey, push_service::ServiceError,
utils::serde_optional_base64, ParseServiceAddressError, ServiceAddress,
utils::serde_optional_base64,
};

pub use crate::proto::Envelope;

impl TryFrom<EnvelopeEntity> for Envelope {
type Error = ParseServiceAddressError;

fn try_from(entity: EnvelopeEntity) -> Result<Self, Self::Error> {
match entity.source_uuid.as_deref() {
Some(uuid) => {
let address = uuid.try_into()?;
Ok(Envelope::new_with_source(entity, address))
},
None => Ok(Envelope::new_from_entity(entity)),
}
}
}

impl Envelope {
#[tracing::instrument(skip(input, signaling_key), fields(input_size = input.len()))]
pub fn decrypt(
Expand Down Expand Up @@ -85,29 +70,6 @@ impl Envelope {
}
}

fn new_from_entity(entity: EnvelopeEntity) -> Self {
Envelope {
r#type: Some(entity.r#type),
timestamp: Some(entity.timestamp),
server_timestamp: Some(entity.server_timestamp),
server_guid: entity.source_uuid,
content: entity.content,
..Default::default()
}
}

fn new_with_source(entity: EnvelopeEntity, source: ServiceAddress) -> Self {
Envelope {
r#type: Some(entity.r#type),
source_device: Some(entity.source_device),
timestamp: Some(entity.timestamp),
server_timestamp: Some(entity.server_timestamp),
source_service_id: Some(source.uuid.to_string()),
content: entity.content,
..Default::default()
}
}

pub fn is_unidentified_sender(&self) -> bool {
self.r#type() == crate::proto::envelope::Type::UnidentifiedSender
}
Expand All @@ -133,18 +95,22 @@ impl Envelope {
self.story.unwrap_or(false)
}

pub fn source_address(&self) -> ServiceAddress {
pub fn source_address(&self) -> ServiceId {
match self.source_service_id.as_deref() {
Some(service_id) => ServiceAddress::try_from(service_id)
.expect("invalid source ProtocolAddress UUID or prefix"),
Some(service_id) => {
ServiceId::parse_from_service_id_string(service_id)
.expect("invalid source ProtocolAddress UUID or prefix")
},
None => panic!("source_service_id is set"),
}
}

pub fn destination_address(&self) -> ServiceAddress {
pub fn destination_address(&self) -> ServiceId {
match self.destination_service_id.as_deref() {
Some(service_id) => ServiceAddress::try_from(service_id)
.expect("invalid destination ProtocolAddress UUID or prefix"),
Some(service_id) => ServiceId::parse_from_service_id_string(
service_id,
)
.expect("invalid destination ProtocolAddress UUID or prefix"),
None => panic!("destination_address is set"),
}
}
Expand Down
5 changes: 2 additions & 3 deletions src/groups_v2/model.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
use std::{convert::TryFrom, convert::TryInto};

use derivative::Derivative;
use libsignal_protocol::ServiceId;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
use zkgroup::profiles::ProfileKey;

use crate::ServiceAddress;

use super::GroupDecodingError;

#[derive(Copy, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
Expand Down Expand Up @@ -34,7 +33,7 @@ impl PartialEq for Member {

#[derive(Clone, Debug, PartialEq, Eq)]
pub struct PendingMember {
pub address: ServiceAddress,
pub address: ServiceId,
pub role: Role,
pub added_by_uuid: Uuid,
pub timestamp: u64,
Expand Down
1 change: 0 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ pub const GROUP_UPDATE_FLAG: u32 = 1;
pub const GROUP_LEAVE_FLAG: u32 = 2;

pub mod prelude {
pub use super::ServiceAddress;
pub use crate::{
cipher::ServiceCipher,
configuration::{
Expand Down
26 changes: 15 additions & 11 deletions src/profile_service.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use libsignal_protocol::ServiceId;

use crate::{
proto::WebSocketRequestMessage,
push_service::{ServiceError, SignalServiceProfile},
push_service::{ServiceError, ServiceIdType, SignalServiceProfile},
websocket::SignalWebSocket,
ServiceAddress,
};

pub struct ProfileService {
Expand All @@ -16,21 +17,24 @@ impl ProfileService {

pub async fn retrieve_profile_by_id(
&mut self,
address: ServiceAddress,
address: ServiceId,
profile_key: Option<zkgroup::profiles::ProfileKey>,
) -> Result<SignalServiceProfile, ServiceError> {
let endpoint = match profile_key {
Some(key) => {
let endpoint = match (profile_key, address) {
(Some(key), ServiceId::Aci(aci)) => {
let version =
bincode::serialize(&key.get_profile_key_version(
address.aci().expect("profile by ACI ProtocolAddress"),
))?;
bincode::serialize(&key.get_profile_key_version(aci))?;
let version = std::str::from_utf8(&version)
.expect("hex encoded profile key version");
format!("/v1/profile/{}/{}", address.uuid, version)
format!("/v1/profile/{}/{}", address.raw_uuid(), version)
},
(Some(_), ServiceId::Pni(_)) => {
return Err(ServiceError::InvalidAddressType(
ServiceIdType::PhoneNumberIdentity,
))
},
None => {
format!("/v1/profile/{}", address.uuid)
(None, _) => {
format!("/v1/profile/{}", address.raw_uuid())
},
};

Expand Down
11 changes: 6 additions & 5 deletions src/push_service/error.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use libsignal_protocol::SignalProtocolError;
use zkgroup::ZkGroupDeserializationFailure;

use crate::{groups_v2::GroupDecodingError, ParseServiceAddressError};
use crate::groups_v2::GroupDecodingError;

use super::{
MismatchedDevices, ProofRequired, RegistrationLockFailure, StaleDevices,
MismatchedDevices, ProofRequired, RegistrationLockFailure, ServiceIdType,
StaleDevices,
};

#[derive(thiserror::Error, Debug)]
Expand All @@ -15,6 +16,9 @@ pub enum ServiceError {
#[error("invalid URL: {0}")]
InvalidUrl(#[from] url::ParseError),

#[error("wrong address type: {0}")]
InvalidAddressType(ServiceIdType),

#[error("Error sending request: {reason}")]
SendError { reason: String },

Expand Down Expand Up @@ -77,9 +81,6 @@ pub enum ServiceError {
#[error("unsupported content")]
UnsupportedContent,

#[error(transparent)]
ParseServiceAddress(#[from] ParseServiceAddressError),

#[error("Not found.")]
NotFoundError,

Expand Down
Loading

0 comments on commit 1d3e40e

Please sign in to comment.