Skip to content

Commit

Permalink
Merge pull request #42 from FL33TW00D/feature/qmm
Browse files Browse the repository at this point in the history
Feature/qmm
  • Loading branch information
FL33TW00D authored Jan 25, 2024
2 parents 90b5a6a + fa2d4c1 commit 6741b07
Show file tree
Hide file tree
Showing 19 changed files with 520 additions and 219 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ jobs:
with:
tool: cargo-nextest

- name: (linux) install llvmpipe, lavapipe, vulkan sdk, alsa
- name: (linux) install lavapipe, vulkan sdk, alsa
if: matrix.os == 'ubuntu-latest'
shell: bash
run: |
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
### [Discord](https://discord.gg/XFe33KQTG4)

A web-first, cross-platform ML framework.

2 changes: 1 addition & 1 deletion crates/ratchet-core/kernels/qgemm_vec4.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ struct Meta {
@group(1) @binding(0)
var<uniform> metadata: Meta;

@compute @workgroup_size({{ workgroup_size_x }}, {{ workgroup_size_y }}, {{ workgroup_size_z }})
@compute @workgroup_size(8,8,1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>
) {
Expand Down
48 changes: 48 additions & 0 deletions crates/ratchet-core/kernels/sgemm_vec2.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
//Unoptimized, only gets 500GFLOP
@group(0) @binding(0)
var<storage, read> A: array<vec2<f32>>;

@group(0) @binding(1)
var<storage, read> B: array<vec2<f32>>;

@group(0) @binding(2)
var<storage, read_write> C: array<vec2<f32>>;

struct Meta {
M: u32,
N: u32,
K: u32,
MD2: u32,
ND2: u32,
KD2: u32,
MD4: u32,
ND4: u32,
KD4: u32,
A_OFFSET: u32,
B_OFFSET: u32,
C_OFFSET: u32,
}

@group(1) @binding(0)
var<uniform> metadata: Meta;

@compute @workgroup_size(8,8,1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>
) {
let a_offset = global_id.z * metadata.A_OFFSET;
let b_offset = global_id.z * metadata.B_OFFSET;
let c_offset = global_id.z * metadata.C_OFFSET;

let cRow = global_id.x;
let cCol = global_id.y;
if (cRow < metadata.M && cCol < metadata.ND2) {
var tmp = vec2<f32>();
for (var k = 0u; k < metadata.KD2; k++) {
let a = A[a_offset + (cRow * metadata.KD2 + k)];
tmp += vec2<f32>(a.x) * B[b_offset + (k * metadata.N + cCol)];
tmp += vec2<f32>(a.y) * B[b_offset + (k * metadata.N + cCol + (1u * metadata.ND2))];
}
C[c_offset + (cRow * metadata.ND2 + cCol)] = tmp;
}
}
53 changes: 53 additions & 0 deletions crates/ratchet-core/kernels/sgemm_vec4.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
//Unoptimized, only gets 500GFLOP
@group(0) @binding(0)
var<storage, read> A: array<vec4<f32>>;

@group(0) @binding(1)
var<storage, read> B: array<vec4<f32>>;

@group(0) @binding(2)
var<storage, read_write> C: array<vec4<f32>>;

struct Meta {
M: u32,
N: u32,
K: u32,
MD2: u32,
ND2: u32,
KD2: u32,
MD4: u32,
ND4: u32,
KD4: u32,
A_OFFSET: u32,
B_OFFSET: u32,
C_OFFSET: u32,
}

@group(1) @binding(0)
var<uniform> metadata: Meta;

@compute @workgroup_size(8,8,1)
fn main(
@builtin(global_invocation_id) global_id: vec3<u32>
) {
let a_offset = global_id.z * metadata.A_OFFSET;
let b_offset = global_id.z * metadata.B_OFFSET;
let c_offset = global_id.z * metadata.C_OFFSET;

let cRow = global_id.x;
let cCol = global_id.y;
if (cRow < metadata.M && cCol < metadata.ND4) {
var tmp = vec4<f32>();
for (var k = 0u; k < metadata.KD4; k++) {
let a = A[a_offset + (cRow * metadata.KD4 + k)];
let b_step = k * metadata.N + cCol; //4 rows per iter
let b_stride = metadata.ND4;

tmp = fma(vec4<f32>(a.x), B[b_offset + b_step], tmp);
tmp = fma(vec4<f32>(a.y), B[b_offset + (b_step + b_stride)], tmp);
tmp = fma(vec4<f32>(a.z), B[b_offset + (b_step + (2u * b_stride))], tmp);
tmp = fma(vec4<f32>(a.w), B[b_offset + (b_step + (3u * b_stride))], tmp);
}
C[c_offset + (cRow * metadata.ND4 + cCol)] = tmp;
}
}
22 changes: 7 additions & 15 deletions crates/ratchet-core/src/compiled_op.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use crate::gpu::{
BindGroupDescriptor, BindGroupEntry, BindGroupLayoutHandle, ComputePipelineHandle,
GpuBindGroup, WgpuDevice, WorkgroupCount,
BindGroupDescriptor, BindGroupLayoutHandle, ComputePipelineHandle, GpuBindGroup, WgpuDevice,
WorkgroupCount,
};
use crate::{drvec, rvec, RVec, Tensor};
use derive_new::new;
use wgpu::DynamicOffset;

//Compiled op represents a single kernel invocation
//TODO: We need to be more general here, enum with encoder.copy_buffer_to_buffer as a COPY
//TODO: We need to be more general here, enum with encoder.copy_buffer_to_buffer as a COPY, and
//compiledOp as compute
#[derive(Debug, new)]
pub struct CompiledOp {
pipeline_handle: ComputePipelineHandle,
Expand All @@ -19,30 +20,21 @@ pub struct CompiledOp {
impl CompiledOp {
const MAX_BINDINGS_PER_GROUP: usize = 4;

//TODO: Should return a Result
pub fn create_storage_bind_groups(
srcs: &[&Tensor],
dst: &Tensor,
bind_group_layouts: RVec<BindGroupLayoutHandle>,
device: &WgpuDevice,
) -> RVec<GpuBindGroup> {
let mut binding_counter: usize = 0;
let mut bind_group_entries = drvec![];

for tensor in srcs.iter().chain(std::iter::once(&dst)) {
let storage_guard = tensor.storage();
let storage = storage_guard.as_ref().unwrap();
let gpu_buf = &storage.try_gpu().unwrap().inner;
bind_group_entries.push(BindGroupEntry {
handle: gpu_buf.handle,
offset: 0,
size: Some(gpu_buf.size().try_into().unwrap()),
});
binding_counter += 1;
bind_group_entries.append(&mut tensor.bindings());
}

let mut storage_groups = rvec![];
for (group_index, bind_group_layout) in bind_group_layouts.iter().enumerate() {
let group_range = Self::group_range(group_index, binding_counter);
let group_range = Self::group_range(group_index, bind_group_entries.len());
let entries = bind_group_entries[group_range].into();
let layout = *bind_group_layout;

Expand Down
41 changes: 41 additions & 0 deletions crates/ratchet-core/src/dtype.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
use std::num::NonZeroU64;

use half::{bf16, f16};
use wgpu::{BufferAddress, BufferSize};

use crate::{rvec, RVec};

#[derive(Debug, Copy, Clone, PartialEq, Eq, Default, Hash)]
pub enum DType {
Expand All @@ -9,6 +14,7 @@ pub enum DType {
F32,
I32,
U32,
WQ8, //Packed Q8 (|--4xQ8(u32)--| |--f32--|)
}

impl DType {
Expand All @@ -21,7 +27,42 @@ impl DType {
DType::F32 => 4,
DType::I32 => 4,
DType::U32 => 4,
DType::WQ8 => 4, //Only works because they're both 4 bytes
}
}

pub fn segments(&self, total_bytes: usize) -> RVec<BufferSegment> {
match self {
DType::WQ8 => {
let weights_size = total_bytes / 5 * 4;
assert!(weights_size % 256 == 0); //storage buffer alignment
let weights = BufferSegment::new(0, Some(weights_size as u64));

let absmax_size = total_bytes - weights_size;
assert!(absmax_size % 256 == 0); //storage buffer alignment
let absmax = BufferSegment::new(weights_size as u64, Some(absmax_size as u64));
rvec![weights, absmax]
}
_ => {
rvec![BufferSegment::new(0, Some(total_bytes as u64))]
}
}
}
}

#[derive(Debug)]
pub struct BufferSegment {
pub offset: BufferAddress,
pub size: Option<BufferSize>,
}

impl BufferSegment {
pub fn new(offset: BufferAddress, size: Option<u64>) -> Self {
if let Some(size) = size {
assert!(size % 256 == 0); //storage buffer alignment
}
let size = size.map(NonZeroU64::new).unwrap();
Self { offset, size }
}
}

Expand Down
1 change: 1 addition & 0 deletions crates/ratchet-core/src/gpu/pools/buffer_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ impl BufferDescriptor {
}
}

//All slotmap keys are COPY
slotmap::new_key_type! { pub struct GpuBufferHandle; }

/// A reference-counter baked buffer.
Expand Down
2 changes: 1 addition & 1 deletion crates/ratchet-core/src/gpu/pools/pipeline_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ impl ComputePipelinePool {
) -> ComputePipelineHandle {
self.inner.get_or_create(desc, |desc| {
let kernel_key = desc.build_kernel_key();
let shader = KERNELS.get(kernel_key.as_str()).unwrap();
let shader = KERNELS.get(kernel_key.as_str()).expect("Kernel not found");
let label = Some(kernel_key.as_str());

let shader_module_desc = wgpu::ShaderModuleDescriptor {
Expand Down
12 changes: 12 additions & 0 deletions crates/ratchet-core/src/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,18 @@ lazy_static! {
r"/Users/fleetwood/Code/ratchet/crates/ratchet-core/kernels/add_scalar.wgsl"
),
);
m.insert(
"sgemm_vec2",
include_str!(
r"/Users/fleetwood/Code/ratchet/crates/ratchet-core/kernels/sgemm_vec2.wgsl"
),
);
m.insert(
"sgemm_vec4",
include_str!(
r"/Users/fleetwood/Code/ratchet/crates/ratchet-core/kernels/sgemm_vec4.wgsl"
),
);
m
};
}
4 changes: 1 addition & 3 deletions crates/ratchet-core/src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::fmt::Debug;
use encase::internal::WriteInto;
use encase::ShaderType;

use crate::gpu::{BindGroupLayoutHandle, CpuUniform, PoolError, WgpuDevice, UNIFORM_ALIGN};
use crate::gpu::{CpuUniform, PoolError, WgpuDevice, UNIFORM_ALIGN};
use crate::{Binary, CompiledOp, InvariantError, Matmul, RVec, StorageView, Tensor};

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -83,8 +83,6 @@ pub trait Operation: Debug + 'static {
device: &WgpuDevice,
) -> Result<CompiledOp, OperationError>;

fn storage_layout(&self, device: &WgpuDevice) -> Result<BindGroupLayoutHandle, OperationError>;

fn check_invariants(srcs: &[&Tensor]) -> Result<(), OperationError>;

/// # Output Inference
Expand Down
11 changes: 4 additions & 7 deletions crates/ratchet-core/src/ops/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ use encase::ShaderType;

use crate::{
gpu::{
BindGroupLayoutDescriptor, BindGroupLayoutHandle, ComputePipelineDescriptor, CpuUniform,
PipelineLayoutDescriptor, WgpuDevice, WorkgroupCount,
BindGroupLayoutDescriptor, ComputePipelineDescriptor, CpuUniform, PipelineLayoutDescriptor,
WgpuDevice, WorkgroupCount,
},
rvec, wgc, CompiledOp, Enforcer, KernelElement, OpMetadata, Operation, OperationError, RVec,
StorageView, Tensor,
Expand Down Expand Up @@ -61,10 +61,6 @@ impl Operation for Binary {
rvec![&self.lhs, &self.rhs]
}

fn storage_layout(&self, device: &WgpuDevice) -> Result<BindGroupLayoutHandle, OperationError> {
Ok(device.get_or_create_bind_group_layout(&BindGroupLayoutDescriptor::binary())?)
}

//TODO: we can refactor this into composite methods and share a single `compile` impl on the
//trait
fn compile(
Expand All @@ -79,7 +75,8 @@ impl Operation for Binary {
let offset = uniform.write(&BinaryMeta { M, N })?;
let wgcx = WorkgroupCount::div_ceil(M as _, 64);

let storage_layout = self.storage_layout(device)?;
let storage_layout =
device.get_or_create_bind_group_layout(&BindGroupLayoutDescriptor::binary())?;
let uniform_layout =
device.get_or_create_bind_group_layout(&BindGroupLayoutDescriptor::uniform())?;
let pipeline_layout = device.get_or_create_pipeline_layout(&PipelineLayoutDescriptor {
Expand Down
Loading

0 comments on commit 6741b07

Please sign in to comment.