Skip to content

Commit

Permalink
Merge pull request #249 from huggingface/feature/refactor-quantization
Browse files Browse the repository at this point in the history
Refactoring quantization
  • Loading branch information
FL33TW00D authored Aug 31, 2024
2 parents 6e9bbf4 + e86e3e8 commit a23fe19
Show file tree
Hide file tree
Showing 7 changed files with 239 additions and 212 deletions.
12 changes: 3 additions & 9 deletions crates/ratchet-core/src/cpu/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
pub mod gemm;

use crate::{
Binary, BinaryOp, CPUBuffer, CPUOperation, Cast, DType, IndexSelect, InvariantError, OpGuards,
Operation, OperationError, Quantization, Quantizer, RVec, Storage, StorageView, Tensor,
dequantize, Binary, BinaryOp, CPUBuffer, CPUOperation, Cast, DType, IndexSelect,
InvariantError, OpGuards, Operation, OperationError, RVec, Storage, StorageView, Tensor,
TensorDType, Unary, UnaryOp,
};
use anyhow::anyhow;
Expand Down Expand Up @@ -228,13 +228,7 @@ fn qindex_select(op: IndexSelect, dst: Tensor) -> Result<Tensor, OperationError>
let src = op.src().deep_clone();

// NOTE: Support for other quantization types is dependent on the corresponding dequantization functions.
let src = match src.dt() {
DType::Q8_0F(_) => {
let quantizer = Quantizer::new(Quantization::SInt8);
quantizer.sint8_dequantize(src)
}
_ => return Err(InvariantError::UnsupportedDType(src.dt()).into()),
};
let src = dequantize(src);
let indices = op.indices().clone();
let dim = op.dim();

Expand Down
56 changes: 55 additions & 1 deletion crates/ratchet-core/src/dtype/blocks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
///
/// We closely follow the memory layout of the original GGUF implementation,
/// but often need 2 variants of each block type for devices that don't support f16.
use crate::{rvec, Align, BufferSegment, RVec};
use crate::{rvec, Align, BufferSegment, DType, RVec, TensorDType};
use derive_new::new;
use half::f16;
use num_traits::{AsPrimitive, Float, FromPrimitive, NumAssign};

/// # Bindings
///
Expand Down Expand Up @@ -163,3 +164,56 @@ impl Segments for Q4_KH {
self.0.segments(numel)
}
}

pub trait Quantized {
type FP: TensorDType + Float + NumAssign + AsPrimitive<i32> + FromPrimitive + Copy + PartialEq;
const PACK_SIZE: usize;
const GROUP_SIZE: usize;
const SF: Self::FP;

const LSHIFT: usize = Self::GROUP_SIZE / Self::PACK_SIZE;
const MASK: i32 = (1 << Self::LSHIFT) - 1;
const RSHIFT: usize = Self::GROUP_SIZE - Self::LSHIFT;

fn dt() -> DType;
}
impl Quantized for Q8_0F {
type FP = f32;
const PACK_SIZE: usize = 4;
const GROUP_SIZE: usize = 32;
const SF: f32 = ((1 << 7) - 1) as f32;

fn dt() -> DType {
DType::Q8_0F(Q8_0F::default())
}
}
impl Quantized for Q8_0H {
type FP = f16;
const PACK_SIZE: usize = 4;
const GROUP_SIZE: usize = 32;
const SF: f16 = f16::from_f32_const(Q8_0F::SF);

fn dt() -> DType {
DType::Q8_0H(Q8_0H::default())
}
}
impl Quantized for Q4_KF {
type FP = f32;
const PACK_SIZE: usize = 8;
const GROUP_SIZE: usize = 32;
const SF: f32 = 7.0;

fn dt() -> DType {
DType::Q4_KF(Q4_KF::default())
}
}
impl Quantized for Q4_KH {
type FP = f16;
const PACK_SIZE: usize = 8;
const GROUP_SIZE: usize = 32;
const SF: f16 = f16::from_f32_const(7.0);

fn dt() -> DType {
DType::Q4_KH(Q4_KH::default())
}
}
21 changes: 9 additions & 12 deletions crates/ratchet-core/src/ops/matmul/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ use std::{cmp::Ordering, mem};

use crate::{
gpu::{BindGroupLayoutDescriptor, CpuUniform},
rvec, DType, Device, GPUOperation, Kernel, KernelElement, KernelKey, KernelMetadata,
quantize, rvec, DType, Device, GPUOperation, Kernel, KernelElement, KernelKey, KernelMetadata,
KernelRenderable, KernelSource, OpGuards, Operation, OperationError, RVec, Shape, StorageView,
Strides, Tensor, WorkgroupSize, Workload, Q4_KF, Q4_KH, Q8_0F, Q8_0H,
};
Expand Down Expand Up @@ -577,7 +577,7 @@ pub enum MatmulKernels {
GEMM(GEMM),
SubgroupGEMV(SubgroupGEMV),
WorkgroupGEMV(WorkgroupGEMV),
Quantized(Quantized),
Quantized(QMatMul),
}

impl KernelRenderable for MatmulKernels {
Expand Down Expand Up @@ -741,7 +741,7 @@ impl GPUOperation for Matmul {
(true, false, false) => {
MatmulKernels::WorkgroupGEMV(WorkgroupGEMV::from_matmul(self, spec))
}
(false, true, _) => MatmulKernels::Quantized(Quantized::from_matmul(self, spec)),
(false, true, _) => MatmulKernels::Quantized(QMatMul::from_matmul(self, spec)),
(false, false, _) => MatmulKernels::GEMM(GEMM::from_matmul(self, spec)),
_ => todo!(),
}
Expand All @@ -754,7 +754,7 @@ mod tests {

use crate::test_util::run_py_prg;

use crate::{shape, Device, DeviceRequest, Quantization, Quantizer};
use crate::{shape, Device, DeviceRequest};

use super::*;

Expand Down Expand Up @@ -955,8 +955,7 @@ def matmul(a, b{}):
let b = Tensor::randn::<f32>(shape![6, 64, 1500], cpu_device.clone());
let ground = ground_truth(&a, &b, None, false, false, false)?;

let quantizer = Quantizer::new(Quantization::SInt8);
let aq = quantizer.sint8_quantize(a);
let aq = quantize::<Q8_0F>(&a);
let a_gpu = aq.to(&device)?;
let b_gpu = b.to(&device)?;
let c_gpu = a_gpu.matmul(b_gpu, false, false)?.resolve()?;
Expand All @@ -976,8 +975,8 @@ def matmul(a, b{}):

let device = Device::request_device(DeviceRequest::GPU).unwrap();
let cpu_device = Device::request_device(DeviceRequest::CPU)?;
let a = Tensor::randn::<f32>(shape![2, 175, 241], cpu_device.clone());
let b = Tensor::randn::<f32>(shape![2, 241, 182], cpu_device.clone());
let a = Tensor::randn::<f32>(shape![2, 175, 240], cpu_device.clone());
let b = Tensor::randn::<f32>(shape![2, 240, 182], cpu_device.clone());
let bias = Some(Tensor::randn::<f32>(shape![182], cpu_device.clone()));

let TRANS_LHS = false;
Expand All @@ -988,8 +987,7 @@ def matmul(a, b{}):
let ground = ground_truth(&a, &b, bias.as_ref(), TRANS_LHS, TRANS_RHS, TRANS_DST)?;

let a_gpu = if QUANT {
let quantizer = Quantizer::new(Quantization::SInt8);
let aq = quantizer.sint8_quantize(a);
let aq = quantize::<Q8_0F>(&a);
aq.to(&device)?
} else {
a.to(&device)?
Expand Down Expand Up @@ -1027,8 +1025,7 @@ def matmul(a, b{}):
let ground = ground_truth(&a, &b, None, TRANS_LHS, TRANS_RHS, TRANS_DST)?;

let a_gpu = if QUANT {
let quantizer = Quantizer::new(Quantization::SInt8);
let aq = quantizer.sint8_quantize(a);
let aq = quantize::<Q8_0F>(&a);
aq.to(&device)?
} else {
a.to(&device)?
Expand Down
8 changes: 4 additions & 4 deletions crates/ratchet-core/src/ops/matmul/quantized.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ pub struct QuantizedMeta {
}

#[derive(Debug, Clone)]
pub struct Quantized {
pub struct QMatMul {
lhs: Tensor,
rhs: Tensor,
bias: Option<Tensor>,
Expand All @@ -26,7 +26,7 @@ pub struct Quantized {
spec: MatmulSpec,
}

impl Quantized {
impl QMatMul {
pub fn from_matmul(matmul: &Matmul, spec: MatmulSpec) -> Self {
let Matmul {
lhs,
Expand All @@ -48,7 +48,7 @@ impl Quantized {
}
}

impl Kernel for Quantized {
impl Kernel for QMatMul {
type Metadata = QuantizedMeta;

fn kernel_name(&self) -> String {
Expand Down Expand Up @@ -103,7 +103,7 @@ impl Kernel for Quantized {
}
}

impl KernelRenderable for Quantized {
impl KernelRenderable for QMatMul {
fn register_bindings<P: WgslPrimitive>(
&self,
builder: &mut WgslKernelBuilder,
Expand Down
15 changes: 7 additions & 8 deletions crates/ratchet-core/src/ops/select.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ mod tests {
use test_strategy::proptest;

use crate::test_util::run_py_prg;
use crate::{rvec, shape, Device, DeviceRequest, Quantization, Quantizer, Shape, Tensor};
use crate::{quantize, rvec, shape, Device, DeviceRequest, Shape, Tensor, Q8_0F};

impl Arbitrary for IndexSelectProblem {
type Parameters = ();
Expand Down Expand Up @@ -324,17 +324,16 @@ def index_select(input, indices):
run_py_prg(prg.to_string(), &[input, indices], &[], input.dt())
}

fn run_index_select_trial(problem: IndexSelectProblem, device: Device, quantize: bool) {
fn run_index_select_trial(problem: IndexSelectProblem, device: Device, quant: bool) {
let IndexSelectProblem {
input_shape,
indices,
} = problem;
let mut input = Tensor::randn::<f32>(input_shape, Device::CPU);

let ground_truth = ground_truth(&input, &indices, 0).unwrap();
if quantize {
let quantizer = Quantizer::new(Quantization::SInt8);
input = quantizer.quantize(input);
if quant {
input = quantize::<Q8_0F>(&input);
}

let input = input.to(&device).unwrap();
Expand All @@ -348,10 +347,10 @@ def index_select(input, indices):
}

#[test]
fn qindex_select() {
fn test_qindex_select() {
let prob = IndexSelectProblem {
input_shape: shape![52000, 1280],
indices: Tensor::from_data(vec![50258, 50259, 50360], shape![3], Device::CPU),
input_shape: shape![256, 32],
indices: Tensor::from_data(vec![64, 192, 255], shape![3], Device::CPU),
};
let device = Device::request_device(DeviceRequest::GPU).unwrap();
run_index_select_trial(prob.clone(), device, true);
Expand Down
Loading

0 comments on commit a23fe19

Please sign in to comment.