Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Ported quantize.cpp #84

Merged
merged 22 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
feat(ggml): make quantizatin safe
  • Loading branch information
philpax committed Apr 25, 2023
commit d968bfa3892a0cd9e08ce10d99a7f436aa1063b2
85 changes: 42 additions & 43 deletions ggml/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
//! All [Tensor]s are nodes in this computational graph, and values cannot be retrieved until computation is completed.

use std::{
ffi::c_void,
os::raw::{c_int, c_void},
ptr::NonNull,
sync::{Arc, Weak},
};
Expand Down Expand Up @@ -272,7 +272,7 @@ impl Context {
pub unsafe fn op_map_unary(
&self,
a: &Tensor,
fun: unsafe extern "C" fn(cnt: ::std::os::raw::c_int, dst: *mut f32, src: *const f32),
fun: unsafe extern "C" fn(cnt: c_int, dst: *mut f32, src: *const f32),
) -> Tensor {
let tensor =
unsafe { ggml_sys::ggml_map_unary_f32(self.ptr.as_ptr(), a.ptr.as_ptr(), Some(fun)) };
Expand All @@ -298,12 +298,7 @@ impl Context {
&self,
a: &Tensor,
b: &Tensor,
fun: unsafe extern "C" fn(
cnt: ::std::os::raw::c_int,
dst: *mut f32,
src0: *const f32,
src1: *const f32,
),
fun: unsafe extern "C" fn(cnt: c_int, dst: *mut f32, src0: *const f32, src1: *const f32),
) -> Tensor {
let tensor = unsafe {
ggml_sys::ggml_map_binary_f32(
Expand Down Expand Up @@ -690,48 +685,52 @@ fn i64_to_usize(val: i64) -> usize {
usize::try_from(val).unwrap()
}

/// Contains the result of a quantization operation.
pub struct QuantizationResult {
/// The quantized output.
pub output: Vec<u8>,
/// The quantization history.
pub history: Vec<i64>,
}

/// Quantizes `src` into `dst` using `q4_0` quantization.
///
/// # Safety
///
/// You must ensure the arrays passed in are of the correct size.
pub unsafe fn quantize_q4_0(
src: &[f32],
dst: &mut [u8],
n: usize,
k: usize,
hist: &mut [i64],
) -> usize {
unsafe {
ggml_sys::ggml_quantize_q4_0(
src.as_ptr(),
dst.as_mut_ptr() as *mut c_void,
n.try_into().unwrap(),
k.try_into().unwrap(),
hist.as_mut_ptr(),
)
}
/// You must ensure that `src.len() == n_elements`, and `n_elements_0`
/// is the first dimension of `src`.
pub fn quantize_q4_0(src: &[f32], n_elements: usize, n_elements_0: usize) -> QuantizationResult {
quantize_impl(src, n_elements, n_elements_0, ggml_sys::ggml_quantize_q4_0)
}

/// Quantizes `src` into `dst` using `q4_1` quantization.
///
/// # Safety
///
/// You must ensure the arrays passed in are of the correct size.
pub unsafe fn quantize_q4_1(
/// You must ensure that `src.len() == n_elements`, and `n_elements_0`
/// is the first dimension of `src`.
pub fn quantize_q4_1(src: &[f32], n_elements: usize, n_elements_0: usize) -> QuantizationResult {
quantize_impl(src, n_elements, n_elements_0, ggml_sys::ggml_quantize_q4_1)
}

fn quantize_impl(
src: &[f32],
dst: &mut [u8],
n: usize,
k: usize,
hist: &mut [i64],
) -> usize {
unsafe {
ggml_sys::ggml_quantize_q4_1(
n_elements: usize,
n_elements_0: usize,
quantizer: unsafe extern "C" fn(*const f32, *mut c_void, c_int, c_int, *mut i64) -> usize,
) -> QuantizationResult {
assert_eq!(src.len(), n_elements);
assert_eq!(n_elements % n_elements_0, 0);

// A conservative multiplier of 4 is used here.
let mut output = vec![0u8; n_elements * 4];
let mut history = vec![0i64; 16];
let output_size = unsafe {
quantizer(
src.as_ptr(),
dst.as_mut_ptr() as *mut c_void,
n.try_into().unwrap(),
k.try_into().unwrap(),
hist.as_mut_ptr(),
output.as_mut_ptr() as *mut c_void,
n_elements.try_into().unwrap(),
n_elements_0.try_into().unwrap(),
history.as_mut_ptr(),
)
}
};

output.resize(output_size, 0u8);
QuantizationResult { output, history }
}
37 changes: 10 additions & 27 deletions llama-rs/src/quantize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,40 +291,23 @@ impl<F: Fn(QuantizeProgress)> SaveHandler<QuantizeError> for QuantizeSaver<'_, F
_ => unreachable!(),
};

let mut history_current = vec![0; 16];

// A conservative multiplier of 4 is used here.
let mut work = vec![0u8; tensor.n_elements * 4];
let curr_size = match self.quantization_type {
ggml::Type::Q4_0 => unsafe {
ggml::quantize_q4_0(
&data_f32,
&mut work,
tensor.n_elements,
tensor.dims[0],
&mut history_current,
)
},
ggml::Type::Q4_1 => unsafe {
ggml::quantize_q4_1(
&data_f32,
&mut work,
tensor.n_elements,
tensor.dims[0],
&mut history_current,
)
},
let result = match self.quantization_type {
ggml::Type::Q4_0 => {
ggml::quantize_q4_0(&data_f32, tensor.n_elements, tensor.dims[0])
}
ggml::Type::Q4_1 => {
ggml::quantize_q4_1(&data_f32, tensor.n_elements, tensor.dims[0])
}
_ => unreachable!(),
};
let new_data = result.output;

let mut history_new = vec![];
for (i, val) in history_current.iter().enumerate() {
for (i, val) in result.history.iter().enumerate() {
self.history_all[i] += val;
history_new.push(*val as f32 / tensor.n_elements as f32);
}

let new_data = &work[0..curr_size];

(self.progress_callback)(QuantizeProgress::TensorQuantized {
name: tensor_name,
original_size: raw_data.len(),
Expand All @@ -334,7 +317,7 @@ impl<F: Fn(QuantizeProgress)> SaveHandler<QuantizeError> for QuantizeSaver<'_, F

self.total_size_new += new_data.len();

(self.quantization_type, new_data.to_owned())
(self.quantization_type, new_data)
} else {
(self.progress_callback)(QuantizeProgress::TensorSkipped {
name: tensor_name,
Expand Down