You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am trying to implement DeepPhase in candle but I am struggling figuring out how to calculate the phase angles from two tensors using atan2 operation.
The text was updated successfully, but these errors were encountered:
Okay I found out that we can implement custom ops. So is this correct?
use candle_core::{CpuStorage,CustomOp2,Layout,Result,Shape,Tensor};pubstructAtan2;implCustomOp2forAtan2{fnname(&self) -> &'static str{"atan2"}fncpu_fwd(&self,s1:&CpuStorage,l1:&Layout,s2:&CpuStorage,l2:&Layout,) -> Result<(CpuStorage,Shape)>{if l1.shape() != l2.shape(){
candle_core::bail!("operands must have the same shape");}let s1 = match l1.contiguous_offsets(){None => candle_core::bail!("input has to be contiguous"),Some((o1, o2)) => &s1.as_slice::<f32>()?[o1..o2],};let s2 = match l2.contiguous_offsets(){None => candle_core::bail!("input has to be contiguous"),Some((o1, o2)) => &s2.as_slice::<f32>()?[o1..o2],};let dst = itertools::zip_eq(s1, s2).map(|(&y,&x)| y.atan2(x)).collect();let storage = candle_core::WithDType::to_cpu_storage_owned(dst);Ok((storage, l1.shape().clone()))}fnbwd(&self,y:&Tensor,x:&Tensor,
_:&Tensor,
_:&Tensor,) -> Result<(Option<Tensor>,Option<Tensor>)>{let d = (x.sqr()? + y.sqr()?)?;Ok((Some(x.div(&d)?),Some(y.div(&d)?.neg()?))))}}
I am trying to implement DeepPhase in candle but I am struggling figuring out how to calculate the phase angles from two tensors using
atan2
operation.The text was updated successfully, but these errors were encountered: