From 9540bc02dda7efa2e7c0d1aee61634686683a8db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Gabriel=20F=C3=A9ron?= Date: Sun, 20 Oct 2024 21:45:36 +0200 Subject: [PATCH] Switch to attachments v4 (#330) * Modify Endpoint enum to allow absolute URLs (necessary for CDN2) * Implement all kinds of attachments v4 * Add fmt::Display for Endpoint --- src/account_manager.rs | 12 +- src/configuration.rs | 110 ++++++++-- src/digeststream.rs | 10 +- src/groups_v2/manager.rs | 3 +- src/profile_service.rs | 3 +- src/push_service/account.rs | 9 +- src/push_service/cdn.rs | 340 ++++++++++++++++++++++++++++--- src/push_service/error.rs | 3 + src/push_service/keys.rs | 23 +-- src/push_service/linking.rs | 6 +- src/push_service/mod.rs | 11 +- src/push_service/profile.rs | 8 +- src/push_service/registration.rs | 24 ++- src/push_service/response.rs | 6 +- src/sender.rs | 31 ++- 15 files changed, 482 insertions(+), 117 deletions(-) diff --git a/src/account_manager.rs b/src/account_manager.rs index 9e69aaccb..af8fc1c88 100644 --- a/src/account_manager.rs +++ b/src/account_manager.rs @@ -225,8 +225,7 @@ impl AccountManager { .service .request( Method::GET, - Endpoint::Service, - "/v1/devices/provisioning/code", + Endpoint::service("/v1/devices/provisioning/code"), HttpAuthOverride::NoOverride, )? .send() @@ -254,8 +253,7 @@ impl AccountManager { self.service .request( Method::PUT, - Endpoint::Service, - format!("/v1/provisioning/{}", destination), + Endpoint::service(format!("/v1/provisioning/{destination}")), HttpAuthOverride::NoOverride, )? .json(&ProvisioningMessage { @@ -594,8 +592,7 @@ impl AccountManager { self.service .request( Method::PUT, - Endpoint::Service, - "/v1/accounts/name", + Endpoint::service("/v1/accounts/name"), HttpAuthOverride::NoOverride, )? .json(&Data { @@ -623,8 +620,7 @@ impl AccountManager { self.service .request( Method::PUT, - Endpoint::Service, - "/v1/challenge", + Endpoint::service("/v1/challenge"), HttpAuthOverride::NoOverride, )? .json(&RecaptchaAttributes { diff --git a/src/configuration.rs b/src/configuration.rs index df0bf4b39..19a1b9423 100644 --- a/src/configuration.rs +++ b/src/configuration.rs @@ -1,5 +1,5 @@ use core::fmt; -use std::{collections::HashMap, str::FromStr}; +use std::{borrow::Cow, collections::HashMap, str::FromStr}; use crate::utils::BASE64_RELAXED; use base64::prelude::*; @@ -74,11 +74,96 @@ pub enum SignalServers { } #[derive(Debug)] -pub enum Endpoint { - Service, - Storage, - Cdn(u32), - ContactDiscovery, +pub enum Endpoint<'a> { + Absolute(Url), + Service { + path: Cow<'a, str>, + }, + Storage { + path: Cow<'a, str>, + }, + Cdn { + cdn_id: u32, + path: Cow<'a, str>, + query: Option>, + }, + ContactDiscovery { + path: Cow<'a, str>, + }, +} + +impl fmt::Display for Endpoint<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Endpoint::Absolute(url) => write!(f, "absolute URL {url}"), + Endpoint::Service { path } => { + write!(f, "service API call to {path}") + }, + Endpoint::Storage { path } => { + write!(f, "storage API call to {path}") + }, + Endpoint::Cdn { cdn_id, path, .. } => { + write!(f, "CDN{cdn_id} call to {path}") + }, + Endpoint::ContactDiscovery { path } => { + write!(f, "Contact discovery API call to {path}") + }, + } + } +} + +impl<'a> Endpoint<'a> { + pub fn service(path: impl Into>) -> Self { + Self::Service { path: path.into() } + } + + pub fn cdn(cdn_id: u32, path: impl Into>) -> Self { + Self::Cdn { + cdn_id, + path: path.into(), + query: None, + } + } + + pub fn cdn_url(cdn_id: u32, url: &'a Url) -> Self { + Self::Cdn { + cdn_id, + path: url.path().into(), + query: url.query().map(Into::into), + } + } + + pub fn storage(path: impl Into>) -> Self { + Self::Storage { path: path.into() } + } + + pub fn into_url( + self, + service_configuration: &ServiceConfiguration, + ) -> Result { + match self { + Endpoint::Service { path } => { + service_configuration.service_url.join(&path) + }, + Endpoint::Storage { path } => { + service_configuration.storage_url.join(&path) + }, + Endpoint::Cdn { + ref cdn_id, + path, + query, + } => { + let mut url = service_configuration.cdn_urls[cdn_id].clone(); + url.set_path(&path); + url.set_query(query.as_deref()); + Ok(url) + }, + Endpoint::ContactDiscovery { path } => { + service_configuration.contact_discovery_url.join(&path) + }, + Endpoint::Absolute(url) => Ok(url), + } + } } impl FromStr for SignalServers { @@ -125,6 +210,7 @@ impl From<&SignalServers> for ServiceConfiguration { let mut map = HashMap::new(); map.insert(0, "https://cdn-staging.signal.org".parse().unwrap()); map.insert(2, "https://cdn2-staging.signal.org".parse().unwrap()); + map.insert(3, "https://cdn3-staging.signal.org".parse().unwrap()); map }, contact_discovery_url: @@ -144,6 +230,7 @@ impl From<&SignalServers> for ServiceConfiguration { let mut map = HashMap::new(); map.insert(0, "https://cdn.signal.org".parse().unwrap()); map.insert(2, "https://cdn2.signal.org".parse().unwrap()); + map.insert(3, "https://cdn3.signal.org".parse().unwrap()); map }, contact_discovery_url: "https://api.directory.signal.org".parse().unwrap(), @@ -156,14 +243,3 @@ impl From<&SignalServers> for ServiceConfiguration { } } } - -impl ServiceConfiguration { - pub fn base_url(&self, endpoint: Endpoint) -> &Url { - match endpoint { - Endpoint::Service => &self.service_url, - Endpoint::Storage => &self.storage_url, - Endpoint::Cdn(ref n) => &self.cdn_urls[n], - Endpoint::ContactDiscovery => &self.contact_discovery_url, - } - } -} diff --git a/src/digeststream.rs b/src/digeststream.rs index 626caeca9..cd9f8fa14 100644 --- a/src/digeststream.rs +++ b/src/digeststream.rs @@ -1,4 +1,4 @@ -use std::io::Read; +use std::io::{self, Read, Seek, SeekFrom}; use sha2::{Digest, Sha256}; @@ -7,7 +7,7 @@ pub struct DigestingReader<'r, R> { digest: Sha256, } -impl<'r, R: Read> Read for DigestingReader<'r, R> { +impl<'r, R: Read + Seek> Read for DigestingReader<'r, R> { fn read(&mut self, tgt: &mut [u8]) -> Result { let amount = self.inner.read(tgt)?; self.digest.update(&tgt[..amount]); @@ -15,7 +15,7 @@ impl<'r, R: Read> Read for DigestingReader<'r, R> { } } -impl<'r, R: Read> DigestingReader<'r, R> { +impl<'r, R: Read + Seek> DigestingReader<'r, R> { pub fn new(inner: &'r mut R) -> Self { Self { inner, @@ -23,6 +23,10 @@ impl<'r, R: Read> DigestingReader<'r, R> { } } + pub fn seek(&mut self, from: SeekFrom) -> io::Result { + self.inner.seek(from) + } + pub fn finalize(self) -> Vec { // XXX representation is not ideal, but this leaks to the public interface and I don't // really like exposing the GenericArray. diff --git a/src/groups_v2/manager.rs b/src/groups_v2/manager.rs index a1c7af0be..567c0edea 100644 --- a/src/groups_v2/manager.rs +++ b/src/groups_v2/manager.rs @@ -170,8 +170,7 @@ impl GroupsManager { .push_service .request( Method::GET, - Endpoint::Service, - &path, + Endpoint::service(path), HttpAuthOverride::NoOverride, )? .send() diff --git a/src/profile_service.rs b/src/profile_service.rs index 016658876..4e7ea9b42 100644 --- a/src/profile_service.rs +++ b/src/profile_service.rs @@ -41,8 +41,7 @@ impl ProfileService { self.push_service .request( Method::GET, - Endpoint::Service, - path, + Endpoint::service(path), HttpAuthOverride::NoOverride, )? .send() diff --git a/src/push_service/account.rs b/src/push_service/account.rs index 0cc99ea0b..0c90ad98d 100644 --- a/src/push_service/account.rs +++ b/src/push_service/account.rs @@ -132,8 +132,7 @@ impl PushService { pub async fn whoami(&mut self) -> Result { self.request( Method::GET, - Endpoint::Service, - "/v1/accounts/whoami", + Endpoint::service("/v1/accounts/whoami"), HttpAuthOverride::NoOverride, )? .send() @@ -157,8 +156,7 @@ impl PushService { let devices: DeviceInfoList = self .request( Method::GET, - Endpoint::Service, - "/v1/devices/", + Endpoint::service("/v1/devices/"), HttpAuthOverride::NoOverride, )? .send() @@ -182,8 +180,7 @@ impl PushService { self.request( Method::PUT, - Endpoint::Service, - "/v1/accounts/attributes/", + Endpoint::service("/v1/accounts/attributes/"), HttpAuthOverride::NoOverride, )? .json(&attributes) diff --git a/src/push_service/cdn.rs b/src/push_service/cdn.rs index b87e3e8a3..cf74212b4 100644 --- a/src/push_service/cdn.rs +++ b/src/push_service/cdn.rs @@ -1,8 +1,17 @@ -use std::io::{self, Read}; +use std::{ + collections::HashMap, + io::{self, Read, SeekFrom}, +}; use futures::TryStreamExt; -use reqwest::{multipart::Part, Method}; -use tracing::debug; +use reqwest::{ + header::{CONTENT_LENGTH, CONTENT_RANGE, CONTENT_TYPE, RANGE}, + multipart::Part, + Method, StatusCode, +}; +use serde::Deserialize; +use tracing::{debug, trace}; +use url::Url; use crate::{ configuration::Endpoint, prelude::AttachmentIdentifier, @@ -21,10 +30,29 @@ pub struct AttachmentV2UploadAttributes { date: String, policy: String, signature: String, - // This is different from Java's implementation, - // and I (Ruben) am unsure why they decide to force-parse at upload-time instead of at registration - // time. - attachment_id: u64, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AttachmentUploadForm { + pub cdn: u32, + pub key: String, + pub headers: HashMap, + pub signed_upload_location: Url, +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct AttachmentDigest { + pub digest: Vec, + pub incremental_digest: Option>, + pub incremental_mac_chunk_size: u64, +} + +#[derive(Debug)] +pub struct ResumeInfo { + pub content_range: Option, + pub content_start: u64, } impl PushService { @@ -57,8 +85,7 @@ impl PushService { let response_stream = self .request( Method::GET, - Endpoint::Cdn(cdn_id), - path, + Endpoint::cdn(cdn_id, path), HttpAuthOverride::Unidentified, // CDN requests are always without authentication )? .send() @@ -72,13 +99,12 @@ impl PushService { Ok(response_stream) } - pub(crate) async fn get_attachment_v2_upload_attributes( + pub(crate) async fn get_attachment_v4_upload_attributes( &mut self, - ) -> Result { + ) -> Result { self.request( Method::GET, - Endpoint::Service, - "/v2/attachments/form/upload", + Endpoint::service("/v4/attachments/form/upload"), HttpAuthOverride::NoOverride, )? .send() @@ -90,26 +116,187 @@ impl PushService { .map_err(Into::into) } - /// Upload attachment to CDN - /// - /// Returns attachment ID and the attachment digest - pub async fn upload_attachment( + #[tracing::instrument(skip(self), level=tracing::Level::TRACE)] + pub(crate) async fn get_attachment_resumable_upload_url( &mut self, - attrs: AttachmentV2UploadAttributes, - mut reader: impl Read + Send, - ) -> Result<(u64, Vec), ServiceError> { - let attachment_id = attrs.attachment_id; - let mut digester = - crate::digeststream::DigestingReader::new(&mut reader); + attachment_upload_form: &AttachmentUploadForm, + ) -> Result { + let mut request = self.request( + Method::POST, + Endpoint::Absolute( + attachment_upload_form.signed_upload_location.clone(), + ), + HttpAuthOverride::Unidentified, + )?; - self.post_to_cdn0("attachments/", attrs, "file".into(), &mut digester) - .await?; + for (key, value) in &attachment_upload_form.headers { + request = request.header(key, value); + } + request = request.header(CONTENT_LENGTH, "0"); + + if attachment_upload_form.cdn == 2 { + request = request.header(CONTENT_TYPE, "application/octet-stream"); + } else if attachment_upload_form.cdn == 3 { + request = request + .header("Upload-Defer-Length", "1") + .header("Tus-Resumable", "1.0.0"); + } else { + return Err(ServiceError::UnknownCdnVersion( + attachment_upload_form.cdn, + )); + }; + + Ok(request + .send() + .await? + .service_error_for_status() + .await? + .headers() + .get("location") + .ok_or_else(|| ServiceError::InvalidFrame { + reason: "missing location header in HTTP response", + })? + .to_str() + .map_err(|_| ServiceError::InvalidFrame { + reason: "invalid location header bytes in HTTP response", + })? + .parse()?) + } + + #[tracing::instrument(skip(self))] + async fn get_attachment_resume_info_cdn2( + &mut self, + resumable_url: &Url, + content_length: u64, + ) -> Result { + let response = self + .request( + Method::PUT, + Endpoint::cdn_url(2, resumable_url), + HttpAuthOverride::Unidentified, + )? + .header(CONTENT_RANGE, format!("bytes */{content_length}")) + .send() + .await? + .error_for_status()?; + + let status = response.status(); + + if status.is_success() { + Ok(ResumeInfo { + content_range: None, + content_start: content_length, + }) + } else if status == StatusCode::PERMANENT_REDIRECT { + let offset = + match response.headers().get(RANGE) { + Some(range) => range + .to_str() + .map_err(|_| ServiceError::InvalidFrame { + reason: "invalid format for Range HTTP header", + })? + .split("-") + .nth(1) + .ok_or_else(|| ServiceError::InvalidFrame { + reason: + "invalid value format for Range HTTP header", + })? + .parse::() + .map_err(|_| ServiceError::InvalidFrame { + reason: + "invalid number format for Range HTTP header", + })? + + 1, + None => 0, + }; - Ok((attachment_id, digester.finalize())) + Ok(ResumeInfo { + content_range: Some(format!( + "bytes {}-{}/{}", + offset, + content_length - 1, + content_length + )), + content_start: offset, + }) + } else { + Err(ServiceError::InvalidFrame { + reason: "failed to get resumable upload data from CDN2", + }) + } + } + + #[tracing::instrument(skip(self))] + async fn get_attachment_resume_info_cdn3( + &mut self, + resumable_url: &Url, + headers: &HashMap, + ) -> Result { + let mut request = self + .request( + Method::HEAD, + Endpoint::cdn_url(3, resumable_url), + HttpAuthOverride::Unidentified, + )? + .header("Tus-Resumable", "1.0.0"); + + for (key, value) in headers { + request = request.header(key, value); + } + + let response = request.send().await?.service_error_for_status().await?; + + let upload_offset = response + .headers() + .get("upload-offset") + .ok_or(ServiceError::InvalidFrame { + reason: "no Upload-Offset header in response", + })? + .to_str() + .map_err(|_| ServiceError::InvalidFrame { + reason: "invalid upload-offset header bytes in HTTP response", + })? + .parse() + .map_err(|_| ServiceError::InvalidFrame { + reason: "invalid integer value for Upload-Offset header", + })?; + + Ok(ResumeInfo { + content_range: None, + content_start: upload_offset, + }) + } + + /// Upload attachment + /// + /// Returns attachment ID and the attachment digest + #[tracing::instrument(skip(self, headers, content))] + pub(crate) async fn upload_attachment_v4( + &mut self, + cdn_id: u32, + resumable_url: &Url, + content_type: &str, + content_length: u64, + headers: HashMap, + content: impl std::io::Read + std::io::Seek + Send, + ) -> Result { + if cdn_id == 2 { + self.upload_to_cdn2(resumable_url, content_length, content) + .await + } else { + self.upload_to_cdn3( + resumable_url, + &headers, + content_type, + content_length, + content, + ) + .await + } } #[tracing::instrument(skip(self, upload_attributes, reader))] - pub async fn post_to_cdn0( + pub async fn upload_to_cdn0( &mut self, path: &str, upload_attributes: AttachmentV2UploadAttributes, @@ -142,8 +329,7 @@ impl PushService { let response = self .request( Method::POST, - Endpoint::Cdn(0), - path, + Endpoint::cdn(0, path), HttpAuthOverride::NoOverride, )? .multipart(form) @@ -156,4 +342,100 @@ impl PushService { Ok(()) } + + #[tracing::instrument(skip(self, content))] + async fn upload_to_cdn2( + &mut self, + resumable_url: &Url, + content_length: u64, + mut content: impl std::io::Read + std::io::Seek + Send, + ) -> Result { + let resume_info = self + .get_attachment_resume_info_cdn2(resumable_url, content_length) + .await?; + + let mut digester = + crate::digeststream::DigestingReader::new(&mut content); + + let mut buf = Vec::new(); + digester.read_to_end(&mut buf)?; + + trace!("digested content"); + + let mut request = self.request( + Method::PUT, + Endpoint::cdn_url(2, resumable_url), + HttpAuthOverride::Unidentified, + )?; + + if let Some(content_range) = resume_info.content_range { + request = request.header(CONTENT_RANGE, content_range); + } + + request.body(buf).send().await?.error_for_status()?; + + Ok(AttachmentDigest { + digest: digester.finalize(), + incremental_digest: None, + incremental_mac_chunk_size: 0, + }) + } + + #[tracing::instrument(skip(self, content))] + async fn upload_to_cdn3( + &mut self, + resumable_url: &Url, + headers: &HashMap, + content_type: &str, + content_length: u64, + mut content: impl std::io::Read + std::io::Seek + Send, + ) -> Result { + let resume_info = self + .get_attachment_resume_info_cdn3(resumable_url, headers) + .await?; + + trace!(?resume_info, "got resume info"); + + if resume_info.content_start == content_length { + let mut digester = + crate::digeststream::DigestingReader::new(&mut content); + let mut buf = Vec::new(); + digester.read_to_end(&mut buf)?; + return Ok(AttachmentDigest { + digest: digester.finalize(), + incremental_digest: None, + incremental_mac_chunk_size: 0, + }); + } + + let mut digester = + crate::digeststream::DigestingReader::new(&mut content); + digester.seek(SeekFrom::Start(resume_info.content_start))?; + + let mut buf = Vec::new(); + digester.read_to_end(&mut buf)?; + + trace!("digested content"); + + self.request( + Method::PATCH, + Endpoint::cdn(3, resumable_url.path()), + HttpAuthOverride::Unidentified, + )? + .header("Tus-Resumable", "1.0.0") + .header("Upload-Offset", resume_info.content_start) + .header("Upload-Length", buf.len()) + .header(CONTENT_TYPE, content_type) + .body(buf) + .send() + .await?; + + trace!("attachment uploaded"); + + Ok(AttachmentDigest { + digest: digester.finalize(), + incremental_digest: None, + incremental_mac_chunk_size: 0, + }) + } } diff --git a/src/push_service/error.rs b/src/push_service/error.rs index f10f67866..2aad42db0 100644 --- a/src/push_service/error.rs +++ b/src/push_service/error.rs @@ -93,6 +93,9 @@ pub enum ServiceError { #[error("invalid device name")] InvalidDeviceName, + #[error("Unknown CDN version {0}")] + UnknownCdnVersion(u32), + #[error("HTTP reqwest error: {0}")] Http(#[from] reqwest::Error), } diff --git a/src/push_service/keys.rs b/src/push_service/keys.rs index f4addb013..d2cf6f091 100644 --- a/src/push_service/keys.rs +++ b/src/push_service/keys.rs @@ -32,8 +32,7 @@ impl PushService { ) -> Result { self.request( Method::GET, - Endpoint::Service, - format!("/v2/keys?identity={}", service_id_type), + Endpoint::service(format!("/v2/keys?identity={}", service_id_type)), HttpAuthOverride::NoOverride, )? .send() @@ -52,8 +51,7 @@ impl PushService { ) -> Result<(), ServiceError> { self.request( Method::PUT, - Endpoint::Service, - format!("/v2/keys?identity={}", service_id_type), + Endpoint::service(format!("/v2/keys?identity={}", service_id_type)), HttpAuthOverride::NoOverride, )? .json(&pre_key_state) @@ -76,8 +74,7 @@ impl PushService { let mut pre_key_response: PreKeyResponse = self .request( Method::GET, - Endpoint::Service, - &path, + Endpoint::service(path), HttpAuthOverride::NoOverride, )? .send() @@ -107,8 +104,7 @@ impl PushService { let pre_key_response: PreKeyResponse = self .request( Method::GET, - Endpoint::Service, - &path, + Endpoint::service(path), HttpAuthOverride::NoOverride, )? .send() @@ -131,8 +127,7 @@ impl PushService { let cert: SenderCertificateJson = self .request( Method::GET, - Endpoint::Service, - "/v1/certificate/delivery", + Endpoint::service("/v1/certificate/delivery"), HttpAuthOverride::NoOverride, )? .send() @@ -150,8 +145,7 @@ impl PushService { let cert: SenderCertificateJson = self .request( Method::GET, - Endpoint::Service, - "/v1/certificate/delivery?includeE164=false", + Endpoint::service("/v1/certificate/delivery?includeE164=false"), HttpAuthOverride::NoOverride, )? .send() @@ -190,8 +184,9 @@ impl PushService { } self.request( Method::PUT, - Endpoint::Service, - "/v2/accounts/phone_number_identity_key_distribution", + Endpoint::service( + "/v2/accounts/phone_number_identity_key_distribution", + ), HttpAuthOverride::NoOverride, )? .json(&PniKeyDistributionRequest { diff --git a/src/push_service/linking.rs b/src/push_service/linking.rs index 5d5026c6c..b79929fee 100644 --- a/src/push_service/linking.rs +++ b/src/push_service/linking.rs @@ -62,8 +62,7 @@ impl PushService { ) -> Result { self.request( Method::PUT, - Endpoint::Service, - "/v1/devices/link", + Endpoint::service("/v1/devices/link"), HttpAuthOverride::Identified(http_auth), )? .json(&link_request) @@ -79,8 +78,7 @@ impl PushService { pub async fn unlink_device(&mut self, id: i64) -> Result<(), ServiceError> { self.request( Method::DELETE, - Endpoint::Service, - format!("/v1/devices/{}", id), + Endpoint::service(format!("/v1/devices/{}", id)), HttpAuthOverride::NoOverride, )? .send() diff --git a/src/push_service/mod.rs b/src/push_service/mod.rs index c6cf31fd5..cae93a75c 100644 --- a/src/push_service/mod.rs +++ b/src/push_service/mod.rs @@ -177,15 +177,14 @@ impl PushService { } } - #[tracing::instrument(skip(self, path), fields(path = %path.as_ref()))] + #[tracing::instrument(skip(self), fields(endpoint = %endpoint))] pub fn request( &self, method: Method, endpoint: Endpoint, - path: impl AsRef, auth_override: HttpAuthOverride, ) -> Result { - let url = self.cfg.base_url(endpoint).join(path.as_ref())?; + let url = endpoint.into_url(&self.cfg)?; let mut builder = self.client.request(method, url); builder = match auth_override { @@ -216,8 +215,7 @@ impl PushService { ) -> Result { let span = debug_span!("websocket"); - let endpoint = self.cfg.base_url(Endpoint::Service); - let mut url = endpoint.join(path).expect("valid url"); + let mut url = Endpoint::service(path).into_url(&self.cfg)?; url.set_scheme("wss").expect("valid https base url"); if let Some(credentials) = credentials { @@ -255,8 +253,7 @@ impl PushService { ) -> Result { self.request( Method::GET, - Endpoint::Storage, - "/v1/groups/", + Endpoint::storage("/v1/groups/"), HttpAuthOverride::Identified(credentials), )? .send() diff --git a/src/push_service/profile.rs b/src/push_service/profile.rs index a1a859ac9..c14b59b91 100644 --- a/src/push_service/profile.rs +++ b/src/push_service/profile.rs @@ -92,7 +92,7 @@ impl PushService { address: ServiceAddress, profile_key: Option, ) -> Result { - let endpoint = if let Some(key) = profile_key { + let path = if let Some(key) = profile_key { let version = bincode::serialize(&key.get_profile_key_version( address.aci().expect("profile by ACI ProtocolAddress"), ))?; @@ -105,8 +105,7 @@ impl PushService { // TODO: set locale to en_US self.request( Method::GET, - Endpoint::Service, - &endpoint, + Endpoint::service(path), HttpAuthOverride::NoOverride, )? .send() @@ -171,8 +170,7 @@ impl PushService { let upload_url: Result = self .request( Method::PUT, - Endpoint::Service, - "/v1/profile", + Endpoint::service("/v1/profile"), HttpAuthOverride::NoOverride, )? .json(&command) diff --git a/src/push_service/registration.rs b/src/push_service/registration.rs index 2ab6ea0c2..8521d682a 100644 --- a/src/push_service/registration.rs +++ b/src/push_service/registration.rs @@ -154,8 +154,7 @@ impl PushService { self.request( Method::POST, - Endpoint::Service, - "/v1/registration", + Endpoint::service("/v1/registration"), HttpAuthOverride::NoOverride, )? .json(&RegistrationSessionRequestBody { @@ -198,8 +197,7 @@ impl PushService { self.request( Method::POST, - Endpoint::Service, - "/v1/verification/session", + Endpoint::service("/v1/verification/session"), HttpAuthOverride::Unidentified, )? .json(&VerificationSessionMetadataRequestBody { @@ -240,8 +238,10 @@ impl PushService { self.request( Method::PATCH, - Endpoint::Service, - format!("/v1/verification/session/{}", session_id), + Endpoint::service(format!( + "/v1/verification/session/{}", + session_id + )), HttpAuthOverride::Unidentified, )? .json(&UpdateVerificationSessionRequestBody { @@ -289,8 +289,10 @@ impl PushService { self.request( Method::POST, - Endpoint::Service, - format!("/v1/verification/session/{}/code", session_id), + Endpoint::service(format!( + "/v1/verification/session/{}/code", + session_id + )), HttpAuthOverride::Unidentified, )? .json(&VerificationCodeRequest { transport, client }) @@ -315,8 +317,10 @@ impl PushService { self.request( Method::PUT, - Endpoint::Service, - format!("/v1/verification/session/{}/code", session_id), + Endpoint::service(format!( + "/v1/verification/session/{}/code", + session_id + )), HttpAuthOverride::Unidentified, )? .json(&VerificationCode { diff --git a/src/push_service/response.rs b/src/push_service/response.rs index 9fc02c6e9..4360a4e56 100644 --- a/src/push_service/response.rs +++ b/src/push_service/response.rs @@ -10,8 +10,10 @@ where ServiceError: From<::Error>, { match response.status_code() { - StatusCode::OK => Ok(response), - StatusCode::NO_CONTENT => Ok(response), + StatusCode::OK + | StatusCode::CREATED + | StatusCode::ACCEPTED + | StatusCode::NO_CONTENT => Ok(response), StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => { Err(ServiceError::Unauthorized) }, diff --git a/src/sender.rs b/src/sender.rs index 691e03939..fd7e2218c 100644 --- a/src/sender.rs +++ b/src/sender.rs @@ -221,16 +221,29 @@ where }); // Request upload attributes - let attrs = self + // TODO: we can actually store the upload spec to be able to resume the upload later + // if it fails or stalls (= we should at least split the API calls so clients can decide what to do) + let attachment_upload_form = self .service - .get_attachment_v2_upload_attributes() + .get_attachment_v4_upload_attributes() .instrument(tracing::trace_span!("requesting upload attributes")) .await?; - let (id, digest) = self + let resumable_upload_url = self .service - .upload_attachment(attrs, &mut std::io::Cursor::new(&contents)) - .instrument(tracing::trace_span!("Uploading attachment")) + .get_attachment_resumable_upload_url(&attachment_upload_form) + .await?; + + let attachment_digest = self + .service + .upload_attachment_v4( + attachment_upload_form.cdn, + &resumable_upload_url, + &spec.content_type, + contents.len() as u64, + attachment_upload_form.headers, + &mut std::io::Cursor::new(&contents), + ) .await?; Ok(AttachmentPointer { @@ -238,7 +251,7 @@ where key: Some(key.to_vec()), size: Some(len as u32), // thumbnail: Option>, - digest: Some(digest), + digest: Some(attachment_digest.digest), file_name: spec.file_name, flags: Some( if spec.voice_note == Some(true) { @@ -261,8 +274,10 @@ where .expect("unix epoch in the past") .as_millis() as u64, ), - cdn_number: Some(0), - attachment_identifier: Some(AttachmentIdentifier::CdnId(id)), + cdn_number: Some(attachment_upload_form.cdn), + attachment_identifier: Some(AttachmentIdentifier::CdnKey( + attachment_upload_form.key, + )), ..Default::default() }) }