Skip to content

Commit

Permalink
Merge pull request #251 from huggingface/various-clippy-fixes
Browse files Browse the repository at this point in the history
Various clippy fixes
  • Loading branch information
FL33TW00D authored Aug 30, 2024
2 parents f3b3cda + 6ae90c5 commit 6e9bbf4
Show file tree
Hide file tree
Showing 15 changed files with 26 additions and 34 deletions.
4 changes: 2 additions & 2 deletions crates/ratchet-core/src/cpu/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ pub mod gemm;

use crate::{
Binary, BinaryOp, CPUBuffer, CPUOperation, Cast, DType, IndexSelect, InvariantError, OpGuards,
Operation, OperationError, Quantization, Quantizer, RVec, Segments, Storage, StorageView,
Tensor, TensorDType, Unary, UnaryOp,
Operation, OperationError, Quantization, Quantizer, RVec, Storage, StorageView, Tensor,
TensorDType, Unary, UnaryOp,
};
use anyhow::anyhow;
use bytemuck::NoUninit;
Expand Down
4 changes: 2 additions & 2 deletions crates/ratchet-core/src/gpu/align.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,11 @@ impl Align for usize {
}

fn align_for_copy(&self) -> usize {
self + &self.calculate_alignment(Self::COPY_BUFFER_ALIGNMENT)
self + self.calculate_alignment(Self::COPY_BUFFER_ALIGNMENT)
}

fn align_for_offset(&self) -> usize {
self + &self.calculate_alignment(Self::STORAGE_BUFFER_OFFSET_ALIGNMENT)
self + self.calculate_alignment(Self::STORAGE_BUFFER_OFFSET_ALIGNMENT)
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/ratchet-core/src/gpu/wgsl/kernel.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ pub trait Kernel: KernelRenderable {
/// with the derive macro: `WgslMetadata`. The author still needs to implement
/// `write_metadata`.
/// 2. Dynamic Metadata - the structure is not known at compile time, so the author must
/// implement both `render` and `write_metadata`.
/// implement both `render` and `write_metadata`.
type Metadata: KernelMetadata + 'static;

fn kernel_name(&self) -> String;
Expand Down
1 change: 0 additions & 1 deletion crates/ratchet-core/src/ops/matmul/gemm.rs
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,6 @@ impl GEMM {
mut kernel_builder: WgslKernelBuilder,
) -> Result<KernelSource, OperationError> {
const ROW_PER_THREAD: usize = 4;
const COL_PER_THREAD: usize = 4;
const TILE_DIM: usize = 32;

let accessor = P::render_type();
Expand Down
4 changes: 0 additions & 4 deletions crates/ratchet-core/src/ops/matmul/subgroup_gemv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,6 @@ impl KernelRenderable for SubgroupGEMV {
const BM: usize = 8;
const BN: usize = 32;

if matches!(self.lhs.dt(), DType::Q8_0F(_) | DType::Q8_0H(_)) {
assert!(TN == 4);
}

let device = self.lhs.device().try_gpu().unwrap();
let mut kernel_builder = WgslKernelBuilder::new(
workgroup_size.clone(),
Expand Down
4 changes: 1 addition & 3 deletions crates/ratchet-core/src/ops/softmax.rs
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,7 @@ impl GPUOperation for Softmax {
type KernelEnum = SoftmaxKernels;

fn select_kernel(&self) -> Self::KernelEnum {
match self {
Self { .. } => SoftmaxKernels::Standard(self.clone()),
}
SoftmaxKernels::Standard(self.clone())
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/ratchet-core/src/storage/cpu_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ impl CPUBuffer {
Self::new(raw)
}

pub unsafe fn into_bytes(self) -> Vec<u8> {
pub fn into_bytes(self) -> Vec<u8> {
Arc::try_unwrap(self.inner).unwrap().into_bytes()
}

Expand Down
2 changes: 0 additions & 2 deletions crates/ratchet-core/src/storage/gpu_buffer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ pub struct GPUBuffer {
}

impl GPUBuffer {
const MIN_SIZE: u64 = 16;

pub fn from_slice<T: NoUninit>(data: &[T], shape: &Shape, device: &WgpuDevice) -> Self {
assert_eq!(data.len(), shape.numel());
Self::from_bytes(
Expand Down
3 changes: 3 additions & 0 deletions crates/ratchet-core/src/storage/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ pub enum Storage {
}

impl Storage {
/// # Safety
///
/// Inherited from the `from_quantized` method of the `CPUBuffer` and `GPUBuffer` structs.
pub unsafe fn from_quantized<T: NoUninit>(data: &[T], device: &Device) -> Self {
match device {
Device::CPU => Storage::CPU(unsafe { CPUBuffer::from_quantized(data) }),
Expand Down
23 changes: 11 additions & 12 deletions crates/ratchet-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,6 @@ impl Inner {
storage,
}
}

pub(crate) fn storage(&self) -> RwLockReadGuard<Option<Storage>> {
self.storage.read()
}
}

impl Tensor {
Expand Down Expand Up @@ -590,6 +586,9 @@ impl Tensor {
Ok(storage.into_bytes())
}

/// # Safety
///
/// Inherited from `Storage::from_quantized`.
pub unsafe fn from_quantized<T: TensorDType, U: AsRef<[T]>>(
data: U,
dt: DType,
Expand Down Expand Up @@ -745,16 +744,16 @@ impl Tensor {
LazyOp::Binary(b) => cpu_binary(b, dst).ok(),
LazyOp::Cast(c) => cpu_cast(c, dst).ok(),
LazyOp::Matmul(m) => m.apply(dst).ok(),
LazyOp::Softmax(s) => todo!(),
LazyOp::RoPE(r) => todo!(),
LazyOp::Softmax(_s) => todo!(),
LazyOp::RoPE(_r) => todo!(),
LazyOp::Unary(u) => cpu_unary(u, dst).ok(),
LazyOp::Reindex(r) => todo!(),
LazyOp::Concat(c) => todo!(),
LazyOp::Norm(n) => todo!(),
LazyOp::Conv(c) => todo!(),
LazyOp::Reindex(_r) => todo!(),
LazyOp::Concat(_c) => todo!(),
LazyOp::Norm(_n) => todo!(),
LazyOp::Conv(_c) => todo!(),
LazyOp::Select(i) => cpu_index_select(i, dst).ok(),
LazyOp::IndexWrite(i) => todo!(),
LazyOp::Cache(c) => todo!(),
LazyOp::IndexWrite(_i) => todo!(),
LazyOp::Cache(_c) => todo!(),
LazyOp::Const => None,
LazyOp::View(_) => None,
}
Expand Down
1 change: 0 additions & 1 deletion crates/ratchet-loader/src/gguf/gguf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,6 @@ pub enum ValueType {
// The value is a UTF-8 non-null-terminated string, with length prepended.
String,
// The value is an array of other values, with the length and type prepended.
///
// Arrays can be nested, and the length of the array is the number of elements in the array, not the number of bytes.
Array,
}
Expand Down
2 changes: 1 addition & 1 deletion crates/ratchet-loader/src/gguf/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl<R: std::io::Seek + std::io::Read, Other: std::io::Write> ReadInto<Other> fo
fn read_u8s_into(&mut self, other: &mut Other, length: usize) -> Result<()> {
let mut temp = vec![0u8; length];
self.read_exact(&mut temp)?;
other.write_all(&mut temp)?;
other.write_all(&temp)?;
Ok(())
}
}
Expand Down
2 changes: 1 addition & 1 deletion crates/ratchet-loader/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ impl TryFrom<u32> for GgmlDType {
}

impl GgmlDType {
pub(crate) fn to_u32(self) -> u32 {
pub fn to_u32(self) -> u32 {
match self {
Self::F32 => 0,
Self::F16 => 1,
Expand Down
2 changes: 1 addition & 1 deletion crates/ratchet-models/src/moondream/generate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pub fn generate(

let mut tos = TokenOutputStream::new(tokenizer);

let img = image::io::Reader::new(std::io::Cursor::new(image_bytes))
let img = image::ImageReader::new(std::io::Cursor::new(image_bytes))
.with_guessed_format()?
.decode()
.unwrap()
Expand Down
4 changes: 2 additions & 2 deletions crates/ratchet-models/src/whisper/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,9 @@ impl Whisper {
}

#[cfg(not(target_arch = "wasm32"))]
pub fn detect_language(&mut self, mel: Tensor) -> anyhow::Result<Language> {
pub fn detect_language(&mut self, _mel: Tensor) -> anyhow::Result<Language> {
panic!("DETECT LANGUAGE NOT IMPLEMENTED");
let audio_ctx = self.encoder.schedule(mel)?.resolve()?;
let audio_ctx = self.encoder.schedule(_mel)?.resolve()?;
let sot = Tensor::from_data([WhisperTokenizer::SOT], shape![1, 1], self.device.clone());

let logits = self.decoder.schedule([audio_ctx, sot])?.full()?.resolve()?;
Expand Down

0 comments on commit 6e9bbf4

Please sign in to comment.