diff --git a/Cargo.toml b/Cargo.toml index 7b00ce902..e1bbb0460 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ ndk-glue = "0.7" [target.'cfg(target_os = "windows")'.dependencies] windows = { version = "0.54.0", features = [ + "implement", "Win32_Media_Audio", "Win32_Foundation", "Win32_Devices_Properties", diff --git a/src/host/wasapi/device.rs b/src/host/wasapi/device.rs index d0c03548b..6ab154dab 100644 --- a/src/host/wasapi/device.rs +++ b/src/host/wasapi/device.rs @@ -11,20 +11,21 @@ use std::mem; use std::os::windows::ffi::OsStringExt; use std::ptr; use std::slice; +use std::sync::mpsc::Sender; use std::sync::OnceLock; use std::sync::{Arc, Mutex, MutexGuard}; use std::time::Duration; use super::com; use super::{windows_err_to_cpal_err, windows_err_to_cpal_err_message}; -use windows::core::Interface; use windows::core::GUID; +use windows::core::{implement, IUnknown, Interface, HRESULT, PCWSTR, PROPVARIANT}; use windows::Win32::Devices::Properties; use windows::Win32::Foundation; use windows::Win32::Media::Audio::IAudioRenderClient; use windows::Win32::Media::{Audio, KernelStreaming, Multimedia}; use windows::Win32::System::Com; -use windows::Win32::System::Com::{StructuredStorage, STGM_READ}; +use windows::Win32::System::Com::{CoTaskMemFree, StringFromIID, StructuredStorage, STGM_READ}; use windows::Win32::System::Threading; use windows::Win32::System::Variant::VT_LPWSTR; @@ -40,10 +41,17 @@ struct IAudioClientWrapper(Audio::IAudioClient); unsafe impl Send for IAudioClientWrapper {} unsafe impl Sync for IAudioClientWrapper {} +#[derive(Debug, Clone)] +enum DeviceType { + DefaultOutput, + DefaultInput, + Specific(Audio::IMMDevice), +} + /// An opaque type that identifies an end point. #[derive(Clone)] pub struct Device { - device: Audio::IMMDevice, + device: DeviceType, /// We cache an uninitialized `IAudioClient` so that we can call functions from it without /// having to create/destroy audio clients all the time. future_audio_client: Arc>>, // TODO: add NonZero around the ptr @@ -275,66 +283,133 @@ unsafe fn format_from_waveformatex_ptr( Some(format) } +#[implement(Audio::IActivateAudioInterfaceCompletionHandler)] +struct CompletionHandler(Sender>); + +fn retrieve_result( + operation: &Audio::IActivateAudioInterfaceAsyncOperation, +) -> windows::core::Result { + let mut result = HRESULT::default(); + let mut interface: Option = None; + unsafe { + operation.GetActivateResult(&mut result, &mut interface)?; + } + result.ok()?; + Ok(interface.unwrap()) +} + +impl Audio::IActivateAudioInterfaceCompletionHandler_Impl for CompletionHandler { + fn ActivateCompleted( + &self, + operation: Option<&Audio::IActivateAudioInterfaceAsyncOperation>, + ) -> windows::core::Result<()> { + let result = retrieve_result(operation.unwrap()); + let _ = self.0.send(result); + Ok(()) + } +} + +#[allow(non_snake_case)] +unsafe fn ActivateAudioInterfaceSync( + deviceinterfacepath: P0, + activationparams: Option<*const PROPVARIANT>, +) -> windows::core::Result +where + P0: windows::core::IntoParam, + T: Interface, +{ + let (sender, receiver) = std::sync::mpsc::channel(); + let completion: Audio::IActivateAudioInterfaceCompletionHandler = + CompletionHandler(sender).into(); + Audio::ActivateAudioInterfaceAsync( + deviceinterfacepath, + &T::IID, + activationparams, + &completion, + )?; + let result = receiver.recv_timeout(Duration::from_secs(2)).unwrap()?; + result.cast() +} + unsafe impl Send for Device {} unsafe impl Sync for Device {} impl Device { pub fn name(&self) -> Result { - unsafe { - // Open the device's property store. - let property_store = self - .device - .OpenPropertyStore(STGM_READ) - .expect("could not open property store"); - - // Get the endpoint's friendly-name property. - let mut property_value = property_store - .GetValue(&Properties::DEVPKEY_Device_FriendlyName as *const _ as *const _) - .map_err(|err| { - let description = - format!("failed to retrieve name from property store: {}", err); - let err = BackendSpecificError { description }; - DeviceNameError::from(err) - })?; + match &self.device { + DeviceType::DefaultOutput => Ok("Default Ouput".to_string()), + DeviceType::DefaultInput => Ok("Default Input".to_string()), + DeviceType::Specific(device) => unsafe { + // Open the device's property store. + let property_store = device + .OpenPropertyStore(STGM_READ) + .expect("could not open property store"); + + // Get the endpoint's friendly-name property. + let mut property_value = property_store + .GetValue(&Properties::DEVPKEY_Device_FriendlyName as *const _ as *const _) + .map_err(|err| { + let description = + format!("failed to retrieve name from property store: {}", err); + let err = BackendSpecificError { description }; + DeviceNameError::from(err) + })?; - let prop_variant = &property_value.as_raw().Anonymous.Anonymous; + let prop_variant = &property_value.as_raw().Anonymous.Anonymous; - // Read the friendly-name from the union data field, expecting a *const u16. - if prop_variant.vt != VT_LPWSTR.0 { - let description = format!( - "property store produced invalid data: {:?}", - prop_variant.vt - ); - let err = BackendSpecificError { description }; - return Err(err.into()); - } - let ptr_utf16 = *(&prop_variant.Anonymous as *const _ as *const *const u16); + // Read the friendly-name from the union data field, expecting a *const u16. + if prop_variant.vt != VT_LPWSTR.0 { + let description = format!( + "property store produced invalid data: {:?}", + prop_variant.vt + ); + let err = BackendSpecificError { description }; + return Err(err.into()); + } + let ptr_utf16 = *(&prop_variant.Anonymous as *const _ as *const *const u16); - // Find the length of the friendly name. - let mut len = 0; - while *ptr_utf16.offset(len) != 0 { - len += 1; - } + // Find the length of the friendly name. + let mut len = 0; + while *ptr_utf16.offset(len) != 0 { + len += 1; + } - // Create the utf16 slice and convert it into a string. - let name_slice = slice::from_raw_parts(ptr_utf16, len as usize); - let name_os_string: OsString = OsStringExt::from_wide(name_slice); - let name_string = match name_os_string.into_string() { - Ok(string) => string, - Err(os_string) => os_string.to_string_lossy().into(), - }; + // Create the utf16 slice and convert it into a string. + let name_slice = slice::from_raw_parts(ptr_utf16, len as usize); + let name_os_string: OsString = OsStringExt::from_wide(name_slice); + let name_string = match name_os_string.into_string() { + Ok(string) => string, + Err(os_string) => os_string.to_string_lossy().into(), + }; - // Clean up the property. - StructuredStorage::PropVariantClear(&mut property_value).ok(); + // Clean up the property. + StructuredStorage::PropVariantClear(&mut property_value).ok(); - Ok(name_string) + Ok(name_string) + }, } } #[inline] fn from_immdevice(device: Audio::IMMDevice) -> Self { Device { - device, + device: DeviceType::Specific(device), + future_audio_client: Arc::new(Mutex::new(None)), + } + } + + #[inline] + fn default_output() -> Self { + Device { + device: DeviceType::DefaultOutput, + future_audio_client: Arc::new(Mutex::new(None)), + } + } + + #[inline] + fn default_input() -> Self { + Device { + device: DeviceType::DefaultInput, future_audio_client: Arc::new(Mutex::new(None)), } } @@ -349,9 +424,25 @@ impl Device { } let audio_client: Audio::IAudioClient = unsafe { - // can fail if the device has been disconnected since we enumerated it, or if - // the device doesn't support playback for some reason - self.device.Activate(Com::CLSCTX_ALL, None)? + match &self.device { + DeviceType::DefaultOutput => { + let default_audio = StringFromIID(&Audio::DEVINTERFACE_AUDIO_RENDER)?; + let result = ActivateAudioInterfaceSync(PCWSTR(default_audio.as_ptr()), None); + CoTaskMemFree(Some(default_audio.as_ptr() as _)); + result? + } + DeviceType::DefaultInput => { + let default_audio = StringFromIID(&Audio::DEVINTERFACE_AUDIO_CAPTURE)?; + let result = ActivateAudioInterfaceSync(PCWSTR(default_audio.as_ptr()), None); + CoTaskMemFree(Some(default_audio.as_ptr() as _)); + result? + } + DeviceType::Specific(device) => { + // can fail if the device has been disconnected since we enumerated it, or if + // the device doesn't support playback for some reason + device.Activate(Com::CLSCTX_ALL, None)? + } + } }; *lock = Some(IAudioClientWrapper(audio_client)); @@ -518,8 +609,14 @@ impl Device { } pub(crate) fn data_flow(&self) -> Audio::EDataFlow { - let endpoint = Endpoint::from(self.device.clone()); - endpoint.data_flow() + match &self.device { + DeviceType::DefaultOutput => Audio::eRender, + DeviceType::DefaultInput => Audio::eCapture, + DeviceType::Specific(device) => { + let endpoint = Endpoint::from(device.clone()); + endpoint.data_flow() + } + } } pub fn default_input_config(&self) -> Result { @@ -769,40 +866,47 @@ impl Device { impl PartialEq for Device { #[inline] fn eq(&self, other: &Device) -> bool { - // Use case: In order to check whether the default device has changed - // the client code might need to compare the previous default device with the current one. - // The pointer comparison (`self.device == other.device`) don't work there, - // because the pointers are different even when the default device stays the same. - // - // In this code section we're trying to use the GetId method for the device comparison, cf. - // https://docs.microsoft.com/en-us/windows/desktop/api/mmdeviceapi/nf-mmdeviceapi-immdevice-getid - unsafe { - struct IdRAII(windows::core::PWSTR); - /// RAII for device IDs. - impl Drop for IdRAII { - fn drop(&mut self) { - unsafe { Com::CoTaskMemFree(Some(self.0 .0 as *mut _)) } - } - } - // GetId only fails with E_OUTOFMEMORY and if it does, we're probably dead already. - // Plus it won't do to change the device comparison logic unexpectedly. - let id1 = self.device.GetId().expect("cpal: GetId failure"); - let id1 = IdRAII(id1); - let id2 = other.device.GetId().expect("cpal: GetId failure"); - let id2 = IdRAII(id2); - // 16-bit null-terminated comparison. - let mut offset = 0; - loop { - let w1: u16 = *(id1.0).0.offset(offset); - let w2: u16 = *(id2.0).0.offset(offset); - if w1 == 0 && w2 == 0 { - return true; - } - if w1 != w2 { - return false; + match (&self.device, &other.device) { + (DeviceType::DefaultOutput, DeviceType::DefaultOutput) => true, + (DeviceType::DefaultInput, DeviceType::DefaultInput) => true, + (DeviceType::Specific(dev1), DeviceType::Specific(dev2)) => { + // Use case: In order to check whether the default device has changed + // the client code might need to compare the previous default device with the current one. + // The pointer comparison (`self.device == other.device`) don't work there, + // because the pointers are different even when the default device stays the same. + // + // In this code section we're trying to use the GetId method for the device comparison, cf. + // https://docs.microsoft.com/en-us/windows/desktop/api/mmdeviceapi/nf-mmdeviceapi-immdevice-getid + unsafe { + struct IdRAII(windows::core::PWSTR); + /// RAII for device IDs. + impl Drop for IdRAII { + fn drop(&mut self) { + unsafe { Com::CoTaskMemFree(Some(self.0 .0 as *mut _)) } + } + } + // GetId only fails with E_OUTOFMEMORY and if it does, we're probably dead already. + // Plus it won't do to change the device comparison logic unexpectedly. + let id1 = dev1.GetId().expect("cpal: GetId failure"); + let id1 = IdRAII(id1); + let id2 = dev2.GetId().expect("cpal: GetId failure"); + let id2 = IdRAII(id2); + // 16-bit null-terminated comparison. + let mut offset = 0; + loop { + let w1: u16 = *(id1.0).0.offset(offset); + let w2: u16 = *(id2.0).0.offset(offset); + if w1 == 0 && w2 == 0 { + return true; + } + if w1 != w2 { + return false; + } + offset += 1; + } } - offset += 1; } + _ => false, } } } @@ -914,23 +1018,25 @@ impl Iterator for Devices { } } -fn default_device(data_flow: Audio::EDataFlow) -> Option { - unsafe { - let device = get_enumerator() - .0 - .GetDefaultAudioEndpoint(data_flow, Audio::eConsole) - .ok()?; - // TODO: check specifically for `E_NOTFOUND`, and panic otherwise - Some(Device::from_immdevice(device)) - } -} +//fn default_device(data_flow: Audio::EDataFlow) -> Option { +// unsafe { +// let device = get_enumerator() +// .0 +// .GetDefaultAudioEndpoint(data_flow, Audio::eConsole) +// .ok()?; +// // TODO: check specifically for `E_NOTFOUND`, and panic otherwise +// Some(Device::from_immdevice(device)) +// } +//} pub fn default_input_device() -> Option { - default_device(Audio::eCapture) + //default_device(Audio::eCapture) + Some(Device::default_input()) } pub fn default_output_device() -> Option { - default_device(Audio::eRender) + //default_device(Audio::eRender) + Some(Device::default_output()) } /// Get the audio clock used to produce `StreamInstant`s.