LinkingManager {
}
}
- pub async fn provision_secondary_device(
+ pub async fn provision_secondary_device(
&mut self,
- ctx: &Context,
+ csprng: &mut R,
signaling_key: SignalingKey,
device_name: &str,
mut tx: Sender,
@@ -297,10 +292,10 @@ impl LinkingManager {
.ws("/v1/websocket/provisioning/", None)
.await?;
- let registration_id = generate_registration_id(&ctx, 0)?;
+ // see libsignal-protocol-c / signal_protocol_key_helper_generate_registration_id
+ let registration_id = csprng.gen_range(1, 16380);
- let provisioning_pipe =
- ProvisioningPipe::from_socket(ws, stream, &ctx)?;
+ let provisioning_pipe = ProvisioningPipe::from_socket(ws, stream)?;
let provision_stream = provisioning_pipe.stream();
pin_mut!(provision_stream);
while let Some(step) = provision_stream.next().await {
@@ -324,8 +319,7 @@ impl LinkingManager {
})
})?;
- let public_key = PublicKey::decode_point(
- &ctx,
+ let public_key = PublicKey::deserialize(
&message.identity_key_public.ok_or(
ProvisioningError::InvalidData {
reason: "missing public key".into(),
@@ -333,8 +327,7 @@ impl LinkingManager {
)?,
)?;
- let private_key = PrivateKey::decode_point(
- &ctx,
+ let private_key = PrivateKey::deserialize(
&message.identity_key_private.ok_or(
ProvisioningError::InvalidData {
reason: "missing public key".into(),
diff --git a/libsignal-service/src/provisioning/mod.rs b/libsignal-service/src/provisioning/mod.rs
index 5d2ee23e9..b20fbb017 100644
--- a/libsignal-service/src/provisioning/mod.rs
+++ b/libsignal-service/src/provisioning/mod.rs
@@ -27,7 +27,13 @@ pub enum ProvisioningError {
#[error("Service error: {0}")]
ServiceError(#[from] ServiceError),
#[error("libsignal-protocol error: {0}")]
- ProtocolError(#[from] libsignal_protocol::Error),
+ ProtocolError(#[from] libsignal_protocol::error::SignalProtocolError),
#[error("ProvisioningCipher in encrypt-only mode")]
EncryptOnlyProvisioningCipher,
}
+
+pub fn generate_registration_id(
+ csprng: &mut R,
+) -> u32 {
+ csprng.gen_range(1, 16380)
+}
diff --git a/libsignal-service/src/provisioning/pipe.rs b/libsignal-service/src/provisioning/pipe.rs
index b2bcc4af3..ccb5f2ce5 100644
--- a/libsignal-service/src/provisioning/pipe.rs
+++ b/libsignal-service/src/provisioning/pipe.rs
@@ -8,8 +8,6 @@ use pin_project::pin_project;
use prost::Message;
use url::Url;
-use libsignal_protocol::Context;
-
pub use crate::proto::{
ProvisionEnvelope, ProvisionMessage, ProvisioningVersion,
};
@@ -43,12 +41,13 @@ impl ProvisioningPipe {
pub fn from_socket(
ws: WS,
stream: WS::Stream,
- ctx: &Context,
) -> Result {
Ok(ProvisioningPipe {
ws,
stream,
- provisioning_cipher: ProvisioningCipher::new(ctx.clone())?,
+ provisioning_cipher: ProvisioningCipher::generate(
+ &mut rand::thread_rng(),
+ )?,
})
}
@@ -150,9 +149,10 @@ impl ProvisioningPipe {
.append_pair("uuid", &uuid.uuid.unwrap())
.append_pair(
"pub_key",
- &format!(
- "{}",
- self.provisioning_cipher.public_key()
+ &base64::encode(
+ self.provisioning_cipher
+ .public_key()
+ .serialize(),
),
);
diff --git a/libsignal-service/src/push_service.rs b/libsignal-service/src/push_service.rs
index feb583b5e..a3fde4c54 100644
--- a/libsignal-service/src/push_service.rs
+++ b/libsignal-service/src/push_service.rs
@@ -12,13 +12,14 @@ use crate::{
ServiceAddress,
};
-use libsignal_protocol::{keys::PublicKey, Context, PreKeyBundle};
-
use aes_gcm::{
aead::{generic_array::GenericArray, Aead},
Aes256Gcm, NewAead,
};
use chrono::prelude::*;
+use libsignal_protocol::{
+ error::SignalProtocolError, IdentityKey, PreKeyBundle, PublicKey,
+};
use prost::Message as ProtobufMessage;
use serde::{Deserialize, Serialize};
use uuid::Uuid;
@@ -61,12 +62,12 @@ pub const STICKER_PATH: &str = "stickers/%s/full/%d";
**/
pub const KEEPALIVE_TIMEOUT_SECONDS: Duration = Duration::from_secs(55);
-pub const DEFAULT_DEVICE_ID: i32 = 1;
+pub const DEFAULT_DEVICE_ID: u32 = 1;
#[derive(Debug, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct DeviceId {
- pub device_id: i32,
+ pub device_id: u32,
}
#[derive(Debug, Serialize, Deserialize)]
@@ -163,23 +164,45 @@ pub struct WhoAmIResponse {
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PreKeyResponseItem {
- pub device_id: i32,
+ pub device_id: u32,
pub registration_id: u32,
- pub signed_pre_key: Option,
+ pub signed_pre_key: SignedPreKeyEntity,
pub pre_key: Option,
}
+impl PreKeyResponseItem {
+ fn into_bundle(
+ self,
+ identity: IdentityKey,
+ ) -> Result {
+ PreKeyBundle::new(
+ self.registration_id,
+ self.device_id,
+ self.pre_key
+ .map(|pk| -> Result<_, SignalProtocolError> {
+ Ok((pk.key_id, PublicKey::deserialize(&pk.public_key)?))
+ })
+ .transpose()?,
+ // pre_key: Option<(u32, PublicKey)>,
+ self.signed_pre_key.key_id,
+ PublicKey::deserialize(&self.signed_pre_key.public_key)?,
+ self.signed_pre_key.signature,
+ identity,
+ )
+ }
+}
+
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct MismatchedDevices {
- pub missing_devices: Vec,
- pub extra_devices: Vec,
+ pub missing_devices: Vec,
+ pub extra_devices: Vec,
}
#[derive(Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct StaleDevices {
- pub stale_devices: Vec,
+ pub stale_devices: Vec,
}
#[derive(Debug, Deserialize)]
@@ -268,7 +291,7 @@ pub enum ServiceError {
MacError,
#[error("Protocol error: {0}")]
- SignalProtocolError(#[from] libsignal_protocol::Error),
+ SignalProtocolError(#[from] SignalProtocolError),
#[error("{0:?}")]
MismatchedDevicesException(MismatchedDevices),
@@ -554,9 +577,8 @@ pub trait PushService {
async fn get_pre_key(
&mut self,
- context: &Context,
destination: &ServiceAddress,
- device_id: i32,
+ device_id: u32,
) -> Result {
let path = if let Some(ref relay) = destination.relay {
format!(
@@ -574,35 +596,15 @@ pub trait PushService {
.await?;
assert!(!pre_key_response.devices.is_empty());
+ let identity = IdentityKey::decode(&pre_key_response.identity_key)?;
let device = pre_key_response.devices.remove(0);
- let mut bundle = PreKeyBundle::builder()
- .identity_key(&PublicKey::decode_point(
- &context,
- &pre_key_response.identity_key,
- )?)
- .device_id(device.device_id)
- .registration_id(device.registration_id);
- if let Some(signed_pre_key) = device.signed_pre_key {
- bundle = bundle.signed_pre_key(
- signed_pre_key.key_id,
- &PublicKey::decode_point(&context, &signed_pre_key.public_key)?,
- );
- bundle = bundle.signature(&signed_pre_key.signature);
- }
- if let Some(pre_key) = device.pre_key {
- bundle = bundle.pre_key(
- pre_key.key_id,
- &PublicKey::decode_point(context, &pre_key.public_key)?,
- );
- }
- Ok(bundle.build()?)
+ Ok(device.into_bundle(identity)?)
}
async fn get_pre_keys(
&mut self,
- context: &Context,
destination: &ServiceAddress,
- device_id: i32,
+ device_id: u32,
) -> Result, ServiceError> {
let path = match (device_id, destination.relay.as_ref()) {
(1, None) => format!("/v2/keys/{}/*", destination.identifier()),
@@ -625,31 +627,9 @@ pub trait PushService {
.get_json(Endpoint::Service, &path, HttpAuthOverride::NoOverride)
.await?;
let mut pre_keys = vec![];
+ let identity = IdentityKey::decode(&pre_key_response.identity_key)?;
for device in pre_key_response.devices {
- let mut bundle = PreKeyBundle::builder()
- .identity_key(&PublicKey::decode_point(
- &context,
- &pre_key_response.identity_key,
- )?)
- .device_id(device.device_id)
- .registration_id(device.registration_id);
- if let Some(signed_pre_key) = device.signed_pre_key {
- bundle = bundle.signed_pre_key(
- signed_pre_key.key_id,
- &PublicKey::decode_point(
- &context,
- &signed_pre_key.public_key,
- )?,
- );
- bundle = bundle.signature(&signed_pre_key.signature);
- }
- if let Some(pre_key) = device.pre_key {
- bundle = bundle.pre_key(
- pre_key.key_id,
- &PublicKey::decode_point(context, &pre_key.public_key)?,
- );
- }
- pre_keys.push(bundle.build()?)
+ pre_keys.push(device.into_bundle(identity)?);
}
Ok(pre_keys)
}
diff --git a/libsignal-service/src/receiver.rs b/libsignal-service/src/receiver.rs
index 2507f92ee..cb50a703d 100644
--- a/libsignal-service/src/receiver.rs
+++ b/libsignal-service/src/receiver.rs
@@ -10,6 +10,7 @@ use crate::{
};
/// Equivalent of Java's `SignalServiceMessageReceiver`.
+#[derive(Clone)]
pub struct MessageReceiver {
service: Service,
}
diff --git a/libsignal-service/src/sealed_session_cipher.rs b/libsignal-service/src/sealed_session_cipher.rs
index 385300618..08d994a0b 100644
--- a/libsignal-service/src/sealed_session_cipher.rs
+++ b/libsignal-service/src/sealed_session_cipher.rs
@@ -1,20 +1,21 @@
-use phonenumber::PhoneNumber;
-use uuid::Uuid;
+use std::convert::TryFrom;
use aes_ctr::{
cipher::stream::{NewStreamCipher, StreamCipher},
Aes256Ctr,
};
-
use hmac::{Hmac, Mac, NewMac};
use libsignal_protocol::{
- keys::{PrivateKey, PublicKey},
- messages::{CiphertextType, PreKeySignalMessage, SignalMessage},
- Address as ProtocolAddress, Context, Deserializable, Serializable,
- SessionCipher, StoreContext,
+ error::SignalProtocolError, message_decrypt_prekey, message_decrypt_signal,
+ message_encrypt, CiphertextMessageType, IdentityKeyStore, KeyPair,
+ PreKeySignalMessage, PreKeyStore, PrivateKey, ProtocolAddress, PublicKey,
+ SessionStore, SignalMessage, SignedPreKeyStore, HKDF,
};
use log::error;
+use phonenumber::PhoneNumber;
+use rand::{CryptoRng, Rng};
use sha2::Sha256;
+use uuid::Uuid;
use crate::{push_service::ProfileKey, ServiceAddress};
@@ -42,7 +43,7 @@ pub enum SealedSessionError {
EncodeError(#[from] prost::EncodeError),
#[error("Protocol error {0}")]
- ProtocolError(#[from] libsignal_protocol::Error),
+ ProtocolError(#[from] SignalProtocolError),
#[error("recipient not trusted")]
NoSessionWithRecipient,
@@ -65,10 +66,12 @@ pub enum MacError {
}
#[derive(Clone)]
-pub(crate) struct SealedSessionCipher {
- context: Context,
- store_context: StoreContext,
- local_address: ServiceAddress,
+pub(crate) struct SealedSessionCipher {
+ session_store: S,
+ identity_key_store: I,
+ signed_pre_key_store: SP,
+ pre_key_store: P,
+ csprng: R,
certificate_validator: CertificateValidator,
}
@@ -93,7 +96,7 @@ struct UnidentifiedSenderMessage {
#[derive(Debug, Clone)]
pub struct UnidentifiedSenderMessageContent {
- r#type: CiphertextType,
+ r#type: CiphertextMessageType,
sender_certificate: SenderCertificate,
content: Vec,
}
@@ -102,7 +105,7 @@ pub struct UnidentifiedSenderMessageContent {
pub struct SenderCertificate {
signer: ServerCertificate,
key: PublicKey,
- sender_device_id: i32,
+ sender_device_id: u32,
sender_uuid: Option,
sender_e164: Option,
expiration: u64,
@@ -140,7 +143,7 @@ pub struct CertificateValidator {
pub(crate) struct DecryptionResult {
pub sender_uuid: Option,
pub sender_e164: Option,
- pub device_id: i32,
+ pub device_id: u32,
pub padded_message: Vec,
pub version: u32,
}
@@ -160,10 +163,7 @@ impl UnidentifiedAccess {
impl UnidentifiedSenderMessage {
const CIPHERTEXT_VERSION: u8 = 1;
- fn from_bytes(
- context: &Context,
- serialized: &[u8],
- ) -> Result {
+ fn from_bytes(serialized: &[u8]) -> Result {
let version = serialized[0] >> 4;
if version > Self::CIPHERTEXT_VERSION {
return Err(SealedSessionError::InvalidMetadataVersionError(
@@ -184,10 +184,7 @@ impl UnidentifiedSenderMessage {
Some(encrypted_static),
Some(encrypted_message),
) => Ok(Self {
- ephemeral: PublicKey::decode_point(
- &context,
- &ephemeral_public,
- )?,
+ ephemeral: PublicKey::deserialize(&ephemeral_public)?,
encrypted_static,
encrypted_message,
}),
@@ -202,9 +199,7 @@ impl UnidentifiedSenderMessage {
let mut buf =
vec![Self::CIPHERTEXT_VERSION << 4 | Self::CIPHERTEXT_VERSION];
crate::proto::UnidentifiedSenderMessage {
- ephemeral_public: Some(
- self.ephemeral.to_bytes()?.as_slice().to_vec(),
- ),
+ ephemeral_public: Some(self.ephemeral.serialize().to_vec()),
encrypted_static: Some(self.encrypted_static),
encrypted_message: Some(self.encrypted_message),
}
@@ -213,17 +208,28 @@ impl UnidentifiedSenderMessage {
}
}
-impl SealedSessionCipher {
+impl SealedSessionCipher
+where
+ S: SessionStore,
+ I: IdentityKeyStore,
+ SP: SignedPreKeyStore,
+ P: PreKeyStore,
+ R: Rng + CryptoRng,
+{
pub(crate) fn new(
- context: Context,
- store_context: StoreContext,
- local_address: ServiceAddress,
+ session_store: S,
+ identity_key_store: I,
+ signed_pre_key_store: SP,
+ pre_key_store: P,
+ csprng: R,
certificate_validator: CertificateValidator,
) -> Self {
Self {
- context,
- store_context,
- local_address,
+ session_store,
+ identity_key_store,
+ signed_pre_key_store,
+ pre_key_store,
+ csprng,
certificate_validator,
}
}
@@ -231,43 +237,47 @@ impl SealedSessionCipher {
/// unused until we make progress on https://github.com/Michael-F-Bryan/libsignal-service-rs/issues/25
/// messages from unidentified senders can only be sent via a unidentifiedPipe
#[allow(dead_code)]
- pub fn encrypt(
- &self,
+ pub async fn encrypt(
+ &mut self,
destination: &ProtocolAddress,
sender_certificate: SenderCertificate,
padded_plaintext: &[u8],
) -> Result, SealedSessionError> {
- let message = SessionCipher::new(
- &self.context,
- &self.store_context,
- &destination,
- )?
- .encrypt(padded_plaintext)?;
-
- let our_identity = &self.store_context.identity_key_pair()?;
+ let message = message_encrypt(
+ padded_plaintext,
+ destination,
+ &mut self.session_store,
+ &mut self.identity_key_store,
+ None,
+ )
+ .await?;
+
+ let our_identity =
+ &self.identity_key_store.get_identity_key_pair(None).await?;
let their_identity = self
- .store_context
- .get_identity(destination.clone())?
+ .identity_key_store
+ .get_identity(destination, None)
+ .await?
.ok_or(SealedSessionError::NoSessionWithRecipient)?;
- let ephemeral = libsignal_protocol::generate_key_pair(&self.context)?;
+ let ephemeral = KeyPair::generate(&mut self.csprng);
let ephemeral_salt = [
b"UnidentifiedDelivery",
- their_identity.to_bytes()?.as_slice(),
- ephemeral.public().to_bytes()?.as_slice(),
+ their_identity.serialize().as_ref(),
+ ephemeral.public_key.serialize().as_ref(),
]
.concat();
let ephemeral_keys = self.calculate_ephemeral_keys(
- &their_identity,
- &ephemeral.private(),
+ &their_identity.public_key(),
+ &ephemeral.private_key,
&ephemeral_salt,
)?;
let static_key_ciphertext = self.encrypt_bytes(
&ephemeral_keys.cipher_key,
&ephemeral_keys.mac_key,
- our_identity.public().to_bytes()?.as_slice(),
+ &our_identity.public_key().serialize(),
)?;
let static_salt = [
@@ -277,15 +287,15 @@ impl SealedSessionCipher {
.concat();
let static_keys = self.calculate_static_keys(
- &their_identity,
- &our_identity.private(),
+ &their_identity.public_key(),
+ &our_identity.private_key(),
&static_salt,
)?;
let content = UnidentifiedSenderMessageContent {
- r#type: message.get_type()?,
+ r#type: message.message_type(),
sender_certificate,
- content: message.serialize()?.as_slice().to_vec(),
+ content: message.serialize().to_vec(),
};
let message_bytes = self.encrypt_bytes(
@@ -295,32 +305,33 @@ impl SealedSessionCipher {
)?;
UnidentifiedSenderMessage {
- ephemeral: ephemeral.public(),
+ ephemeral: ephemeral.public_key,
encrypted_static: static_key_ciphertext,
encrypted_message: message_bytes,
}
.into_bytes()
}
- pub fn decrypt(
- &self,
+ pub async fn decrypt(
+ &mut self,
ciphertext: &[u8],
timestamp: u64,
) -> Result {
- let our_identity = self.store_context.identity_key_pair()?;
- let wrapper =
- UnidentifiedSenderMessage::from_bytes(&self.context, ciphertext)?;
+ let our_identity =
+ self.identity_key_store.get_identity_key_pair(None).await?;
+
+ let wrapper = UnidentifiedSenderMessage::from_bytes(ciphertext)?;
let ephemeral_salt = [
b"UnidentifiedDelivery",
- our_identity.public().to_bytes()?.as_slice(),
- wrapper.ephemeral.to_bytes()?.as_slice(),
+ our_identity.public_key().serialize().as_ref(),
+ wrapper.ephemeral.serialize().as_ref(),
]
.concat();
let ephemeral_keys = self.calculate_ephemeral_keys(
&wrapper.ephemeral,
- &our_identity.private(),
+ &our_identity.private_key(),
&ephemeral_salt,
)?;
@@ -330,13 +341,12 @@ impl SealedSessionCipher {
&wrapper.encrypted_static,
)?;
- let static_key =
- PublicKey::decode_point(&self.context, &static_key_bytes)?;
+ let static_key = PublicKey::deserialize(&static_key_bytes)?;
let static_salt =
[ephemeral_keys.chain_key, wrapper.encrypted_static].concat();
let static_keys = self.calculate_static_keys(
&static_key,
- &our_identity.private(),
+ &our_identity.private_key(),
&static_salt,
)?;
@@ -347,13 +357,12 @@ impl SealedSessionCipher {
)?;
let content = UnidentifiedSenderMessageContent::try_from(
- &self.context,
message_bytes.as_slice(),
)?;
self.certificate_validator
.validate(&content.sender_certificate, timestamp)?;
- self.decrypt_message_content(content)
+ self.decrypt_message_content(content).await
}
fn calculate_ephemeral_keys(
@@ -362,12 +371,13 @@ impl SealedSessionCipher {
private_key: &PrivateKey,
salt: &[u8],
) -> Result {
- let ephemeral_secret = public_key.calculate_agreement(private_key)?;
- let ephemeral_derived = libsignal_protocol::create_hkdf(
- &self.context,
- 3,
- )?
- .derive_secrets(96, &ephemeral_secret, salt, &[])?;
+ let ephemeral_secret = private_key.calculate_agreement(public_key)?;
+ let ephemeral_derived = HKDF::new(3)?.derive_salted_secrets(
+ &ephemeral_secret,
+ salt,
+ &[],
+ 96,
+ )?;
let ephemeral_keys = EphemeralKeys {
chain_key: ephemeral_derived[0..32].into(),
cipher_key: ephemeral_derived[32..64].into(),
@@ -382,9 +392,13 @@ impl SealedSessionCipher {
private_key: &PrivateKey,
salt: &[u8],
) -> Result {
- let static_secret = public_key.calculate_agreement(private_key)?;
- let static_derived = libsignal_protocol::create_hkdf(&self.context, 3)?
- .derive_secrets(96, &static_secret, salt, &[])?;
+ let static_secret = private_key.calculate_agreement(public_key)?;
+ let static_derived = HKDF::new(3)?.derive_salted_secrets(
+ &static_secret,
+ salt,
+ &[],
+ 96,
+ )?;
Ok(StaticKeys {
cipher_key: static_derived[32..64].into(),
mac_key: static_derived[64..96].into(),
@@ -443,8 +457,8 @@ impl SealedSessionCipher {
Ok(decrypted)
}
- fn decrypt_message_content(
- &self,
+ async fn decrypt_message_content(
+ &mut self,
message: UnidentifiedSenderMessageContent,
) -> Result {
let UnidentifiedSenderMessageContent {
@@ -453,29 +467,51 @@ impl SealedSessionCipher {
sender_certificate,
} = message;
let sender = crate::cipher::get_preferred_protocol_address(
- &self.store_context,
- sender_certificate.address(),
+ &self.session_store,
+ &sender_certificate.address(),
sender_certificate.sender_device_id,
- )?;
- let session_cipher =
- SessionCipher::new(&self.context, &self.store_context, &sender)?;
+ )
+ .await?;
+
let msg = match r#type {
- CiphertextType::Signal => {
- let msg = session_cipher.decrypt_message(
- &SignalMessage::deserialize(&self.context, &content)?,
- )?;
+ CiphertextMessageType::Whisper => {
+ let msg = message_decrypt_signal(
+ &SignalMessage::try_from(&content[..])?,
+ &sender,
+ &mut self.session_store,
+ &mut self.identity_key_store,
+ &mut self.csprng,
+ None,
+ )
+ .await?;
msg.as_slice().to_vec()
}
- CiphertextType::PreKey => {
- let msg = session_cipher.decrypt_pre_key_message(
- &PreKeySignalMessage::deserialize(&self.context, &content)?,
- )?;
+ CiphertextMessageType::PreKey => {
+ let msg = message_decrypt_prekey(
+ &PreKeySignalMessage::try_from(&content[..])?,
+ &sender,
+ &mut self.session_store,
+ &mut self.identity_key_store,
+ &mut self.pre_key_store,
+ &mut self.signed_pre_key_store,
+ &mut self.csprng,
+ None,
+ )
+ .await?;
msg.as_slice().to_vec()
}
_ => unreachable!("unknown message from unidentified sender type"),
};
- let version = session_cipher.get_session_version()?;
+ let version = self
+ .session_store
+ .load_session(&sender, None)
+ .await?
+ .ok_or_else(|| {
+ SignalProtocolError::SessionNotFound(format!("{}", sender))
+ })?
+ .session_version()?;
+
Ok(DecryptionResult {
padded_message: msg,
version,
@@ -487,10 +523,7 @@ impl SealedSessionCipher {
}
impl UnidentifiedSenderMessageContent {
- fn try_from(
- context: &Context,
- serialized: &[u8],
- ) -> Result {
+ fn try_from(serialized: &[u8]) -> Result {
use crate::proto::unidentified_sender_message::{self, message};
let message: unidentified_sender_message::Message =
@@ -500,9 +533,11 @@ impl UnidentifiedSenderMessageContent {
(Some(message_type), Some(sender_certificate), Some(content)) => {
Ok(Self {
r#type: match message::Type::from_i32(message_type) {
- Some(message::Type::Message) => CiphertextType::Signal,
+ Some(message::Type::Message) => {
+ CiphertextMessageType::Whisper
+ }
Some(message::Type::PrekeyMessage) => {
- CiphertextType::PreKey
+ CiphertextMessageType::PreKey
}
t => {
return Err(
@@ -513,7 +548,6 @@ impl UnidentifiedSenderMessageContent {
}
},
sender_certificate: SenderCertificate::try_from(
- &context,
sender_certificate,
)?,
content,
@@ -532,8 +566,8 @@ impl UnidentifiedSenderMessageContent {
unidentified_sender_message::Message {
r#type: Some(match self.r#type {
- CiphertextType::PreKey => message::Type::PrekeyMessage,
- CiphertextType::Signal => message::Type::Message,
+ CiphertextMessageType::PreKey => message::Type::PrekeyMessage,
+ CiphertextMessageType::Whisper => message::Type::Message,
_ => {
return Err(
SealedSessionError::InvalidMetadataMessageError(
@@ -556,7 +590,6 @@ impl UnidentifiedSenderMessageContent {
impl SenderCertificate {
fn try_from(
- context: &Context,
wrapper: crate::proto::SenderCertificate,
) -> Result {
use crate::proto::sender_certificate::Certificate;
@@ -591,16 +624,11 @@ impl SenderCertificate {
.transpose()?;
Ok(Self {
- signer: ServerCertificate::try_from(
- &context, signer,
- )?,
- key: PublicKey::decode_point(
- &context,
- &identity_key,
- )?,
+ signer: ServerCertificate::try_from(signer)?,
+ key: PublicKey::deserialize(&identity_key)?,
sender_e164,
sender_uuid,
- sender_device_id: sender_device_id as i32,
+ sender_device_id,
expiration: expires,
certificate,
signature,
@@ -624,7 +652,6 @@ impl SenderCertificate {
impl ServerCertificate {
fn try_from(
- context: &Context,
wrapper: crate::proto::ServerCertificate,
) -> Result {
use crate::proto::server_certificate;
@@ -637,7 +664,7 @@ impl ServerCertificate {
match (server_certificate.id, server_certificate.key) {
(Some(id), Some(key)) => Ok(Self {
key_id: id,
- key: PublicKey::decode_point(context, &key)?,
+ key: PublicKey::deserialize(&key)?,
certificate,
signature,
}),
@@ -660,26 +687,28 @@ impl CertificateValidator {
validation_time: u64,
) -> Result<(), SealedSessionError> {
let server_certificate = &certificate.signer;
- self.trust_root
- .verify_signature(
- &server_certificate.certificate,
- &server_certificate.signature,
- )
- .map_err(|e| {
- error!("failed to verify server certificate: {}", e);
- SealedSessionError::InvalidCertificate
- })?;
- server_certificate
+ match self.trust_root.verify_signature(
+ &server_certificate.certificate,
+ &server_certificate.signature,
+ ) {
+ Err(_) | Ok(false) => {
+ return Err(SealedSessionError::InvalidCertificate)
+ }
+ _ => (),
+ };
+
+ match server_certificate
.key
.verify_signature(&certificate.certificate, &certificate.signature)
- .map_err(|e| {
- error!("failed to verify certificate: {}", e);
- SealedSessionError::InvalidCertificate
- })?;
+ {
+ Err(_) | Ok(false) => {
+ return Err(SealedSessionError::InvalidCertificate)
+ }
+ _ => (),
+ }
if validation_time > certificate.expiration {
- error!("certificate is expired");
return Err(SealedSessionError::ExpiredCertificate);
}
@@ -689,21 +718,17 @@ impl CertificateValidator {
#[cfg(test)]
mod tests {
- use std::time::UNIX_EPOCH;
+ use std::time::{SystemTime, UNIX_EPOCH};
use libsignal_protocol::{
- self as sig,
- crypto::DefaultCrypto,
- keys::PreKey,
- keys::{KeyPair, PublicKey},
- stores::InMemoryPreKeyStore,
- stores::InMemorySessionStore,
- stores::{InMemoryIdentityKeyStore, InMemorySignedPreKeyStore},
- Address as ProtocolAddress, Context, PreKeyBundle, Serializable,
- SessionBuilder, StoreContext,
+ process_prekey_bundle, IdentityKeyPair, IdentityKeyStore,
+ InMemIdentityKeyStore, InMemPreKeyStore, InMemSessionStore,
+ InMemSignedPreKeyStore, KeyPair, PreKeyBundle, PreKeyRecord,
+ PreKeyStore, ProtocolAddress, PublicKey, SignedPreKeyRecord,
+ SignedPreKeyStore,
};
- use crate::ServiceAddress;
+ use crate::{provisioning::generate_registration_id, ServiceAddress};
use super::{
CertificateValidator, SealedSessionCipher, SealedSessionError,
@@ -720,51 +745,62 @@ mod tests {
.unwrap()
}
- fn bob_address() -> ServiceAddress {
- ServiceAddress::parse(
- Some("+14152222222"),
- Some("e80f7bbe-5b94-471e-bd8c-2173654ea3d1"),
- )
- .unwrap()
+ struct Stores {
+ identity_key_store: InMemIdentityKeyStore,
+ session_store: InMemSessionStore,
+ signed_pre_key_store: InMemSignedPreKeyStore,
+ pre_key_store: InMemPreKeyStore,
}
- #[test]
- fn test_encrypt_decrypt() -> anyhow::Result<()> {
- let (ctx, alice_store_context, bob_store_context) = create_contexts()?;
- initialize_session(&ctx, &bob_store_context, &alice_store_context)?;
+ #[tokio::test]
+ async fn test_encrypt_decrypt() -> anyhow::Result<()> {
+ let mut csprng = rand::thread_rng();
+
+ let (alice_stores, bob_stores) = create_stores(&mut csprng).await?;
- let trust_root = libsignal_protocol::generate_key_pair(&ctx)?;
+ let trust_root = KeyPair::generate(&mut csprng);
let certificate_validator =
- CertificateValidator::new(trust_root.public());
+ CertificateValidator::new(trust_root.public_key);
let sender_certificate = create_certificate_for(
- &ctx,
&trust_root,
alice_address(),
1,
- alice_store_context.identity_key_pair()?.public(),
+ *alice_stores
+ .identity_key_store
+ .get_identity_key_pair(None)
+ .await?
+ .public_key(),
31337,
+ &mut csprng,
)?;
- let alice_cipher = SealedSessionCipher::new(
- ctx.clone(),
- alice_store_context,
- alice_address(),
+ let mut alice_cipher = SealedSessionCipher::new(
+ alice_stores.session_store,
+ alice_stores.identity_key_store,
+ alice_stores.signed_pre_key_store,
+ alice_stores.pre_key_store,
+ csprng,
certificate_validator.clone(),
);
- let ciphertext = alice_cipher.encrypt(
- &ProtocolAddress::new("+14152222222", 1),
- sender_certificate,
- "smert za smert".as_bytes(),
- )?;
- let bob_cipher = SealedSessionCipher::new(
- ctx,
- bob_store_context,
- bob_address(),
+ let ciphertext = alice_cipher
+ .encrypt(
+ &ProtocolAddress::new("+14152222222".into(), 1),
+ sender_certificate,
+ "smert za smert".as_bytes(),
+ )
+ .await?;
+
+ let mut bob_cipher = SealedSessionCipher::new(
+ bob_stores.session_store,
+ bob_stores.identity_key_store,
+ bob_stores.signed_pre_key_store,
+ bob_stores.pre_key_store,
+ csprng,
certificate_validator,
);
- let plaintext = bob_cipher.decrypt(&ciphertext, 31335)?;
+ let plaintext = bob_cipher.decrypt(&ciphertext, 31335).await?;
assert_eq!(
String::from_utf8_lossy(&plaintext.padded_message),
@@ -777,49 +813,59 @@ mod tests {
Ok(())
}
- #[test]
- fn test_encrypt_decrypt_untrusted() -> anyhow::Result<()> {
- let (ctx, alice_store_context, bob_store_context) = create_contexts()?;
- initialize_session(&ctx, &bob_store_context, &alice_store_context)?;
+ #[tokio::test]
+ async fn test_encrypt_decrypt_untrusted() -> anyhow::Result<()> {
+ let mut csprng = rand::thread_rng();
+ let (alice_stores, bob_stores) = create_stores(&mut csprng).await?;
- let trust_root = libsignal_protocol::generate_key_pair(&ctx)?;
+ let trust_root = KeyPair::generate(&mut csprng);
let certificate_validator =
- CertificateValidator::new(trust_root.public());
+ CertificateValidator::new(trust_root.public_key);
- let false_trust_root = libsignal_protocol::generate_key_pair(&ctx)?;
+ let false_trust_root = KeyPair::generate(&mut csprng);
let false_certificate_validator =
- CertificateValidator::new(false_trust_root.public());
+ CertificateValidator::new(false_trust_root.public_key);
let sender_certificate = create_certificate_for(
- &ctx,
&trust_root,
alice_address(),
1,
- alice_store_context.identity_key_pair()?.public(),
+ *alice_stores
+ .identity_key_store
+ .get_identity_key_pair(None)
+ .await?
+ .public_key(),
31337,
+ &mut csprng,
)?;
- let alice_cipher = SealedSessionCipher::new(
- ctx.clone(),
- alice_store_context,
- alice_address(),
+ let mut alice_cipher = SealedSessionCipher::new(
+ alice_stores.session_store,
+ alice_stores.identity_key_store,
+ alice_stores.signed_pre_key_store,
+ alice_stores.pre_key_store,
+ csprng,
certificate_validator,
);
- let ciphertext = alice_cipher.encrypt(
- &ProtocolAddress::new("+14152222222", 1),
- sender_certificate,
- "и вот я".as_bytes(),
- )?;
-
- let bob_cipher = SealedSessionCipher::new(
- ctx,
- bob_store_context,
- bob_address(),
+ let ciphertext = alice_cipher
+ .encrypt(
+ &ProtocolAddress::new("+14152222222".into(), 1),
+ sender_certificate,
+ "и вот я".as_bytes(),
+ )
+ .await?;
+
+ let mut bob_cipher = SealedSessionCipher::new(
+ bob_stores.session_store,
+ bob_stores.identity_key_store,
+ bob_stores.signed_pre_key_store,
+ bob_stores.pre_key_store,
+ csprng,
false_certificate_validator,
);
- let plaintext = bob_cipher.decrypt(&ciphertext, 31335);
+ let plaintext = bob_cipher.decrypt(&ciphertext, 31335).await;
match plaintext {
Err(SealedSessionError::InvalidCertificate) => Ok(()),
@@ -827,132 +873,150 @@ mod tests {
}
}
- #[test]
- fn test_encrypt_decrypt_expired() -> anyhow::Result<()> {
- let (ctx, alice_store_context, bob_store_context) = create_contexts()?;
- initialize_session(&ctx, &bob_store_context, &alice_store_context)?;
+ #[tokio::test]
+ async fn test_encrypt_decrypt_expired() -> anyhow::Result<()> {
+ let mut csprng = rand::thread_rng();
+ let (alice_stores, bob_stores) = create_stores(&mut csprng).await?;
- let trust_root = libsignal_protocol::generate_key_pair(&ctx)?;
+ let trust_root = KeyPair::generate(&mut csprng);
let certificate_validator =
- CertificateValidator::new(trust_root.public());
+ CertificateValidator::new(trust_root.public_key);
let sender_certificate = create_certificate_for(
- &ctx,
&trust_root,
alice_address(),
1,
- alice_store_context.identity_key_pair()?.public(),
+ *alice_stores
+ .identity_key_store
+ .get_identity_key_pair(None)
+ .await?
+ .public_key(),
31337,
+ &mut csprng,
)?;
- let alice_cipher = SealedSessionCipher::new(
- ctx.clone(),
- alice_store_context,
- alice_address(),
+ let mut alice_cipher = SealedSessionCipher::new(
+ alice_stores.session_store,
+ alice_stores.identity_key_store,
+ alice_stores.signed_pre_key_store,
+ alice_stores.pre_key_store,
+ csprng,
certificate_validator.clone(),
);
- let ciphertext = alice_cipher.encrypt(
- &ProtocolAddress::new("+14152222222", 1),
- sender_certificate,
- "smert za smert".as_bytes(),
- )?;
-
- let bob_cipher = SealedSessionCipher::new(
- ctx,
- bob_store_context,
- bob_address(),
+ let ciphertext = alice_cipher
+ .encrypt(
+ &ProtocolAddress::new("+14152222222".into(), 1),
+ sender_certificate,
+ "smert za smert".as_bytes(),
+ )
+ .await?;
+
+ let mut bob_cipher = SealedSessionCipher::new(
+ bob_stores.session_store,
+ bob_stores.identity_key_store,
+ bob_stores.signed_pre_key_store,
+ bob_stores.pre_key_store,
+ csprng,
certificate_validator,
);
- match bob_cipher.decrypt(&ciphertext, 31338) {
+ match bob_cipher.decrypt(&ciphertext, 31338).await {
Err(SealedSessionError::ExpiredCertificate) => Ok(()),
_ => panic!("certificate is expired, we should not get decrypted data here!11!")
}
}
- #[test]
- fn test_encrypt_from_wrong_identity() -> anyhow::Result<()> {
- let (ctx, alice_store_context, bob_store_context) = create_contexts()?;
- initialize_session(&ctx, &bob_store_context, &alice_store_context)?;
+ #[tokio::test]
+ async fn test_encrypt_from_wrong_identity() -> anyhow::Result<()> {
+ let mut csprng = rand::thread_rng();
+ let (alice_stores, bob_stores) = create_stores(&mut csprng).await?;
- let trust_root = libsignal_protocol::generate_key_pair(&ctx)?;
- let random_key_pair = libsignal_protocol::generate_key_pair(&ctx)?;
+ let trust_root = KeyPair::generate(&mut csprng);
+ let random_key_pair = KeyPair::generate(&mut csprng);
let certificate_validator =
- CertificateValidator::new(trust_root.public());
+ CertificateValidator::new(trust_root.public_key);
let sender_certificate = create_certificate_for(
- &ctx,
&random_key_pair,
alice_address(),
1,
- alice_store_context.identity_key_pair()?.public(),
+ *alice_stores
+ .identity_key_store
+ .get_identity_key_pair(None)
+ .await?
+ .public_key(),
31337,
+ &mut csprng,
)?;
- let alice_cipher = SealedSessionCipher::new(
- ctx.clone(),
- alice_store_context,
- alice_address(),
+ let mut alice_cipher = SealedSessionCipher::new(
+ alice_stores.session_store,
+ alice_stores.identity_key_store,
+ alice_stores.signed_pre_key_store,
+ alice_stores.pre_key_store,
+ csprng,
certificate_validator.clone(),
);
- let ciphertext = alice_cipher.encrypt(
- &ProtocolAddress::new("+14152222222", 1),
- sender_certificate,
- "smert za smert".as_bytes(),
- )?;
- let bob_cipher = SealedSessionCipher::new(
- ctx,
- bob_store_context,
- bob_address(),
+ let ciphertext = alice_cipher
+ .encrypt(
+ &ProtocolAddress::new("+14152222222".into(), 1),
+ sender_certificate,
+ "smert za smert".as_bytes(),
+ )
+ .await?;
+
+ let mut bob_cipher = SealedSessionCipher::new(
+ bob_stores.session_store,
+ bob_stores.identity_key_store,
+ bob_stores.signed_pre_key_store,
+ bob_stores.pre_key_store,
+ csprng,
certificate_validator,
);
- match bob_cipher.decrypt(&ciphertext, 31335) {
+ match bob_cipher.decrypt(&ciphertext, 31335).await {
Err(SealedSessionError::InvalidCertificate) => Ok(()),
_ => panic!("the certificate is invalid here!11"),
}
}
- fn create_contexts(
- ) -> Result<(Context, StoreContext, StoreContext), SealedSessionError> {
- let ctx = Context::new(DefaultCrypto::default())?;
-
- let alice_identity = sig::generate_identity_key_pair(&ctx).unwrap();
- let alice_store = sig::store_context(
- &ctx,
- InMemoryPreKeyStore::default(),
- InMemorySignedPreKeyStore::default(),
- InMemorySessionStore::default(),
- InMemoryIdentityKeyStore::new(
- sig::generate_registration_id(&ctx, 0).unwrap(),
- &alice_identity,
+ async fn create_stores(
+ csprng: &mut R,
+ ) -> anyhow::Result<(Stores, Stores)> {
+ let mut alice_stores = Stores {
+ identity_key_store: InMemIdentityKeyStore::new(
+ IdentityKeyPair::generate(csprng),
+ generate_registration_id(csprng),
),
- )?;
+ session_store: InMemSessionStore::new(),
+ signed_pre_key_store: InMemSignedPreKeyStore::new(),
+ pre_key_store: InMemPreKeyStore::new(),
+ };
- let bob_identity = sig::generate_identity_key_pair(&ctx).unwrap();
- let bob_store = sig::store_context(
- &ctx,
- InMemoryPreKeyStore::default(),
- InMemorySignedPreKeyStore::default(),
- InMemorySessionStore::default(),
- InMemoryIdentityKeyStore::new(
- sig::generate_registration_id(&ctx, 0).unwrap(),
- &bob_identity,
+ let mut bob_stores = Stores {
+ identity_key_store: InMemIdentityKeyStore::new(
+ IdentityKeyPair::generate(csprng),
+ generate_registration_id(csprng),
),
- )?;
+ session_store: InMemSessionStore::new(),
+ signed_pre_key_store: InMemSignedPreKeyStore::new(),
+ pre_key_store: InMemPreKeyStore::new(),
+ };
- Ok((ctx, alice_store, bob_store))
+ initialize_session(&mut alice_stores, &mut bob_stores, csprng).await?;
+
+ Ok((alice_stores, bob_stores))
}
- fn create_certificate_for(
- context: &Context,
+ fn create_certificate_for(
trust_root: &KeyPair,
addr: ServiceAddress,
device_id: u32,
identity_key: PublicKey,
expires: u64,
+ csprng: &mut R,
) -> Result {
- let server_key = libsignal_protocol::generate_key_pair(&context)?;
+ let server_key = KeyPair::generate(csprng);
let uuid = addr.uuid.as_ref().map(uuid::Uuid::to_string);
let e164 = addr.e164();
@@ -960,22 +1024,17 @@ mod tests {
let mut server_certificate_bytes = vec![];
crate::proto::server_certificate::Certificate {
id: Some(1),
- key: Some(server_key.public().serialize()?.as_slice().to_vec()),
+ key: Some(server_key.public_key.serialize().into_vec()),
}
.encode(&mut server_certificate_bytes)?;
- let server_certificate_signature =
- libsignal_protocol::calculate_signature(
- &context,
- &trust_root.private(),
- &server_certificate_bytes,
- )?
- .as_slice()
- .to_vec();
+ let server_certificate_signature = trust_root
+ .private_key
+ .calculate_signature(&server_certificate_bytes, csprng)?;
let server_certificate = crate::proto::ServerCertificate {
certificate: Some(server_certificate_bytes),
- signature: Some(server_certificate_signature),
+ signature: Some(server_certificate_signature.into_vec()),
};
let mut sender_certificate_bytes = vec![];
@@ -983,62 +1042,83 @@ mod tests {
sender_uuid: uuid,
sender_e164: e164,
sender_device: Some(device_id),
- identity_key: Some(identity_key.serialize()?.as_slice().to_vec()),
+ identity_key: Some(identity_key.serialize().into_vec()),
expires: Some(expires),
signer: Some(server_certificate),
}
.encode(&mut sender_certificate_bytes)?;
- let sender_certificate_signature =
- libsignal_protocol::calculate_signature(
- &context,
- &server_key.private(),
- &sender_certificate_bytes,
- )?
- .as_slice()
- .to_vec();
+ let sender_certificate_signature = server_key
+ .private_key
+ .calculate_signature(&sender_certificate_bytes, csprng)?;
- SenderCertificate::try_from(
- &context,
+ Ok(SenderCertificate::try_from(
crate::proto::SenderCertificate {
certificate: Some(sender_certificate_bytes),
- signature: Some(sender_certificate_signature),
+ signature: Some(sender_certificate_signature.into_vec()),
},
- )
+ )?)
}
- fn initialize_session(
- context: &Context,
- bob_store_context: &StoreContext,
- alice_store_context: &StoreContext,
+ async fn initialize_session(
+ alice_stores: &mut Stores,
+ bob_stores: &mut Stores,
+ csprng: &mut R,
) -> Result<(), SealedSessionError> {
- let bob_pre_key = libsignal_protocol::generate_key_pair(&context)?;
- let bob_identity_key = bob_store_context.identity_key_pair()?;
- let bob_signed_pre_key = libsignal_protocol::generate_signed_pre_key(
- &context,
- &bob_identity_key,
- 2,
- UNIX_EPOCH,
- )?;
+ let bob_pre_key = PreKeyRecord::new(1, &KeyPair::generate(csprng));
+ let bob_identity_key_pair = bob_stores
+ .identity_key_store
+ .get_identity_key_pair(None)
+ .await?;
+
+ // TODO: check
+ let signed_pre_key_pair = KeyPair::generate(csprng);
+ let signed_pre_key_signature = bob_identity_key_pair
+ .private_key()
+ .calculate_signature(
+ &signed_pre_key_pair.public_key.serialize(),
+ csprng,
+ )?
+ .into_vec();
- let bob_bundle = PreKeyBundle::builder()
- .registration_id(1)
- .device_id(1)
- .pre_key(1, &bob_pre_key.public())
- .signed_pre_key(2, &bob_signed_pre_key.key_pair().public())
- .signature(&bob_signed_pre_key.signature())
- .identity_key(&bob_identity_key.public())
- .build()?;
-
- let alice_session_builder = SessionBuilder::new(
- &context,
- &alice_store_context,
- &ProtocolAddress::new("+14152222222", 1),
+ let bob_signed_pre_key_record = SignedPreKeyRecord::new(
+ 2,
+ SystemTime::now()
+ .duration_since(UNIX_EPOCH)
+ .unwrap()
+ .as_millis() as u64,
+ &signed_pre_key_pair,
+ &signed_pre_key_signature,
);
- alice_session_builder.process_pre_key_bundle(&bob_bundle)?;
- bob_store_context.store_signed_pre_key(&bob_signed_pre_key)?;
- bob_store_context.store_pre_key(&PreKey::new(1, &bob_pre_key)?)?;
+ let bob_bundle = PreKeyBundle::new(
+ 1,
+ 1,
+ Some((1, bob_pre_key.public_key()?)),
+ 2,
+ signed_pre_key_pair.public_key,
+ signed_pre_key_signature,
+ *bob_identity_key_pair.identity_key(),
+ )?;
+
+ process_prekey_bundle(
+ &ProtocolAddress::new("+14152222222".into(), 1),
+ &mut alice_stores.session_store,
+ &mut alice_stores.identity_key_store,
+ &bob_bundle,
+ csprng,
+ None,
+ )
+ .await?;
+
+ bob_stores
+ .signed_pre_key_store
+ .save_signed_pre_key(2, &bob_signed_pre_key_record, None)
+ .await?;
+ bob_stores
+ .pre_key_store
+ .save_pre_key(1, &bob_pre_key, None)
+ .await?;
Ok(())
}
}
diff --git a/libsignal-service/src/sender.rs b/libsignal-service/src/sender.rs
index cebdb7105..b293b1218 100644
--- a/libsignal-service/src/sender.rs
+++ b/libsignal-service/src/sender.rs
@@ -1,19 +1,25 @@
use std::time::SystemTime;
-use crate::cipher::get_preferred_protocol_address;
-use crate::proto::{
- attachment_pointer::AttachmentIdentifier,
- attachment_pointer::Flags as AttachmentPointerFlags, sync_message,
- AttachmentPointer, SyncMessage,
-};
-
use chrono::prelude::*;
-use libsignal_protocol::SessionBuilder;
+use libsignal_protocol::{
+ process_prekey_bundle, IdentityKeyStore, PreKeyStore, ProtocolAddress,
+ SessionStore, SignalProtocolError, SignedPreKeyStore,
+};
use log::{info, trace};
+use rand::{CryptoRng, Rng};
use crate::{
- cipher::ServiceCipher, content::ContentBody, push_service::*,
- sealed_session_cipher::UnidentifiedAccess, ServiceAddress,
+ cipher::{get_preferred_protocol_address, ServiceCipher},
+ content::ContentBody,
+ proto::{
+ attachment_pointer::AttachmentIdentifier,
+ attachment_pointer::Flags as AttachmentPointerFlags, sync_message,
+ AttachmentPointer, SyncMessage,
+ },
+ push_service::*,
+ sealed_session_cipher::UnidentifiedAccess,
+ session_store::SessionStoreExt,
+ ServiceAddress,
};
pub use crate::proto::{ContactDetails, GroupDetails};
@@ -22,7 +28,7 @@ pub use crate::proto::{ContactDetails, GroupDetails};
#[serde(rename_all = "camelCase")]
pub struct OutgoingPushMessage {
pub r#type: u32,
- pub destination_device_id: i32,
+ pub destination_device_id: u32,
pub destination_registration_id: u32,
pub content: String,
}
@@ -66,10 +72,14 @@ pub struct AttachmentSpec {
/// Equivalent of Java's `SignalServiceMessageSender`.
#[derive(Clone)]
-pub struct MessageSender {
+pub struct MessageSender {
service: Service,
- cipher: ServiceCipher,
- device_id: i32,
+ cipher: ServiceCipher,
+ csprng: R,
+ session_store: S,
+ identity_key_store: I,
+ local_address: ServiceAddress,
+ device_id: u32,
}
#[derive(thiserror::Error, Debug)]
@@ -86,7 +96,7 @@ pub enum MessageSenderError {
#[error("{0}")]
ServiceError(#[from] ServiceError),
#[error("protocol error: {0}")]
- ProtocolError(#[from] libsignal_protocol::Error),
+ ProtocolError(#[from] SignalProtocolError),
#[error("Failed to upload attachment {0}")]
AttachmentUploadError(#[from] AttachmentUploadError),
@@ -112,18 +122,31 @@ pub enum MessageSenderError {
IdentityFailure { recipient: ServiceAddress },
}
-impl MessageSender
+impl MessageSender
where
Service: PushService + Clone,
+ S: SessionStore + SessionStoreExt + Clone,
+ I: IdentityKeyStore + Clone,
+ SP: SignedPreKeyStore + Clone,
+ P: PreKeyStore + Clone,
+ R: Rng + CryptoRng + Clone,
{
pub fn new(
service: Service,
- cipher: ServiceCipher,
- device_id: i32,
+ cipher: ServiceCipher,
+ csprng: R,
+ session_store: S,
+ identity_key_store: I,
+ local_address: ServiceAddress,
+ device_id: u32,
) -> Self {
MessageSender {
service,
cipher,
+ csprng,
+ session_store,
+ identity_key_store,
+ local_address,
device_id,
}
}
@@ -326,7 +349,7 @@ where
timestamp,
);
self.try_send_message(
- (&self.cipher.local_address).clone(),
+ (&self.local_address).clone(),
None,
&sync_message,
timestamp,
@@ -340,12 +363,12 @@ where
if end_session {
log::debug!("ending session with {}", recipient);
if let Some(ref uuid) = recipient.uuid {
- self.cipher
- .store_context
- .delete_all_sessions(&uuid.to_string())?;
+ self.session_store
+ .delete_all_sessions(&uuid.to_string())
+ .await?;
}
if let Some(e164) = recipient.e164() {
- self.cipher.store_context.delete_all_sessions(&e164)?;
+ self.session_store.delete_all_sessions(&e164).await?;
}
}
@@ -400,7 +423,7 @@ where
);
self.try_send_message(
- self.cipher.local_address.clone(),
+ self.local_address.clone(),
unidentified_access,
&sync_message,
timestamp,
@@ -461,21 +484,21 @@ where
"dropping session with device {}",
extra_device_id
);
- if let Some(ref uuid) = recipient.uuid {
- self.cipher.store_context.delete_session(
- &libsignal_protocol::Address::new(
+ if let Some(uuid) = recipient.uuid {
+ self.session_store
+ .delete_session(&ProtocolAddress::new(
uuid.to_string(),
*extra_device_id,
- ),
- )?;
+ ))
+ .await?;
}
if let Some(e164) = recipient.e164() {
- self.cipher.store_context.delete_session(
- &libsignal_protocol::Address::new(
- &e164,
+ self.session_store
+ .delete_session(&ProtocolAddress::new(
+ e164,
*extra_device_id,
- ),
- )?;
+ ))
+ .await?;
}
}
@@ -484,23 +507,24 @@ where
"creating session with missing device {}",
missing_device_id
);
+ let remote_address = ProtocolAddress::new(
+ recipient.identifier(),
+ *missing_device_id,
+ );
let pre_key = self
.service
- .get_pre_key(
- &self.cipher.context,
- &recipient,
- *missing_device_id,
- )
+ .get_pre_key(&recipient, *missing_device_id)
.await?;
- SessionBuilder::new(
- &self.cipher.context,
- &self.cipher.store_context,
- &libsignal_protocol::Address::new(
- &recipient.identifier(),
- *missing_device_id,
- ),
+
+ process_prekey_bundle(
+ &remote_address,
+ &mut self.session_store,
+ &mut self.identity_key_store,
+ &pre_key,
+ &mut self.csprng,
+ None,
)
- .process_pre_key_bundle(&pre_key)
+ .await
.map_err(|e| {
log::error!("failed to create session: {}", e);
MessageSenderError::UntrustedIdentity {
@@ -517,20 +541,20 @@ where
extra_device_id
);
if let Some(ref uuid) = recipient.uuid {
- self.cipher.store_context.delete_session(
- &libsignal_protocol::Address::new(
+ self.session_store
+ .delete_session(&ProtocolAddress::new(
uuid.to_string(),
*extra_device_id,
- ),
- )?;
+ ))
+ .await?;
}
if let Some(e164) = recipient.e164() {
- self.cipher.store_context.delete_session(
- &libsignal_protocol::Address::new(
+ self.session_store
+ .delete_session(&ProtocolAddress::new(
e164,
*extra_device_id,
- ),
- )?;
+ ))
+ .await?;
}
}
}
@@ -620,7 +644,7 @@ where
) -> Result, MessageSenderError> {
let mut messages = vec![];
- let myself = recipient.matches(&self.cipher.local_address);
+ let myself = recipient.matches(&self.local_address);
if !myself || unidentified_access.is_some() {
trace!("sending message to default device");
messages.push(
@@ -634,30 +658,17 @@ where
);
}
- // XXX maybe refactor this in a method, this is probably something we need on every call to
- // get_sub_device_sessions.
- let mut sub_device_sessions = Vec::new();
- if let Some(uuid) = &recipient.uuid {
- sub_device_sessions.extend(
- self.cipher
- .store_context
- .get_sub_device_sessions(&uuid.to_string())?,
- );
- }
- if let Some(e164) = &recipient.e164() {
- sub_device_sessions.extend(
- self.cipher.store_context.get_sub_device_sessions(&e164)?,
- );
- }
-
- for device_id in sub_device_sessions {
+ for device_id in
+ recipient.sub_device_sessions(&self.session_store).await?
+ {
trace!("sending message to device {}", device_id);
let ppa = get_preferred_protocol_address(
- &self.cipher.store_context,
- recipient.clone(),
+ &self.session_store,
+ recipient,
device_id,
- )?;
- if self.cipher.store_context.contains_session(&ppa)? {
+ )
+ .await?;
+ if self.session_store.load_session(&ppa, None).await?.is_some() {
messages.push(
self.create_encrypted_message(
recipient,
@@ -680,53 +691,57 @@ where
&mut self,
recipient: &ServiceAddress,
unidentified_access: Option<&UnidentifiedAccess>,
- device_id: i32,
+ device_id: u32,
content: &[u8],
) -> Result {
let recipient_address = get_preferred_protocol_address(
- &self.cipher.store_context,
- recipient.clone(),
+ &self.session_store,
+ recipient,
device_id,
- )?;
+ )
+ .await?;
log::trace!("encrypting message for {:?}", recipient_address);
- if !self
- .cipher
- .store_context
- .contains_session(&recipient_address)?
+ if self
+ .session_store
+ .load_session(&recipient_address, None)
+ .await?
+ .is_none()
{
info!("establishing new session with {:?}", recipient_address);
- let pre_keys = self
- .service
- .get_pre_keys(&self.cipher.context, recipient, device_id)
- .await?;
+ let pre_keys =
+ self.service.get_pre_keys(&recipient, device_id).await?;
for pre_key_bundle in pre_keys {
- if recipient.matches(&self.cipher.local_address)
- && self.device_id == pre_key_bundle.device_id()
+ if recipient.matches(&self.local_address)
+ && self.device_id == pre_key_bundle.device_id()?
{
trace!("not establishing a session with myself!");
continue;
}
let pre_key_address = get_preferred_protocol_address(
- &self.cipher.store_context,
- recipient.clone(),
- pre_key_bundle.device_id(),
- )?;
- let session_builder = SessionBuilder::new(
- &self.cipher.context,
- &self.cipher.store_context,
+ &self.session_store,
+ recipient,
+ pre_key_bundle.device_id()?,
+ )
+ .await?;
+
+ process_prekey_bundle(
&pre_key_address,
- );
- session_builder.process_pre_key_bundle(&pre_key_bundle)?;
+ &mut self.session_store,
+ &mut self.identity_key_store,
+ &pre_key_bundle,
+ &mut self.csprng,
+ None,
+ )
+ .await?;
}
}
- let message = self.cipher.encrypt(
- &recipient_address,
- unidentified_access,
- content,
- )?;
+ let message = self
+ .cipher
+ .encrypt(&recipient_address, unidentified_access, content)
+ .await?;
Ok(message)
}
diff --git a/libsignal-service/src/service_address.rs b/libsignal-service/src/service_address.rs
index 74acdc54f..958edede7 100644
--- a/libsignal-service/src/service_address.rs
+++ b/libsignal-service/src/service_address.rs
@@ -1,6 +1,8 @@
use phonenumber::*;
use uuid::Uuid;
+use crate::{push_service::ServiceError, session_store::SessionStoreExt};
+
#[derive(thiserror::Error, Debug)]
pub enum ParseServiceAddressError {
#[error("Supplied phone number could not be parsed in E164 format")]
@@ -27,6 +29,25 @@ impl ServiceAddress {
.as_ref()
.map(|pn| pn.format().mode(phonenumber::Mode::E164).to_string())
}
+
+ pub async fn sub_device_sessions(
+ &self,
+ session_store: &dyn SessionStoreExt,
+ ) -> Result, ServiceError> {
+ let mut sub_device_sessions = Vec::new();
+ if let Some(uuid) = &self.uuid {
+ sub_device_sessions.extend(
+ session_store
+ .get_sub_device_sessions(&uuid.to_string())
+ .await?,
+ );
+ }
+ if let Some(e164) = &self.e164() {
+ sub_device_sessions
+ .extend(session_store.get_sub_device_sessions(&e164).await?);
+ }
+ Ok(sub_device_sessions)
+ }
}
impl std::fmt::Display for ServiceAddress {
diff --git a/libsignal-service/src/session_store.rs b/libsignal-service/src/session_store.rs
new file mode 100644
index 000000000..942a1cff8
--- /dev/null
+++ b/libsignal-service/src/session_store.rs
@@ -0,0 +1,29 @@
+use async_trait::async_trait;
+use libsignal_protocol::{ProtocolAddress, SessionStore, SignalProtocolError};
+
+/// This is additional functions required to handle
+/// session deletion. It might be a candidate for inclusion into
+/// the bigger `SessionStore` trait.
+#[async_trait(?Send)]
+pub trait SessionStoreExt: SessionStore {
+ /// Get the IDs of all known devices with active sessions for a recipient.
+ async fn get_sub_device_sessions(
+ &self,
+ name: &str,
+ ) -> Result, SignalProtocolError>;
+
+ /// Remove a session record for a recipient ID + device ID tuple.
+ async fn delete_session(
+ &self,
+ address: &ProtocolAddress,
+ ) -> Result<(), SignalProtocolError>;
+
+ /// Remove the session records corresponding to all devices of a recipient
+ /// ID.
+ ///
+ /// Returns the number of deleted sessions.
+ async fn delete_all_sessions(
+ &self,
+ address: &str,
+ ) -> Result;
+}
diff --git a/libsignal-service/src/utils.rs b/libsignal-service/src/utils.rs
index 32f39c77b..d84a28898 100644
--- a/libsignal-service/src/utils.rs
+++ b/libsignal-service/src/utils.rs
@@ -57,7 +57,7 @@ pub mod serde_optional_base64 {
}
pub mod serde_public_key {
- use libsignal_protocol::keys::PublicKey;
+ use libsignal_protocol::PublicKey;
use serde::Serializer;
pub fn serialize(
@@ -67,8 +67,7 @@ pub mod serde_public_key {
where
S: Serializer,
{
- use serde::ser::Error;
- serializer
- .serialize_str(&public_key.to_base64().map_err(S::Error::custom)?)
+ let public_key = public_key.serialize();
+ serializer.serialize_str(&base64::encode(&public_key))
}
}