Skip to content

Commit

Permalink
Merge pull request #64 from FL33TW00D/feature/norm
Browse files Browse the repository at this point in the history
feature: Normalization
  • Loading branch information
FL33TW00D authored Feb 2, 2024
2 parents 4815d4e + f887b67 commit 9b03876
Show file tree
Hide file tree
Showing 18 changed files with 476 additions and 37 deletions.
9 changes: 8 additions & 1 deletion .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@ jobs:
include:
- name: Windows x86_64
os: windows-2022
target: x86_64-pc-windows-msvc
kind: native

- name: MacOS aarch64
os: macos-14
target: aarch64-apple-darwin
kind: native
#TODO: android?
steps:
- uses: actions/checkout@v4
Expand Down Expand Up @@ -74,7 +81,7 @@ jobs:
- name: Setup python
uses: actions/setup-python@v5
with:
python-version: '3.10.6'
python-version: '3.10.11'
cache: 'pip'
- run: pip install -r requirements.txt

Expand Down
3 changes: 2 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ members = [
"crates/ratchet-core",
"crates/ratchet-integration-tests",
"crates/ratchet-loader",
"crates/ratchet-models",
"crates/ratchet-models",
"crates/ratchet-nn",
]
resolver = "2"

Expand Down
56 changes: 47 additions & 9 deletions crates/ratchet-core/build.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use anyhow::Context as anyhowCtx;
use pathdiff;

use std::fs::File;
use std::io::Write;
use std::path::{Path, PathBuf};
Expand Down Expand Up @@ -127,24 +127,36 @@ impl std::fmt::Display for ReindexOp {
impl ReindexOp {
pub fn func_body(&self) -> String {
match self {
ReindexOp::Permute => format!(
r#"
ReindexOp::Permute => r#"
var src_index = vec4<u32>(0u);
src_index[metadata.perm[0]] = dst_index[0];
src_index[metadata.perm[1]] = dst_index[1];
src_index[metadata.perm[2]] = dst_index[2];
src_index[metadata.perm[3]] = dst_index[3];
"#,
),
ReindexOp::Slice => format!(
r#"
"#
.to_string(),
ReindexOp::Slice => r#"
var src_index = dst_index;
"#,
),
"#
.to_string(),
}
}
}

#[derive(Debug, Clone, strum_macros::EnumIter)]
pub enum NormOp {
LayerNorm,
}

impl std::fmt::Display for NormOp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let s = match self {
NormOp::LayerNorm => "layernorm",
};
write!(f, "{}", s)
}
}

#[derive(Debug)]
pub struct KernelGenerator {
tera: Tera,
Expand All @@ -168,6 +180,7 @@ impl KernelGenerator {
self.generate_unary()?;
self.generate_binary()?;
self.generate_reindex()?;
self.generate_norm()?;
Ok(())
}

Expand All @@ -187,6 +200,31 @@ impl KernelGenerator {
Ok(())
}

fn generate_norm(&mut self) -> anyhow::Result<()> {
for op in NormOp::iter() {
for ke in KernelElement::iter() {
let path = self.templates_path.join("layernorm.wgsl");
self.tera.add_template_file(path, Some("layernorm"))?;

let mut context = Context::new();
context.insert("elem", &ke.as_wgsl(WgslDType::F32));
context.insert("elem_size", &ke.as_size());
let reduction_len = match ke {
KernelElement::Scalar => "metadata.N",
KernelElement::Vec2 => "metadata.ND2",
KernelElement::Vec4 => "metadata.ND4",
};
context.insert("reduction_len", reduction_len);
let rendered = self.tera.render("layernorm", &context)?;

let kernel_fname = format!("{}_{}.wgsl", op, ke);
let mut file = File::create(self.dest_path.join(kernel_fname))?;
file.write_all(rendered.as_bytes())?;
}
}
Ok(())
}

fn generate_unary(&mut self) -> anyhow::Result<()> {
for func in UnaryOp::iter() {
for ke in KernelElement::iter() {
Expand Down
102 changes: 102 additions & 0 deletions crates/ratchet-core/kernel-templates/layernorm.wgsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
@group(0) @binding(0)
var<storage, read> X: array<{{ elem }}>;

@group(0) @binding(1)
var<storage, read> S: array<{{ elem }}>;

@group(0) @binding(2)
var<storage, read> B: array<{{ elem }}>;

@group(0) @binding(3)
var<storage, read_write> Y: array<{{ elem }}>;

struct Meta {
M: u32,
N: u32,
ND2: u32,
ND4: u32,
eps: f32,
}

@group(1) @binding(0)
var<uniform> metadata: Meta;

const BLOCK_SIZE: u32 = 128u;

var<workgroup> smem: array<{{ elem }}, BLOCK_SIZE>; //max 16kb

fn block_sum(index: u32, stride: u32) {
if index < stride {
smem[index] += smem[index + stride];
}
workgroupBarrier();
}

fn mu(local_id: vec3<u32>, anchor: u32) -> f32 {
var threadSum = {{ elem }}(0.0);
for (var i: u32 = local_id.x; i < {{ reduction_len }}; i += BLOCK_SIZE) {
threadSum += X[anchor + i];
}
workgroupBarrier();
smem[local_id.x] = threadSum;
workgroupBarrier();

block_sum(local_id.x, 64u);
block_sum(local_id.x, 32u);
block_sum(local_id.x, 16u);
block_sum(local_id.x, 8u);
block_sum(local_id.x, 4u);
block_sum(local_id.x, 2u);
block_sum(local_id.x, 1u);

{% if elem == "f32" -%}
return smem[0] / f32(metadata.N);
{% else -%}
return dot(smem[0], {{ elem }}(1.0)) / f32(metadata.N);
{% endif %}
}

fn sigma(local_id: vec3<u32>, anchor: u32, mu: f32) -> f32 {
var threadSum = {{ elem }}(0.0);
//Compute σ
for (var i: u32 = local_id.x; i < {{ reduction_len }}; i += BLOCK_SIZE) {
let val = X[anchor + i] - mu;
threadSum = fma(val, val, threadSum);
}

workgroupBarrier();
smem[local_id.x] = threadSum;
workgroupBarrier();

block_sum(local_id.x, 64u);
block_sum(local_id.x, 32u);
block_sum(local_id.x, 16u);
block_sum(local_id.x, 8u);
block_sum(local_id.x, 4u);
block_sum(local_id.x, 2u);
block_sum(local_id.x, 1u);

{% if elem == "f32" -%}
return smem[0] / f32(metadata.N);
{% else -%}
return dot(smem[0], {{ elem }}(1.0)) / f32(metadata.N);
{% endif %}
}

@compute @workgroup_size(128, 1, 1)
fn main(
@builtin(local_invocation_id) local_id: vec3<u32>,
@builtin(workgroup_id) group_id: vec3<u32>,
@builtin(global_invocation_id) global_id: vec3<u32>
) {
let anchor = (group_id.y * metadata.M * {{ reduction_len }}) + group_id.x * {{ reduction_len }};
let mu = mu(local_id, anchor);
let sigma = sigma(local_id, anchor, mu);

let denom = inverseSqrt(sigma + {{ elem }}(metadata.eps));

for(var i: u32 = local_id.x; i < {{ reduction_len }}; i += BLOCK_SIZE) {
let val = (X[anchor + i] - mu) * denom;
Y[anchor + i] = fma(val, S[i], B[i]);
}
}
2 changes: 1 addition & 1 deletion crates/ratchet-core/src/compiled_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ impl CompiledOp {
bind_group_layouts: RVec<BindGroupLayoutHandle>,
device: &WgpuDevice,
inplace: bool,
kernel_name: &str,
_kernel_name: &str,
) -> Result<RVec<GpuBindGroup>, OperationError> {
let mut bind_group_entries = drvec![];

Expand Down
9 changes: 5 additions & 4 deletions crates/ratchet-core/src/gpu/buffer_allocator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ impl BufferAllocator {
) -> PooledGPUBuffer {
let buf = self.pool.borrow_mut().get_or_create(desc, device);
device.queue().write_buffer(&buf.inner, 0, contents);
device.queue().submit(None);
device.poll(wgpu::Maintain::Wait);
buf
}

Expand Down Expand Up @@ -110,11 +112,10 @@ impl BufferAllocator {
return GraphBuffer::from(self.create_buffer(&descriptor, device));
}

let result = match closest_index {
match closest_index {
Some(idx) => free.remove(idx),
None => GraphBuffer::from(self.create_buffer(&descriptor, device)),
};
result
}
}

/// # Inplace operations
Expand Down Expand Up @@ -249,7 +250,7 @@ impl BufferAllocator {
source.id(),
just_allocated.inner().global_id(),
);
assignments.insert(source.id(), GraphBuffer::from(just_allocated.clone()));
assignments.insert(source.id(), just_allocated.clone());
}
}

Expand Down
1 change: 0 additions & 1 deletion crates/ratchet-core/src/gpu/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ impl WgpuDevice {
let backends = wgpu::util::backend_bits_from_env().unwrap_or(wgpu::Backends::PRIMARY);
let adapter = instance
.enumerate_adapters(backends)
.into_iter()
.max_by_key(|adapter| match adapter.get_info().device_type {
DeviceType::DiscreteGpu => 5,
DeviceType::Other => 4,
Expand Down
7 changes: 5 additions & 2 deletions crates/ratchet-core/src/gpu/pools/buffer_pool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,15 @@ impl BufferPool {
};
self.inner.get_or_create(&descriptor, |descriptor| {
let (size, usage, mapped_at_creation) = descriptor.fields();
device.create_buffer(&wgpu::BufferDescriptor {
let buf = device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size,
usage,
mapped_at_creation,
})
});
device.queue().submit(None);
device.poll(wgpu::Maintain::Wait);
buf
})
}

Expand Down
12 changes: 12 additions & 0 deletions crates/ratchet-core/src/kernels.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ lazy_static! {
"cos_vec2",
include_str!(r"../kernels/generated/cos_vec2.wgsl"),
);
m.insert(
"layernorm_vec4",
include_str!(r"../kernels/generated/layernorm_vec4.wgsl"),
);
m.insert(
"sub_scalar",
include_str!(r"../kernels/generated/sub_scalar.wgsl"),
Expand Down Expand Up @@ -92,6 +96,10 @@ lazy_static! {
"cos_vec4",
include_str!(r"../kernels/generated/cos_vec4.wgsl"),
);
m.insert(
"layernorm_vec2",
include_str!(r"../kernels/generated/layernorm_vec2.wgsl"),
);
m.insert(
"tanh_scalar",
include_str!(r"../kernels/generated/tanh_scalar.wgsl"),
Expand All @@ -112,6 +120,10 @@ lazy_static! {
"floor_vec2",
include_str!(r"../kernels/generated/floor_vec2.wgsl"),
);
m.insert(
"layernorm_scalar",
include_str!(r"../kernels/generated/layernorm_scalar.wgsl"),
);
m.insert(
"div_scalar",
include_str!(r"../kernels/generated/div_scalar.wgsl"),
Expand Down
6 changes: 5 additions & 1 deletion crates/ratchet-core/src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use crate::gpu::{
PoolError, WgpuDevice, WorkgroupCount, UNIFORM_ALIGN,
};
use crate::{
rvec, Binary, CompiledOp, InvariantError, KernelElement, Matmul, RVec, Reindex, Softmax,
rvec, Binary, CompiledOp, InvariantError, KernelElement, Matmul, Norm, RVec, Reindex, Softmax,
StorageView, Tensor, Unary,
};

Expand All @@ -20,6 +20,7 @@ pub enum LazyOp {
Softmax(Softmax), //Should be custom
Unary(Unary),
Reindex(Reindex),
Norm(Norm),
Const,
}

Expand All @@ -31,6 +32,7 @@ impl LazyOp {
LazyOp::Softmax(s) => s.name(),
LazyOp::Unary(u) => u.name(),
LazyOp::Reindex(r) => r.name(),
LazyOp::Norm(n) => n.name(),
LazyOp::Const => "Const",
}
}
Expand All @@ -42,6 +44,7 @@ impl LazyOp {
LazyOp::Softmax(s) => s.srcs(),
LazyOp::Unary(u) => u.srcs(),
LazyOp::Reindex(r) => r.srcs(),
LazyOp::Norm(n) => n.srcs(),
LazyOp::Const => rvec![], //end of the line kid
_ => unimplemented!(),
}
Expand All @@ -53,6 +56,7 @@ impl LazyOp {
LazyOp::Matmul(m) => m.supports_inplace(),
LazyOp::Softmax(s) => s.supports_inplace(),
LazyOp::Unary(u) => u.supports_inplace(),
LazyOp::Reindex(r) => r.supports_inplace(),
LazyOp::Const => false,
_ => false,
}
Expand Down
2 changes: 1 addition & 1 deletion crates/ratchet-core/src/ops/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ impl MetaOperation for Binary {

fn storage_bind_group_layout(
&self,
inplace: bool,
_inplace: bool,
) -> Result<BindGroupLayoutDescriptor, OperationError> {
/*
if inplace {
Expand Down
2 changes: 2 additions & 0 deletions crates/ratchet-core/src/ops/mod.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
mod binary;
mod matmul;
mod norm;
mod reindex;
mod softmax;
mod unary;

pub use binary::*;
pub use matmul::*;
pub use norm::*;
pub use reindex::*;
pub use softmax::*;
pub use unary::*;
Expand Down
Loading

0 comments on commit 9b03876

Please sign in to comment.