Skip to content

Commit

Permalink
Merge pull request #45 from FL33TW00D/feature/janitor-duty
Browse files Browse the repository at this point in the history
Feature/janitor duty
  • Loading branch information
FL33TW00D authored Jan 26, 2024
2 parents f46cc47 + 8f62718 commit 198025d
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 86 deletions.
10 changes: 3 additions & 7 deletions .github/workflows/ratbot.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,9 @@ jobs:
uses: actions/github-script@v6
with:
script: |
const codeReport = `
\`\`\`\n
```
${{ steps.scc.outputs.scc }}
```
\`\`\`
`;
const codeReport = "\\`\\`\\`\\n" +
${{ steps.scc.outputs.scc }} +
"\\n\\`\\`\\`";
github.rest.issues.createComment({
issue_number: context.issue.number,
owner: context.repo.owner,
Expand Down
13 changes: 5 additions & 8 deletions crates/ratchet-core/src/compiled_op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::gpu::{
BindGroupDescriptor, BindGroupLayoutHandle, ComputePipelineHandle, GpuBindGroup, WgpuDevice,
WorkgroupCount,
};
use crate::{drvec, rvec, RVec, Tensor};
use crate::{drvec, rvec, OperationError, RVec, Tensor};
use derive_new::new;
use wgpu::DynamicOffset;

Expand All @@ -20,14 +20,13 @@ pub struct CompiledOp {
impl CompiledOp {
const MAX_BINDINGS_PER_GROUP: usize = 4;

//TODO: Should return a Result
pub fn create_storage_bind_groups(
srcs: &[&Tensor],
dst: &Tensor,
bind_group_layouts: RVec<BindGroupLayoutHandle>,
device: &WgpuDevice,
inplace: bool,
) -> RVec<GpuBindGroup> {
) -> Result<RVec<GpuBindGroup>, OperationError> {
let mut bind_group_entries = drvec![];

for tensor in srcs.iter() {
Expand All @@ -44,12 +43,10 @@ impl CompiledOp {
let entries = bind_group_entries[group_range].into();
let layout = *bind_group_layout;

let bind_group = device
.get_or_create_bind_group(&BindGroupDescriptor { entries, layout })
.unwrap();
storage_groups.push(bind_group);
let bg = device.get_or_create_bind_group(&BindGroupDescriptor { entries, layout })?;
storage_groups.push(bg);
}
storage_groups
Ok(storage_groups)
}

/// Determines which bindings belong to which bind group
Expand Down
1 change: 1 addition & 0 deletions crates/ratchet-core/src/dtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ impl DType {
}
}

//TODO: use a different method, total_bytes won't work with 256 byte padding
pub fn segments(&self, total_bytes: usize) -> RVec<BufferSegment> {
match self {
DType::WQ8 => {
Expand Down
26 changes: 16 additions & 10 deletions crates/ratchet-core/src/gpu/buffer_allocator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,24 @@ impl BufferAllocator {
/// from the actual source (i.e the first non-inplace operation)
///
/// On what conditions do we terminate the upward traversal?
/// 1. We reach a constant
/// 2. We reach an operation that does not support inplace
/// 3. We reach an operation that has more than one consumer
/// 4. We reach an operation that has more than one source
fn traverse_upwards_for_inplace(source: &Tensor) -> &Tensor {
/// 1. We reach an operation that does not support inplace
/// 2. We reach an operation that has more than one consumer
/// 3. We reach an operation that has more than one source
fn determine_tensor_source<'a>(source: &'a Tensor, execution_order: &[Tensor]) -> &'a Tensor {
let mut true_source = source;
loop {
let is_const = true_source.op().is_const();
let cant_inplace = !true_source.op().supports_inplace();
let multiple_sources = true_source.op().srcs().len() > 1;
let multiple_consumers = false; //TODO: implement
if cant_inplace || multiple_sources || multiple_consumers || is_const {
let ts_index = execution_order
.iter()
.position(|t| t.id() == true_source.id())
.unwrap();
let multiple_consumers = execution_order[ts_index + 1..]
.iter()
.filter(|t| t.op().srcs().contains(&true_source))
.count()
> 1;
if cant_inplace || multiple_sources || multiple_consumers {
break;
}

Expand Down Expand Up @@ -169,7 +175,7 @@ impl BufferAllocator {
// If the current tensor is an inplace operation,
// we traverse upwards until we find a non-inplace operation.
for source in t.op().srcs() {
let true_source = Self::traverse_upwards_for_inplace(source);
let true_source = Self::determine_tensor_source(source, execution_order);
assignments.entry(true_source.id()).or_insert_with(|| {
self.graph_allocate(
BufferDescriptor::new(
Expand All @@ -194,7 +200,7 @@ impl BufferAllocator {
//We know we need an allocation for the output.
//We traverse upwards until we find the first non-inplace operation, and use it's buffer.
let output = execution_order.last().unwrap();
let output_source = Self::traverse_upwards_for_inplace(output);
let output_source = Self::determine_tensor_source(output, execution_order);

//If output source is allocated, we can use it's buffer
//Otherwise, we need to allocate a new buffer
Expand Down
34 changes: 14 additions & 20 deletions crates/ratchet-core/src/gpu/uniform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::num::NonZeroU64;

use crate::{
gpu::{BindGroupEntry, BindGroupLayoutDescriptor},
rvec,
rvec, OperationError,
};

use super::{BindGroupDescriptor, GpuBindGroup, PooledGPUBuffer, WgpuDevice};
Expand Down Expand Up @@ -32,26 +32,20 @@ impl CpuUniform {
}

/// Consumes the CPU repr of the uniform buffer and writes to the GPU.
pub(crate) fn into_gpu(self, device: &WgpuDevice) -> GpuUniform {
let uniform_buf = device.create_uniform_init(self);
let bind_group_layout = device
.get_or_create_bind_group_layout(&BindGroupLayoutDescriptor::uniform())
.unwrap();
let bind_group = device
.get_or_create_bind_group(&BindGroupDescriptor {
entries: rvec![BindGroupEntry {
handle: uniform_buf.handle,
offset: 0,
size: NonZeroU64::new(uniform_buf.size()),
}],
layout: bind_group_layout,
})
.unwrap();
pub(crate) fn into_gpu(self, device: &WgpuDevice) -> Result<GpuUniform, OperationError> {
let buf = device.create_uniform_init(self);
let layout =
device.get_or_create_bind_group_layout(&BindGroupLayoutDescriptor::uniform())?;
let bind_group = device.get_or_create_bind_group(&BindGroupDescriptor {
entries: rvec![BindGroupEntry {
handle: buf.handle,
offset: 0,
size: NonZeroU64::new(buf.size()),
}],
layout,
})?;

GpuUniform {
buf: uniform_buf,
bind_group,
}
Ok(GpuUniform { buf, bind_group })
}
}

Expand Down
5 changes: 2 additions & 3 deletions crates/ratchet-core/src/op.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ use crate::gpu::{CpuUniform, PoolError, WgpuDevice, UNIFORM_ALIGN};
use crate::{rvec, Binary, CompiledOp, InvariantError, Matmul, RVec, Softmax, StorageView, Tensor};

#[derive(Clone, Debug)]
#[non_exhaustive]
pub enum LazyOp {
Dummy(Tensor),
Matmul(Matmul),
Binary(Binary),
Softmax(Softmax),
Expand All @@ -21,7 +21,6 @@ impl LazyOp {
LazyOp::Binary(b) => b.srcs(),
LazyOp::Matmul(m) => m.srcs(),
LazyOp::Softmax(s) => s.srcs(),
LazyOp::Dummy(t) => rvec![t],
LazyOp::Const => rvec![], //end of the line kid
_ => unimplemented!(),
}
Expand All @@ -32,7 +31,7 @@ impl LazyOp {
LazyOp::Binary(b) => b.supports_inplace(),
LazyOp::Matmul(m) => m.supports_inplace(),
LazyOp::Softmax(s) => s.supports_inplace(),
LazyOp::Const => true,
LazyOp::Const => false,
_ => false,
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/ratchet-core/src/ops/binary.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ impl Operation for Binary {
rvec![storage_layout],
device,
false,
);
)?;

Ok(CompiledOp::new(
pipeline_handle,
Expand Down
7 changes: 6 additions & 1 deletion crates/ratchet-core/src/ops/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,8 @@ pub struct Matmul {
rhs: Tensor,
}

impl Matmul {}

#[allow(clippy::too_many_arguments)]
#[derive(Debug, Clone, ShaderType)]
pub struct MatmulMeta {
Expand Down Expand Up @@ -297,7 +299,7 @@ impl Operation for Matmul {
rvec![storage_layout],
device,
false,
);
)?;

Ok(CompiledOp::new(
pipeline_handle,
Expand All @@ -308,6 +310,9 @@ impl Operation for Matmul {
}

fn infer_output(&self, srcs: &[&Tensor]) -> Result<StorageView, OperationError> {
let (_a, _b) = (srcs[0], srcs[1]);
//let c_shape = Matmul::compute_output_shape(a.clone(), b.clone()).unwrap();

//TODO: THIS IS WRONG 🚨
Ok(srcs[0].view().clone())
}
Expand Down
3 changes: 1 addition & 2 deletions crates/ratchet-core/src/ops/softmax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ impl Operation for Softmax {
rvec![storage_layout],
device,
can_inplace,
);
)?;

Ok(CompiledOp::new(
pipeline_handle,
Expand All @@ -98,7 +98,6 @@ impl Operation for Softmax {
}

fn infer_output(&self, srcs: &[&Tensor]) -> Result<StorageView, OperationError> {
//TODO: FIX
Ok(srcs[0].view().clone())
}

Expand Down
51 changes: 17 additions & 34 deletions crates/ratchet-core/src/tensor.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use crate::gpu::{BindGroupEntry, CpuUniform, WgpuDevice};
use crate::{
ops::*, rvec, shape, strides, CPUBuffer, CompiledOp, DType, Device, DeviceStorage, Executable,
ops::*, rvec, shape, CPUBuffer, CompiledOp, DType, Device, DeviceStorage, Executable,
GPUBuffer, Operation, OperationError, RVec, RawCPUBuffer, Shape, Storage, Strides, TensorDType,
TensorId,
};
Expand Down Expand Up @@ -55,15 +55,6 @@ impl Tensor {
Self::new(op, meta, None, device)
}

pub fn dummy(src: Tensor) -> Self {
Self::new(
LazyOp::Dummy(src),
StorageView::new(shape![], DType::F32, Strides::default()),
None,
Device::CPU,
)
}

fn update_storage(&self, storage: Storage) {
*self.inner.storage.write() = Some(storage);
}
Expand Down Expand Up @@ -270,12 +261,12 @@ impl Tensor {
let handle = gpu_buf.inner().handle;
let segments = self.dt().segments(gpu_buf.inner().size() as usize);
segments.iter().fold(rvec![], |mut entries, segment| {
let entry = BindGroupEntry {
let (offset, size) = (segment.offset, segment.size);
entries.push(BindGroupEntry {
handle,
offset: segment.offset,
size: segment.size,
};
entries.push(entry);
offset,
size,
});
entries
})
}
Expand All @@ -298,19 +289,14 @@ impl Tensor {
if visited.contains(&tensor) {
continue;
}
match &tensor.inner.op {
LazyOp::Const => {}
LazyOp::Binary(b) => {
stack.extend(b.srcs().into_iter().cloned());
}
LazyOp::Matmul(m) => {
stack.extend(m.srcs().into_iter().cloned());
}
LazyOp::Softmax(s) => {
stack.extend(s.srcs().into_iter().cloned());
}
let srcs = match &tensor.inner.op {
LazyOp::Const => rvec![],
LazyOp::Binary(b) => b.srcs(),
LazyOp::Matmul(m) => m.srcs(),
LazyOp::Softmax(s) => s.srcs(),
_ => unimplemented!(),
}
};
stack.extend(srcs.into_iter().cloned());
visited.push(tensor);
}
visited.reverse();
Expand All @@ -337,7 +323,6 @@ impl Tensor {
let device = self.device().try_gpu()?;

let execution_order = self.execution_order();
println!("EXECUTION ORDER: \n{:#?}", execution_order);
let mut compiled_ops = Vec::with_capacity(execution_order.len());
let allocations = device.allocate_cfg(&execution_order, device)?;

Expand All @@ -353,6 +338,7 @@ impl Tensor {
t.update_storage(Storage::GPU(storage));
}

//Can inplace && only 1 consumer
let can_inplace = t.op().supports_inplace()
&& execution_order[tix + 1..]
.iter()
Expand All @@ -364,7 +350,7 @@ impl Tensor {
compiled_ops.push(compiled_op);
}
}
let executable = Executable::new(compiled_ops, uniform.into_gpu(device));
let executable = Executable::new(compiled_ops, uniform.into_gpu(device)?);
let index = executable.dispatch_operations(device).unwrap();
device.poll(wgpu::MaintainBase::WaitForSubmissionIndex(index));
Ok(())
Expand Down Expand Up @@ -470,6 +456,7 @@ impl Tensor {
}
}

#[derive(Default)]
struct CloseStats {
total_error: f32,
max_abs_error: f32,
Expand All @@ -483,13 +470,9 @@ struct CloseStats {
impl CloseStats {
fn new(atol: f32, rtol: f32) -> Self {
Self {
total_error: 0.0,
max_abs_error: 0.0,
max_abs_error_idxs: None,
element_count: 0,
fail_count: 0,
atol,
rtol,
..Default::default()
}
}

Expand Down

0 comments on commit 198025d

Please sign in to comment.