diff --git a/core/src/ops/cast.rs b/core/src/ops/cast.rs index fa28127ae6..cca7a569a6 100644 --- a/core/src/ops/cast.rs +++ b/core/src/ops/cast.rs @@ -1,4 +1,7 @@ +use tract_data::itertools::Itertools; + use crate::internal::*; +use crate::plan::eval; pub fn cast(to: DatumType) -> Cast { Cast { to } @@ -24,7 +27,8 @@ impl Cast { Ok(tvec!(tmp.cast_to_dt(self.to)?.into_owned().into_tvalue())) } } else { - Ok(tvec!(input.cast_to_dt(self.to)?.into_owned().into_tvalue())) + let out = input.cast_to_dt(self.to)?; + Ok(tvec!(out.into_owned().into_tvalue())) } } } @@ -102,5 +106,162 @@ impl TypedOp for Cast { Ok(Some(AxisChangeConsequence::new(model, node, None, change))) } + fn codegen( + &self, + model: &TypedModel, + node: &TypedNode, + ) -> TractResult> { + let src_dt = model.node_input_facts(node.id)?[0].datum_type; + if src_dt.is_quantized() && src_dt.size_of() == 1 && self.to.is_float() { + codegen_quant_ew_chain_to_lut(self, model, node) + } else { + Ok(None) + } + } + as_op!(); } + +fn codegen_quant_ew_chain_to_lut( + original_dequant: &Cast, + model: &TypedModel, + origin: &TypedNode, +) -> TractResult> { + let mut current = origin; + let incoming_dt = model.node_input_facts(origin.id)?[0].datum_type; + while let Some(next) = model.single_succ(current.id)? { + /* + let q_params = if let Some(op) = op.op_as::() { + if let Some(mop) = op.0.downcast_ref::() { + Some((mop.scale, mop.zero_point as i32, u8::datum_type())) + } else { + op.0.downcast_ref::() + .map(|mop| (mop.scale, mop.zero_point as i32, i8::datum_type())) + } + } else { + None + }; + */ + let q_dt_dst: Option = + next.op_as::().map(|c| c.to).filter(|dt| dt.is_quantized()); + if let Some(dt) = q_dt_dst { + let (zp, scale) = dt.zp_scale(); + /* + // first, try Op::quantize() on all ops in the chain + let mut patch = TypedModelPatch::default(); + let mut wire: OutletId = patch.tap_model(model, origin.inputs[0])?; + let mut next = model.single_succ(origin.id)?.unwrap(); + loop { + if let Some(op) = next + .op + .quantize(model, dequant, dt, scale, zero_point) + .with_context(|| format!("Quantizing {next}"))? + { + wire = patch.wire_node(&*next.name, op, [wire].as_ref())?[0]; + } else { + break; + } + if next.id == current.id { + patch.shunt_outside(model, OutletId::new(op.id, 0), wire)?; + return Ok(Some(patch)); + } else { + next = model.single_succ(next.id)?.unwrap(); + } + } + */ + // or else make a lookup table + if incoming_dt.is_quantized() && incoming_dt.size_of() == 1 { + return Ok(Some( + transform_quant_seq_to_lut(model, origin.inputs[0], next.id.into()) + .context("Transforming sequence to LUT")?, + )); + } + } + let (input_facts, output_facts) = model.node_facts(next.id)?; + let invariants = next + .op + .axes_mapping(&input_facts, &output_facts) + .with_context(|| format!("Querying invariants for {next}"))?; + if invariants.is_element_wise_unary() { + current = next; + } else { + break; + } + } + Ok(None) +} + +fn transform_quant_seq_to_lut( + model: &TypedModel, + src: OutletId, // wire before the dequant cast + dst: OutletId, // wire after the requant cast +) -> TractResult { + let incoming_dt = model.outlet_fact(src)?.datum_type; + let outgoing_dt = model.outlet_fact(dst)?.datum_type; + ensure!(incoming_dt.is_quantized() && incoming_dt.size_of() == 1); + + let mut adhoc_model = TypedModel::default(); + let wire = adhoc_model.add_source("ad-hoc", incoming_dt.fact([256]))?; + let mut next = model.single_succ(src.node)?.unwrap(); + // plug in dequant + let dequant = model.node(src.node); + let name = &dequant.name; + let mut wire: TVec = tvec!(wire); + while next.id != dst.node { + wire = adhoc_model.wire_node(&*next.name, next.op.clone(), &wire)?; + next = model.single_succ(next.id)?.unwrap(); + } + // plug in quant + wire = adhoc_model.wire_node(&*next.name, next.op.clone(), &wire)?; + adhoc_model.set_output_outlets(&wire)?; + + let input = tensor1(&(0u8..=255).collect_vec()); + let input = input.cast_to_dt(incoming_dt.unquantized())?.cast_to_dt(incoming_dt)?.into_owned(); + let output = SimpleState::new(SimplePlan::new(adhoc_model)?)? + .run_plan_with_eval(tvec!(input.into_tvalue()), |s, op, node, inputs| { + eprintln!("{node} {inputs:?}"); + eval(s, op, node, inputs) + })? + .remove(0); + + let table: &[u8] = match incoming_dt.unquantized() { + DatumType::I8 => unsafe { std::mem::transmute(output.as_slice::()?) }, + DatumType::U8 => output.as_slice::()?, + _ => unreachable!(), + }; + let op = crate::ops::quant::lookup_table((tract_linalg::ops().lut_u8)(table)); + let mut patch = TypedModelPatch::default(); + let mut wire = patch.taps(model, &[src])?; + wire = patch.wire_node(format!("{name}.lut"), op, &wire)?; + wire = patch.wire_node(format!("{name}.cast"), cast(outgoing_dt), &wire)?; + patch.shunt_outside(model, dst, wire[0])?; + Ok(patch) +} + +#[cfg(test)] +mod test { + use super::*; + use crate::ops::nn::sigmoid; + + #[test] + fn test_lut() -> TractResult<()> { + let mut model = TypedModel::default(); + let dt = i8::datum_type().with_zp_scale(0, 0.03); + let src = model.add_source("src", dt.fact(&[10]))?; + let mut wire = model.wire_node("dq", cast(f32::datum_type()), &[src])?; + wire = model.wire_node("sigmoid", sigmoid(), &wire)?; + wire = model.wire_node("q", cast(dt), &wire)?; + model.set_output_outlets(&wire)?; + + let input = + tensor1(&(-5i32..5i32).collect_vec()).cast_to::()?.cast_to_dt(dt)?.into_owned(); + let ref_output = model.clone().into_runnable()?.run(tvec!(input.clone().into_tvalue()))?; + dbg!(&input); + dbg!(&ref_output); + + let codegen = model.into_optimized()?; + assert!(codegen.nodes.len() == 2); // Source then LookupTable + let output = codegen.into_runnable()?.run(tvec!(input.into_tvalue()))?; + output[0].close_enough(&ref_output[0], Approximation::Exact) + } +} diff --git a/core/src/ops/element_wise.rs b/core/src/ops/element_wise.rs index 2dae760207..a33cb1a947 100644 --- a/core/src/ops/element_wise.rs +++ b/core/src/ops/element_wise.rs @@ -165,6 +165,7 @@ macro_rules! element_wise { $(; q: $( [$($typ_dt:ident),*] => $f_f32:expr),*)? $(; cost: $cost:expr )? $(; declutter: $declutter:expr )? + $(; eval_override: $eval_override: expr)? $(; operating_datum_type: $operating_datum_type:expr )? $(; prefix: $prefix:expr )? $(; quantize: $quantize:expr )? @@ -177,6 +178,7 @@ macro_rules! element_wise { format!("{}{}", self.prefix(), stringify!($Op)) } fn eval_in_place(&self, t: &mut Tensor) -> TractResult<()> { + $( return $eval_override(self, t); )? $( $(if t.datum_type() == $typ::datum_type() { let t: &mut[$typ] = t.as_slice_mut::<$typ>()?; diff --git a/core/src/ops/quant.rs b/core/src/ops/quant.rs index 864ad9c2e5..df50ff5b1f 100644 --- a/core/src/ops/quant.rs +++ b/core/src/ops/quant.rs @@ -11,6 +11,7 @@ use tract_linalg::Scaler; use super::binary::TypedBinOp; use super::math::round_ties_to_even; +/* pub fn quantize_linear_f32_u8(x: f32, scale: f32, zero_point: i32) -> u8 { (((x * scale).round() as i32) + zero_point) .clamp(u8::min_value() as i32, u8::max_value() as i32) as u8 @@ -20,7 +21,9 @@ pub fn quantize_linear_f32_i8(x: f32, scale: f32, zero_point: i32) -> i8 { (((x * scale).round() as i32) + zero_point) .clamp(i8::min_value() as i32, i8::max_value() as i32) as i8 } +*/ +/* element_wise_oop!(quantize_linear_u8, QuantizeLinearU8 { scale: f32, @@ -250,24 +253,34 @@ impl TypedOp for DequantizeLinearF32 { as_op!(); } +*/ -element_wise_oop!(lookup_table, +element_wise!(lookup_table, LookupTable { table: Box - }, - [i8] => i8 |op, xs, ys| { - ys.copy_from_slice(xs); + }, ; + eval_override: |op: &LookupTable, xs: &mut Tensor| { + // dbg!(&op.table.table()); + // dbg!(&xs); + let bytes = unsafe { xs.as_bytes_mut() }; + // dbg!(&bytes); + op.table.run(bytes); + // dbg!(&bytes); + Ok(()) +} + /* + [i8] => |op, xs| { unsafe { - let casted = std::slice::from_raw_parts_mut(ys.as_mut_ptr() as *mut u8, ys.len()); + let casted = std::slice::from_raw_parts_mut(xs.as_mut_ptr() as *mut u8, xs.len()); op.table.run(casted); } Ok(()) }, - [u8] => u8 |op, xs, ys| { - ys.copy_from_slice(xs); - op.table.run(ys); + [u8] => |op, xs| { + op.table.run(xs); Ok(()) } + */ ); #[derive(Debug, Clone, Hash)] diff --git a/onnx/src/ops/quant.rs b/onnx/src/ops/quant.rs index acd894ca74..88328a7a9a 100644 --- a/onnx/src/ops/quant.rs +++ b/onnx/src/ops/quant.rs @@ -1,7 +1,7 @@ use crate::model::{OnnxOpRegister, ParsingContext}; use crate::pb::NodeProto; use tract_hir::internal::*; -use tract_hir::ops::quant::*; +use tract_hir::ops::cast::cast; use tract_ndarray::ArrayViewD; pub fn register_all_ops(reg: &mut OnnxOpRegister) { @@ -13,7 +13,7 @@ pub fn register_all_ops(reg: &mut OnnxOpRegister) { fn quantize_linear( _ctx: &ParsingContext, node: &NodeProto, -) -> TractResult<(Box, Vec)> { + ) -> TractResult<(Box, Vec)> { let op = QuantizeLinear::new(Some(2).filter(|_| node.input.len() == 3)); Ok((expand(op), vec![])) } @@ -21,7 +21,7 @@ fn quantize_linear( fn dequantize_linear( _ctx: &ParsingContext, node: &NodeProto, -) -> TractResult<(Box, Vec)> { + ) -> TractResult<(Box, Vec)> { let op = DequantizeLinear::new(Some(2).filter(|_| node.input.len() == 3)); Ok((expand(op), vec![])) } @@ -29,7 +29,7 @@ fn dequantize_linear( fn dynamic_quantize_linear( _ctx: &ParsingContext, _node: &NodeProto, -) -> TractResult<(Box, Vec)> { + ) -> TractResult<(Box, Vec)> { let op = DynamicQuantizeLinear::new(); Ok((expand(op), vec![])) } @@ -39,8 +39,6 @@ pub struct QuantizeLinear { optional_zero_point_input: Option, } - - impl Expansion for QuantizeLinear { fn name(&self) -> Cow { "QuantizeLinear".into() @@ -51,14 +49,14 @@ impl Expansion for QuantizeLinear { s: &mut Solver<'r>, inputs: &'p [TensorProxy], outputs: &'p [TensorProxy], - ) -> TractResult<()> { + ) -> TractResult<()> { check_input_arity(inputs, 2 + self.optional_zero_point_input.is_some() as usize)?; check_output_arity(outputs, 1)?; // s.equals(&inputs[1].rank, 0)?; broken in Onnx test suite s.equals(&inputs[1].datum_type, f32::datum_type())?; if self.optional_zero_point_input.is_some() { s.equals(&outputs[0].datum_type, &inputs[2].datum_type)?; - // s.equals(&inputs[2].rank, 0)?; // broken in Onnx test suite + // s.equals(&inputs[2].rank, 0)?; // broken in Onnx test suite } else { s.equals(&outputs[0].datum_type, u8::datum_type())?; } @@ -71,15 +69,13 @@ impl Expansion for QuantizeLinear { prefix: &str, target: &mut TypedModel, inputs: &[OutletId], - ) -> TractResult> { - use tract_hir::ops::quant::*; + ) -> TractResult> { let scale = target .outlet_fact(inputs[1])? .konst .as_ref() .context("y_scale must be a const")? - .as_slice::()?[0] - .recip(); + .as_slice::()?[0]; let zero_point = if self.optional_zero_point_input.is_some() { target .outlet_fact(inputs[2])? @@ -90,12 +86,10 @@ impl Expansion for QuantizeLinear { } else { rctensor0(0u8) }; - let op: Box = if zero_point.datum_type() == u8::datum_type() { - Box::new(quantize_linear_u8(scale, zero_point.as_slice::()?[0])) - } else { - Box::new(quantize_linear_i8(scale, zero_point.as_slice::()?[0])) - }; - target.wire_node(prefix, op, &[inputs[0]]) + let dst = zero_point.datum_type().with_zp_scale(zero_point.cast_to_scalar::()?, scale); + let quant = target.wire_node(format!("{prefix}.cvt"), cast(dst), &[inputs[0]])?; + // ONNX expect unquantized types + target.wire_node(prefix, cast(dst.unquantized()), &quant) } } @@ -104,8 +98,6 @@ pub struct DequantizeLinear { optional_zero_point_input: Option, } - - impl Expansion for DequantizeLinear { fn name(&self) -> Cow { "DequantizeLinear".into() @@ -116,7 +108,7 @@ impl Expansion for DequantizeLinear { s: &mut Solver<'r>, inputs: &'p [TensorProxy], outputs: &'p [TensorProxy], - ) -> TractResult<()> { + ) -> TractResult<()> { check_input_arity(inputs, 2 + self.optional_zero_point_input.is_some() as usize)?; check_output_arity(outputs, 1)?; // s.equals(&inputs[1].rank, 0)?; broken in Onnx test suite @@ -135,7 +127,7 @@ impl Expansion for DequantizeLinear { prefix: &str, target: &mut TypedModel, inputs: &[OutletId], - ) -> TractResult> { + ) -> TractResult> { let scale = target .outlet_fact(inputs[1])? .konst @@ -152,22 +144,14 @@ impl Expansion for DequantizeLinear { } else { rctensor0(0u8) }; - let op: Box = if zero_point.datum_type() == u8::datum_type() { - Box::new(DequantizeLinearF32::new(scale, zero_point.as_slice::()?[0] as i32)) - } else if zero_point.datum_type() == i8::datum_type() { - Box::new(DequantizeLinearF32::new(scale, zero_point.as_slice::()?[0] as i32)) - } else { - Box::new(DequantizeLinearF32::new(scale, zero_point.as_slice::()?[0])) - }; - target.wire_node(prefix, op, &[inputs[0]]) + let q = target.wire_node(format!("{prefix}.ri_cast"), cast(i32::datum_type().with_zp_scale(zero_point.cast_to_scalar::()?, scale)), &[inputs[0]])?; + target.wire_node(prefix, cast(f32::datum_type()), &q) } } #[derive(Debug, Clone, new, Default, Hash)] pub struct DynamicQuantizeLinear {} - - impl Expansion for DynamicQuantizeLinear { fn name(&self) -> Cow { "DynamicQuantizeLinear".into() @@ -182,7 +166,7 @@ impl Expansion for DynamicQuantizeLinear { s: &mut Solver<'r>, inputs: &'p [TensorProxy], outputs: &'p [TensorProxy], - ) -> TractResult<()> { + ) -> TractResult<()> { check_input_arity(inputs, 1)?; check_output_arity(outputs, 3)?; s.equals(&inputs[0].datum_type, f32::datum_type())?; @@ -201,7 +185,7 @@ impl Expansion for DynamicQuantizeLinear { prefix: &str, target: &mut TypedModel, inputs: &[OutletId], - ) -> TractResult> { + ) -> TractResult> { let op: Box = Box::new(DynamicQuantizeLinearU8::new()); target.wire_node(format!("{prefix}.dynamic_quantize"), op, &[inputs[0]]) } @@ -267,8 +251,6 @@ impl Op for DynamicQuantizeLinearU8 { op_as_typed_op!(); } - - impl EvalOp for DynamicQuantizeLinearU8 { fn is_stateless(&self) -> bool { true @@ -287,7 +269,7 @@ impl EvalOp for DynamicQuantizeLinearU8 { zero_point, input.as_slice::()?, dst.as_slice_mut::()?, - ); + ); let quantized_tensor = dst.into_tvalue(); let scale_tensor = tensor0(scale).into(); @@ -342,7 +324,7 @@ mod tests { ( &[1., 2.1, 1.3, 2.5, 3.34, 4., 1.5, 2.6, 3.9, 4., 3., 2.345], &[64, 134, 83, 159, 213, 255, 96, 166, 249, 255, 191, 149], - ), + ), ]; for (v, quantized_ok) in &data { @@ -356,7 +338,7 @@ mod tests { zero_point, v.as_slice().unwrap(), quantized.as_slice_mut().unwrap(), - ); + ); assert_eq!(quantized.as_slice().unwrap(), *quantized_ok); } }