Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Ported quantize.cpp #84

Merged
merged 22 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
{
"rust-analyzer.cargo.features": ["convert"]
"rust-analyzer.cargo.features": ["convert", "quantize"]
}
16 changes: 16 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

82 changes: 41 additions & 41 deletions ggml-sys/ggml/ggml.c
Original file line number Diff line number Diff line change
Expand Up @@ -1962,7 +1962,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();

/* Prepare the constants we will need during execution */
/* Prepare the constants we will need during execution */
const __m256i lowMask = _mm256_set1_epi8( 0xF );
const __m256i offset_8 = _mm256_set1_epi16( 8 );

Expand All @@ -1973,60 +1973,60 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
// Main loop
for (int i = 0; i < nb; i+=UNROLL_COUNT) {

// This loop will be unrolled by the compiler
// This loop will be unrolled by the compiler
for (int u=0;u<UNROLL_COUNT;u++) {
/* Compute combined scale for the block */
const __m256 scale = _mm256_mul_ps(
_mm256_broadcast_ss( &x[i+u].d ),
_mm256_broadcast_ss( &y[i+u].d ) );

/* get input from x
Input: 32 Nibbles (16 bytes) at *x[i+u]
Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */
/* Load 16 bytes from memory */
const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
/* Expand bytes into uint16_t values */
const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
/* Compute combined scale for the block */
const __m256 scale = _mm256_mul_ps(
_mm256_broadcast_ss( &x[i+u].d ),
_mm256_broadcast_ss( &y[i+u].d ) );

/* get input from x
Input: 32 Nibbles (16 bytes) at *x[i+u]
Output: 2 vectors with 16 values of type int16_t (x_high_q, x_low_q) */

/* Load 16 bytes from memory */
const __m128i tmp_x = _mm_loadu_si128( ( const __m128i* ) x[i+u].qs);
/* Expand bytes into uint16_t values */
const __m256i bytes_x = _mm256_cvtepu8_epi16(tmp_x);
/* Unpack values into individual bytes */
__m256i x_low_q = _mm256_and_si256( lowMask, bytes_x );
const __m256i pre_shift_x_high_q = _mm256_andnot_si256( lowMask, bytes_x );
__m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
__m256i x_high_q = _mm256_srli_epi16( pre_shift_x_high_q, 4 );
/* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );
x_high_q = _mm256_sub_epi16( x_high_q, offset_8 );
x_low_q = _mm256_sub_epi16( x_low_q, offset_8 );

/* get input from y
Input: 32 Nibbles (16 bytes) at *y[i+u]
Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */
/* get input from y
Input: 32 Nibbles (16 bytes) at *y[i+u]
Output: 2 vectors with 16 values of type int16_t (y_high_q, y_low_q) */

/* Load 16 bytes from memory */
const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
/* Expand bytes into uint16_t values */
const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
/* Load 16 bytes from memory */
const __m128i tmp_y = _mm_loadu_si128( (const __m128i* ) y[i+u].qs);
/* Expand bytes into uint16_t values */
const __m256i bytes_y = _mm256_cvtepu8_epi16(tmp_y);
/* Unpack values into individual bytes */
const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
__m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
__m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
const __m256i pre_shift_y_high_q = _mm256_andnot_si256( lowMask, bytes_y );
__m256i y_high_q = _mm256_srli_epi16( pre_shift_y_high_q, 4 );
__m256i y_low_q = _mm256_and_si256( lowMask, bytes_y );
/* Now we have two vectors with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval. */
y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );
y_high_q = _mm256_sub_epi16( y_high_q, offset_8 );
y_low_q = _mm256_sub_epi16( y_low_q, offset_8 );

/* Compute products of int16_t integers, add pairwise, store as int32_t */
__m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
__m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );
/* Compute products of int16_t integers, add pairwise, store as int32_t */
__m256i xy_high_q = _mm256_madd_epi16( x_high_q, y_high_q );
__m256i xy_low_q = _mm256_madd_epi16( x_low_q, y_low_q );

/* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
__m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );
/* Accumulate the products of int32_t integers -> we now have a vector of 8 int_32t */
__m256i xy_q = _mm256_add_epi32( xy_high_q, xy_low_q );

/* Convert to vectore of 8 int32_t to 8 floats */
__m256 q = _mm256_cvtepi32_ps( xy_q );
/* Convert to vectore of 8 int32_t to 8 floats */
__m256 q = _mm256_cvtepi32_ps( xy_q );

/* Multiply q with scale and accumulate */
acc = _mm256_fmadd_ps( scale, q, acc );
/* Multiply q with scale and accumulate */
acc = _mm256_fmadd_ps( scale, q, acc );
}
}

}

// Return horizontal sum of the acc vector
__m128 res = _mm256_extractf128_ps( acc, 1 );
Expand Down
46 changes: 46 additions & 0 deletions ggml/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -466,3 +466,49 @@ fn i32_to_usize(val: i32) -> usize {
fn i64_to_usize(val: i64) -> usize {
usize::try_from(val).unwrap()
}

/// Quantizes `src` into `dst` using `q4_0` quantization.
///
/// # Safety
///
/// You must ensure the arrays passed in are of the correct size.
philpax marked this conversation as resolved.
Show resolved Hide resolved
pub unsafe fn quantize_q4_0(
src: &[f32],
dst: &mut [f32],
n: i32,
k: i32,
hist: &mut [i64],
) -> usize {
unsafe {
ggml_sys::ggml_quantize_q4_0(
src.as_ptr(),
dst.as_mut_ptr() as *mut c_void,
n,
k,
hist.as_mut_ptr(),
)
}
}

/// Quantizes `src` into `dst` using `q4_1` quantization.
///
/// # Safety
///
/// You must ensure the arrays passed in are of the correct size.
pub unsafe fn quantize_q4_1(
src: &[f32],
dst: &mut [f32],
n: i32,
k: i32,
hist: &mut [i64],
) -> usize {
unsafe {
ggml_sys::ggml_quantize_q4_1(
src.as_ptr(),
dst.as_mut_ptr() as *mut c_void,
n,
k,
hist.as_mut_ptr(),
)
}
}
2 changes: 1 addition & 1 deletion llama-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
llama-rs = { path = "../llama-rs", features = ["convert"] }
llama-rs = { path = "../llama-rs", features = ["convert", "quantize"] }

rand = { workspace = true }

Expand Down
16 changes: 15 additions & 1 deletion llama-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ pub enum Args {
///
/// For reference, see [the PR](https://github.com/rustformers/llama-rs/pull/83).
Convert(Box<Convert>),

/// Quantize a GGML model to 4-bit.
Quantize(Box<Quantize>),
}

#[derive(Parser, Debug)]
Expand Down Expand Up @@ -244,7 +247,7 @@ fn parse_bias(s: &str) -> Result<TokenBias, String> {
pub struct ModelLoad {
/// Where to load the model path from
#[arg(long, short = 'm')]
pub model_path: String,
pub model_path: PathBuf,

/// Sets the size of the context (in tokens). Allows feeding longer prompts.
/// Note that this affects memory.
Expand Down Expand Up @@ -367,6 +370,17 @@ pub struct Convert {
pub element_type: ElementType,
}

#[derive(Parser, Debug)]
pub struct Quantize {
/// The path to the model to quantize
#[arg()]
pub source: PathBuf,

/// The path to save the quantized model to
#[arg()]
pub destination: PathBuf,
}

#[derive(Parser, Debug, ValueEnum, Clone, Copy)]
pub enum ElementType {
/// Quantized 4-bit (type 0).
Expand Down
13 changes: 13 additions & 0 deletions llama-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ fn main() {
Args::Repl(args) => interactive(&args, false),
Args::ChatExperimental(args) => interactive(&args, true),
Args::Convert(args) => convert_pth_to_ggml(&args.directory, args.element_type.into()),
Args::Quantize(args) => quantize(&args),
}
}

Expand Down Expand Up @@ -191,6 +192,18 @@ fn interactive(
}
}

fn quantize(args: &cli_args::Quantize) {
llama_rs::quantize::quantize(
&args.source,
&args.destination,
llama_rs::ElementType::Q4_0,
|p| {
println!("{p:?}");
},
)
.unwrap();
}

fn load_prompt_file_with_prompt(
prompt_file: &cli_args::PromptFile,
prompt: Option<&str>,
Expand Down
6 changes: 5 additions & 1 deletion llama-rs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,9 @@ serde_json = { version = "1.0.94", optional = true }
protobuf = { version = "= 2.14.0", optional = true }
rust_tokenizers = { version = "3.1.2", optional = true }

# Used for the `quantize` feature
half = { version = "2.2.1", optional = true }

[features]
convert = ["dep:serde_json", "dep:protobuf", "dep:rust_tokenizers"]
convert = ["dep:serde_json", "dep:protobuf", "dep:rust_tokenizers"]
quantize = ["dep:half"]
84 changes: 84 additions & 0 deletions llama-rs/src/file.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#![allow(dead_code)]

use crate::LoadError;
pub use std::fs::File;
pub use std::io::{BufRead, BufReader, BufWriter, Read, Write};

fn read(reader: &mut impl BufRead, bytes: &mut [u8]) -> Result<(), LoadError> {
reader
.read_exact(bytes)
.map_err(|e| LoadError::ReadExactFailed {
source: e,
bytes: bytes.len(),
})
}

fn read_bytes<const N: usize>(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> {
let mut bytes = [0u8; N];
read(reader, &mut bytes)?;
Ok(bytes)
}

fn rw<const N: usize>(
reader: &mut impl BufRead,
writer: &mut impl Write,
) -> Result<[u8; N], LoadError> {
let mut bytes = [0u8; N];
read(reader, &mut bytes)?;
writer.write_all(&bytes)?;
Ok(bytes)
}

pub(crate) fn read_i32(reader: &mut impl BufRead) -> Result<i32, LoadError> {
Ok(i32::from_le_bytes(read_bytes::<4>(reader)?))
}

pub(crate) fn rw_i32(reader: &mut impl BufRead, writer: &mut impl Write) -> Result<i32, LoadError> {
Ok(i32::from_le_bytes(rw::<4>(reader, writer)?))
}

pub(crate) fn read_u32(reader: &mut impl BufRead) -> Result<u32, LoadError> {
Ok(u32::from_le_bytes(read_bytes::<4>(reader)?))
}

pub(crate) fn rw_u32(reader: &mut impl BufRead, writer: &mut impl Write) -> Result<u32, LoadError> {
Ok(u32::from_le_bytes(rw::<4>(reader, writer)?))
}

pub(crate) fn read_f32(reader: &mut impl BufRead) -> Result<f32, LoadError> {
Ok(f32::from_le_bytes(read_bytes::<4>(reader)?))
}

pub(crate) fn rw_f32(reader: &mut impl BufRead, writer: &mut impl Write) -> Result<f32, LoadError> {
Ok(f32::from_le_bytes(rw::<4>(reader, writer)?))
}

/// Helper function. Reads a string from the buffer and returns it.
pub(crate) fn read_string(reader: &mut BufReader<File>, len: usize) -> Result<String, LoadError> {
let mut buf = vec![0; len];
reader
.read_exact(&mut buf)
.map_err(|e| LoadError::ReadExactFailed {
source: e,
bytes: buf.len(),
})?;
let s = String::from_utf8(buf)?;
Ok(s)
}

pub(crate) fn rw_string(
reader: &mut impl BufRead,
writer: &mut impl Write,
len: usize,
) -> Result<String, LoadError> {
let mut buf = vec![0; len];
reader
.read_exact(&mut buf)
.map_err(|e| LoadError::ReadExactFailed {
source: e,
bytes: buf.len(),
})?;
writer.write_all(&buf)?;
let s = String::from_utf8_lossy(&buf);
Ok(s.into_owned())
}
Loading