Skip to content

Commit

Permalink
wip decluttering (de)quant to cast
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Mar 25, 2024
1 parent 9590809 commit c9d2854
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 40 deletions.
6 changes: 5 additions & 1 deletion core/src/ops/cast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ pub struct Cast {

impl Cast {
fn do_eval(&self, input: &Tensor, symbols: &SymbolValues) -> TractResult<TVec<TValue>> {
dbg!(self);
dbg!(input);
if input.datum_type() == self.to {
Ok(tvec!(input.clone().into_tvalue()))
} else if input.datum_type() == TDim::datum_type() {
Expand All @@ -24,7 +26,9 @@ impl Cast {
Ok(tvec!(tmp.cast_to_dt(self.to)?.into_owned().into_tvalue()))
}
} else {
Ok(tvec!(input.cast_to_dt(self.to)?.into_owned().into_tvalue()))
let out = input.cast_to_dt(self.to)?;
dbg!(&out);
Ok(tvec!(out.into_owned().into_tvalue()))
}
}
}
Expand Down
4 changes: 4 additions & 0 deletions core/src/ops/quant.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use tract_linalg::Scaler;
use super::binary::TypedBinOp;
use super::math::round_ties_to_even;

/*
pub fn quantize_linear_f32_u8(x: f32, scale: f32, zero_point: i32) -> u8 {
(((x * scale).round() as i32) + zero_point)
.clamp(u8::min_value() as i32, u8::max_value() as i32) as u8
Expand All @@ -20,7 +21,9 @@ pub fn quantize_linear_f32_i8(x: f32, scale: f32, zero_point: i32) -> i8 {
(((x * scale).round() as i32) + zero_point)
.clamp(i8::min_value() as i32, i8::max_value() as i32) as i8
}
*/

/*
element_wise_oop!(quantize_linear_u8,
QuantizeLinearU8 {
scale: f32,
Expand Down Expand Up @@ -250,6 +253,7 @@ impl TypedOp for DequantizeLinearF32 {
as_op!();
}
*/

element_wise_oop!(lookup_table,
LookupTable {
Expand Down
60 changes: 21 additions & 39 deletions onnx/src/ops/quant.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use crate::model::{OnnxOpRegister, ParsingContext};
use crate::pb::NodeProto;
use tract_hir::internal::*;
use tract_hir::ops::quant::*;
use tract_hir::ops::cast::cast;
use tract_ndarray::ArrayViewD;

pub fn register_all_ops(reg: &mut OnnxOpRegister) {
Expand All @@ -13,23 +13,23 @@ pub fn register_all_ops(reg: &mut OnnxOpRegister) {
fn quantize_linear(
_ctx: &ParsingContext,
node: &NodeProto,
) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
let op = QuantizeLinear::new(Some(2).filter(|_| node.input.len() == 3));
Ok((expand(op), vec![]))
}

fn dequantize_linear(
_ctx: &ParsingContext,
node: &NodeProto,
) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
let op = DequantizeLinear::new(Some(2).filter(|_| node.input.len() == 3));
Ok((expand(op), vec![]))
}

fn dynamic_quantize_linear(
_ctx: &ParsingContext,
_node: &NodeProto,
) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
) -> TractResult<(Box<dyn InferenceOp>, Vec<String>)> {
let op = DynamicQuantizeLinear::new();
Ok((expand(op), vec![]))
}
Expand All @@ -39,8 +39,6 @@ pub struct QuantizeLinear {
optional_zero_point_input: Option<usize>,
}



impl Expansion for QuantizeLinear {
fn name(&self) -> Cow<str> {
"QuantizeLinear".into()
Expand All @@ -51,14 +49,14 @@ impl Expansion for QuantizeLinear {
s: &mut Solver<'r>,
inputs: &'p [TensorProxy],
outputs: &'p [TensorProxy],
) -> TractResult<()> {
) -> TractResult<()> {
check_input_arity(inputs, 2 + self.optional_zero_point_input.is_some() as usize)?;
check_output_arity(outputs, 1)?;
// s.equals(&inputs[1].rank, 0)?; broken in Onnx test suite
s.equals(&inputs[1].datum_type, f32::datum_type())?;
if self.optional_zero_point_input.is_some() {
s.equals(&outputs[0].datum_type, &inputs[2].datum_type)?;
// s.equals(&inputs[2].rank, 0)?; // broken in Onnx test suite
// s.equals(&inputs[2].rank, 0)?; // broken in Onnx test suite
} else {
s.equals(&outputs[0].datum_type, u8::datum_type())?;
}
Expand All @@ -71,15 +69,13 @@ impl Expansion for QuantizeLinear {
prefix: &str,
target: &mut TypedModel,
inputs: &[OutletId],
) -> TractResult<TVec<OutletId>> {
use tract_hir::ops::quant::*;
) -> TractResult<TVec<OutletId>> {
let scale = target
.outlet_fact(inputs[1])?
.konst
.as_ref()
.context("y_scale must be a const")?
.as_slice::<f32>()?[0]
.recip();
.as_slice::<f32>()?[0];
let zero_point = if self.optional_zero_point_input.is_some() {
target
.outlet_fact(inputs[2])?
Expand All @@ -90,12 +86,10 @@ impl Expansion for QuantizeLinear {
} else {
rctensor0(0u8)
};
let op: Box<dyn TypedOp> = if zero_point.datum_type() == u8::datum_type() {
Box::new(quantize_linear_u8(scale, zero_point.as_slice::<u8>()?[0]))
} else {
Box::new(quantize_linear_i8(scale, zero_point.as_slice::<i8>()?[0]))
};
target.wire_node(prefix, op, &[inputs[0]])
let dst = zero_point.datum_type().with_zp_scale(zero_point.cast_to_scalar::<i32>()?, scale);
let quant = target.wire_node(format!("{prefix}.cvt"), cast(dst), &[inputs[0]])?;
// ONNX expect unquantized types
target.wire_node(prefix, cast(dst.unquantized()), &quant)
}
}

Expand All @@ -104,8 +98,6 @@ pub struct DequantizeLinear {
optional_zero_point_input: Option<usize>,
}



impl Expansion for DequantizeLinear {
fn name(&self) -> Cow<str> {
"DequantizeLinear".into()
Expand All @@ -116,7 +108,7 @@ impl Expansion for DequantizeLinear {
s: &mut Solver<'r>,
inputs: &'p [TensorProxy],
outputs: &'p [TensorProxy],
) -> TractResult<()> {
) -> TractResult<()> {
check_input_arity(inputs, 2 + self.optional_zero_point_input.is_some() as usize)?;
check_output_arity(outputs, 1)?;
// s.equals(&inputs[1].rank, 0)?; broken in Onnx test suite
Expand All @@ -135,7 +127,7 @@ impl Expansion for DequantizeLinear {
prefix: &str,
target: &mut TypedModel,
inputs: &[OutletId],
) -> TractResult<TVec<OutletId>> {
) -> TractResult<TVec<OutletId>> {
let scale = target
.outlet_fact(inputs[1])?
.konst
Expand All @@ -152,22 +144,14 @@ impl Expansion for DequantizeLinear {
} else {
rctensor0(0u8)
};
let op: Box<dyn TypedOp> = if zero_point.datum_type() == u8::datum_type() {
Box::new(DequantizeLinearF32::new(scale, zero_point.as_slice::<u8>()?[0] as i32))
} else if zero_point.datum_type() == i8::datum_type() {
Box::new(DequantizeLinearF32::new(scale, zero_point.as_slice::<i8>()?[0] as i32))
} else {
Box::new(DequantizeLinearF32::new(scale, zero_point.as_slice::<i32>()?[0]))
};
target.wire_node(prefix, op, &[inputs[0]])
let q = target.wire_node(format!("{prefix}.ri_cast"), cast(i32::datum_type().with_zp_scale(zero_point.cast_to_scalar::<i32>()?, scale)), &[inputs[0]])?;
target.wire_node(prefix, cast(f32::datum_type()), &q)
}
}

#[derive(Debug, Clone, new, Default, Hash)]
pub struct DynamicQuantizeLinear {}



impl Expansion for DynamicQuantizeLinear {
fn name(&self) -> Cow<str> {
"DynamicQuantizeLinear".into()
Expand All @@ -182,7 +166,7 @@ impl Expansion for DynamicQuantizeLinear {
s: &mut Solver<'r>,
inputs: &'p [TensorProxy],
outputs: &'p [TensorProxy],
) -> TractResult<()> {
) -> TractResult<()> {
check_input_arity(inputs, 1)?;
check_output_arity(outputs, 3)?;
s.equals(&inputs[0].datum_type, f32::datum_type())?;
Expand All @@ -201,7 +185,7 @@ impl Expansion for DynamicQuantizeLinear {
prefix: &str,
target: &mut TypedModel,
inputs: &[OutletId],
) -> TractResult<TVec<OutletId>> {
) -> TractResult<TVec<OutletId>> {
let op: Box<dyn TypedOp> = Box::new(DynamicQuantizeLinearU8::new());
target.wire_node(format!("{prefix}.dynamic_quantize"), op, &[inputs[0]])
}
Expand Down Expand Up @@ -267,8 +251,6 @@ impl Op for DynamicQuantizeLinearU8 {
op_as_typed_op!();
}



impl EvalOp for DynamicQuantizeLinearU8 {
fn is_stateless(&self) -> bool {
true
Expand All @@ -287,7 +269,7 @@ impl EvalOp for DynamicQuantizeLinearU8 {
zero_point,
input.as_slice::<f32>()?,
dst.as_slice_mut::<u8>()?,
);
);

let quantized_tensor = dst.into_tvalue();
let scale_tensor = tensor0(scale).into();
Expand Down Expand Up @@ -342,7 +324,7 @@ mod tests {
(
&[1., 2.1, 1.3, 2.5, 3.34, 4., 1.5, 2.6, 3.9, 4., 3., 2.345],
&[64, 134, 83, 159, 213, 255, 96, 166, 249, 255, 191, 149],
),
),
];

for (v, quantized_ok) in &data {
Expand All @@ -356,7 +338,7 @@ mod tests {
zero_point,
v.as_slice().unwrap(),
quantized.as_slice_mut().unwrap(),
);
);
assert_eq!(quantized.as_slice().unwrap(), *quantized_ok);
}
}
Expand Down

0 comments on commit c9d2854

Please sign in to comment.