Skip to content

Commit

Permalink
Various cleanups after merging ACI/PNI (#298)
Browse files Browse the repository at this point in the history
 * Use enum variants in a bunch of errors instead of static strings
 * Rename VerifyAccountResponse::uuid to VerifyAccountResponse::aci for completion (but use a serde alias)
 * Use IdentityKey in some places where PublicKey was used before (and change the relevant serde modules)
  • Loading branch information
gferon authored Apr 13, 2024
1 parent 93c23cf commit 26c036e
Show file tree
Hide file tree
Showing 12 changed files with 202 additions and 176 deletions.
17 changes: 15 additions & 2 deletions libsignal-service/examples/storage.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use libsignal_service::pre_keys::{PreKeysStore, ServiceKyberPreKeyStore};
use libsignal_service::pre_keys::{KyberPreKeyStoreExt, PreKeysStore};
use libsignal_service::protocol::{
Direction, IdentityKey, IdentityKeyPair, IdentityKeyStore, KyberPreKeyId,
KyberPreKeyRecord, KyberPreKeyStore, PreKeyId, PreKeyRecord, PreKeyStore,
Expand Down Expand Up @@ -92,7 +92,7 @@ impl SignedPreKeyStore for ExampleStore {
}

#[async_trait::async_trait(?Send)]
impl ServiceKyberPreKeyStore for ExampleStore {
impl KyberPreKeyStoreExt for ExampleStore {
async fn store_last_resort_kyber_pre_key(
&mut self,
_kyber_prekey_id: KyberPreKeyId,
Expand Down Expand Up @@ -227,6 +227,19 @@ impl PreKeysStore for ExampleStore {
) -> Result<(), SignalProtocolError> {
todo!()
}

async fn signed_pre_keys_count(
&self,
) -> Result<usize, SignalProtocolError> {
todo!()
}

async fn kyber_pre_keys_count(
&self,
_last_resort: bool,
) -> Result<usize, SignalProtocolError> {
todo!()
}
}

#[allow(dead_code)]
Expand Down
44 changes: 17 additions & 27 deletions libsignal-service/src/account_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use hmac::{Hmac, Mac};
use libsignal_protocol::{
kem, GenericSignedPreKey, IdentityKey, IdentityKeyStore, KeyPair,
KyberPreKeyRecord, PrivateKey, ProtocolStore, PublicKey, SenderKeyStore,
SignalProtocolError, SignedPreKeyRecord,
SignedPreKeyRecord,
};
use prost::Message;
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -64,20 +64,6 @@ pub enum ProfileManagerError {
ProfileCipherError(#[from] ProfileCipherError),
}

#[derive(thiserror::Error, Debug)]
pub enum LinkError {
#[error(transparent)]
ServiceError(#[from] ServiceError),
#[error("TsUrl has an invalid UUID field")]
InvalidUuid,
#[error("TsUrl has an invalid pub_key field")]
InvalidPublicKey,
#[error("Protocol error {0}")]
ProtocolError(#[from] SignalProtocolError),
#[error(transparent)]
ProvisioningError(#[from] ProvisioningError),
}

#[derive(Debug, Default, Serialize, Deserialize, Clone)]
pub struct Profile {
pub name: Option<ProfileName<String>>,
Expand Down Expand Up @@ -111,7 +97,6 @@ impl<Service: PushService> AccountManager<Service> {
service_id_type: ServiceIdType,
csprng: &mut R,
use_last_resort_key: bool,
force: bool,
) -> Result<(), ServiceError> {
let prekey_status = match self
.service
Expand Down Expand Up @@ -143,7 +128,9 @@ impl<Service: PushService> AccountManager<Service> {
if prekey_status.count >= PRE_KEY_MINIMUM
&& prekey_status.pq_count >= PRE_KEY_MINIMUM
{
if !force {
if protocol_store.signed_pre_keys_count().await? > 0
&& protocol_store.kyber_pre_keys_count(true).await? > 0
{
tracing::debug!("Available keys sufficient");
return Ok(());
}
Expand All @@ -159,6 +146,7 @@ impl<Service: PushService> AccountManager<Service> {
.load_last_resort_kyber_pre_keys()
.instrument(tracing::trace_span!("fetch last resort key"))
.await?;

// XXX: Maybe this check should be done in the generate_pre_keys function?
let has_last_resort_key = !last_resort_keys.is_empty();

Expand Down Expand Up @@ -186,7 +174,7 @@ impl<Service: PushService> AccountManager<Service> {
.transpose()?
};

let identity_key = *identity_key_pair.identity_key().public_key();
let identity_key = *identity_key_pair.identity_key();

let pre_keys: Vec<_> = pre_keys
.into_iter()
Expand Down Expand Up @@ -289,16 +277,18 @@ impl<Service: PushService> AccountManager<Service> {
aci_identity_store: &dyn IdentityKeyStore,
pni_identity_store: &dyn IdentityKeyStore,
credentials: ServiceCredentials,
) -> Result<(), LinkError> {
) -> Result<(), ProvisioningError> {
let query: HashMap<_, _> = url.query_pairs().collect();
let ephemeral_id = query.get("uuid").ok_or(LinkError::InvalidUuid)?;
let pub_key =
query.get("pub_key").ok_or(LinkError::InvalidPublicKey)?;
let ephemeral_id =
query.get("uuid").ok_or(ProvisioningError::MissingUuid)?;
let pub_key = query
.get("pub_key")
.ok_or(ProvisioningError::MissingPublicKey)?;
let pub_key = BASE64_RELAXED
.decode(&**pub_key)
.map_err(|_e| LinkError::InvalidPublicKey)?;
.map_err(|e| ProvisioningError::InvalidPublicKey(e.into()))?;
let pub_key = PublicKey::deserialize(&pub_key)
.map_err(|_e| LinkError::InvalidPublicKey)?;
.map_err(|e| ProvisioningError::InvalidPublicKey(e.into()))?;

let aci_identity_key_pair =
aci_identity_store.get_identity_key_pair().await?;
Expand Down Expand Up @@ -361,7 +351,7 @@ impl<Service: PushService> AccountManager<Service> {
aci_protocol_store: &mut Aci,
pni_protocol_store: &mut Pni,
skip_device_transfer: bool,
) -> Result<VerifyAccountResponse, LinkError> {
) -> Result<VerifyAccountResponse, ProvisioningError> {
let aci_identity_key_pair = aci_protocol_store
.get_identity_key_pair()
.instrument(tracing::trace_span!("get ACI identity key pair"))
Expand Down Expand Up @@ -421,8 +411,8 @@ impl<Service: PushService> AccountManager<Service> {
registration_method,
account_attributes,
skip_device_transfer,
*aci_identity_key,
*pni_identity_key,
aci_identity_key,
pni_identity_key,
dar,
)
.await?;
Expand Down
3 changes: 2 additions & 1 deletion libsignal-service/src/cipher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use crate::{
envelope::Envelope,
push_service::ServiceError,
sender::OutgoingPushMessage,
session_store::SessionStoreExt,
utils::BASE64_RELAXED,
ServiceAddress,
};
Expand Down Expand Up @@ -75,7 +76,7 @@ fn debug_envelope(envelope: &Envelope) -> String {

impl<S, R> ServiceCipher<S, R>
where
S: ProtocolStore + KyberPreKeyStore + SenderKeyStore + Clone,
S: ProtocolStore + SenderKeyStore + SessionStoreExt + Clone,
R: Rng + CryptoRng,
{
pub fn new(
Expand Down
1 change: 1 addition & 0 deletions libsignal-service/src/groups_v2/manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ pub fn decrypt_group(
fn current_days_seconds() -> (u64, u64) {
let days_seconds = |date: NaiveDate| {
date.and_time(NaiveTime::from_hms_opt(0, 0, 0).unwrap())
.and_utc()
.timestamp() as u64
};

Expand Down
5 changes: 2 additions & 3 deletions libsignal-service/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ pub mod utils;
pub mod websocket;

pub use crate::account_manager::{
decrypt_device_name, AccountManager, LinkError, Profile,
ProfileManagerError,
decrypt_device_name, AccountManager, Profile, ProfileManagerError,
};
pub use crate::service_address::*;

Expand Down Expand Up @@ -96,7 +95,7 @@ pub mod prelude {
profiles::ProfileKey,
};

pub use libsignal_protocol::DeviceId;
pub use libsignal_protocol::{DeviceId, IdentityKeyStore};
}

pub use libsignal_protocol as protocol;
Expand Down
28 changes: 19 additions & 9 deletions libsignal-service/src/pre_keys.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use std::{convert::TryFrom, time::SystemTime};

use crate::utils::{serde_base64, serde_public_key};
use crate::utils::{serde_base64, serde_identity_key};
use async_trait::async_trait;
use libsignal_protocol::{
error::SignalProtocolError, kem, GenericSignedPreKey, IdentityKeyPair,
IdentityKeyStore, KeyPair, KyberPreKeyId, KyberPreKeyRecord,
KyberPreKeyStore, PreKeyRecord, PreKeyStore, PublicKey, SignedPreKeyRecord,
SignedPreKeyStore,
error::SignalProtocolError, kem, GenericSignedPreKey, IdentityKey,
IdentityKeyPair, IdentityKeyStore, KeyPair, KyberPreKeyId,
KyberPreKeyRecord, KyberPreKeyStore, PreKeyRecord, PreKeyStore,
SignedPreKeyRecord, SignedPreKeyStore,
};

use serde::{Deserialize, Serialize};
Expand All @@ -16,7 +16,7 @@ use tracing::Instrument;
/// Additional methods for the Kyber pre key store
///
/// Analogue of Android's ServiceKyberPreKeyStore
pub trait ServiceKyberPreKeyStore: KyberPreKeyStore {
pub trait KyberPreKeyStoreExt: KyberPreKeyStore {
async fn store_last_resort_kyber_pre_key(
&mut self,
kyber_prekey_id: KyberPreKeyId,
Expand Down Expand Up @@ -55,7 +55,7 @@ pub trait PreKeysStore:
+ IdentityKeyStore
+ SignedPreKeyStore
+ KyberPreKeyStore
+ ServiceKyberPreKeyStore
+ KyberPreKeyStoreExt
{
/// ID of the next pre key
async fn next_pre_key_id(&self) -> Result<u32, SignalProtocolError>;
Expand Down Expand Up @@ -83,6 +83,16 @@ pub trait PreKeysStore:
&mut self,
id: u32,
) -> Result<(), SignalProtocolError>;

/// number of signed pre-keys we currently have in store
async fn signed_pre_keys_count(&self)
-> Result<usize, SignalProtocolError>;

/// number of kyber pre-keys we currently have in store
async fn kyber_pre_keys_count(
&self,
last_resort: bool,
) -> Result<usize, SignalProtocolError>;
}

#[derive(Debug, Deserialize, Serialize)]
Expand Down Expand Up @@ -169,8 +179,8 @@ impl TryFrom<KyberPreKeyRecord> for KyberPreKeyEntity {
pub struct PreKeyState {
pub pre_keys: Vec<PreKeyEntity>,
pub signed_pre_key: SignedPreKeyEntity,
#[serde(with = "serde_public_key")]
pub identity_key: PublicKey,
#[serde(with = "serde_identity_key")]
pub identity_key: IdentityKey,
#[serde(skip_serializing_if = "Option::is_none")]
pub pq_last_resort_key: Option<KyberPreKeyEntity>,
pub pq_pre_keys: Vec<KyberPreKeyEntity>,
Expand Down
12 changes: 3 additions & 9 deletions libsignal-service/src/provisioning/cipher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -138,9 +138,7 @@ impl ProvisioningCipher {
.body
.expect("no body in ProvisionMessage");
if body[0] != VERSION {
return Err(ProvisioningError::InvalidData {
reason: "Bad version number".into(),
});
return Err(ProvisioningError::BadVersionNumber);
}

let iv = &body[IV_OFFSET..(IV_LENGTH + IV_OFFSET)];
Expand All @@ -166,9 +164,7 @@ impl ProvisioningCipher {
let our_mac = verifier.finalize().into_bytes();
debug_assert_eq!(our_mac.len(), mac.len());
if &our_mac[..32] != mac {
return Err(ProvisioningError::InvalidData {
reason: "wrong MAC".into(),
});
return Err(ProvisioningError::MismatchedMac);
}

// libsignal-service-java uses Pkcs5,
Expand All @@ -177,9 +173,7 @@ impl ProvisioningCipher {
let cipher = cbc::Decryptor::<Aes256>::new(parts1.into(), iv.into());
let input = cipher
.decrypt_padded_vec_mut::<Pkcs7>(cipher_text)
.map_err(|e| ProvisioningError::InvalidData {
reason: format!("CBC/Padding error: {:?}", e).into(),
})?;
.map_err(ProvisioningError::AesPaddingError)?;

Ok(prost::Message::decode(Bytes::from(input))?)
}
Expand Down
Loading

0 comments on commit 26c036e

Please sign in to comment.