Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Ported quantize.cpp #84

Merged
merged 22 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fixed according to comments
  • Loading branch information
FloppyDisck committed Apr 1, 2023
commit 69d7ddce2b1c09028c34eee5c61027048b2fc2ff
82 changes: 82 additions & 0 deletions llama-rs/src/file.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
use crate::LoadError;
pub use std::fs::File;
pub use std::io::{BufRead, BufReader, BufWriter, Read, Write};

fn read(reader: &mut impl BufRead, bytes: &mut [u8]) -> Result<(), LoadError> {
reader
.read_exact(bytes)
.map_err(|e| LoadError::ReadExactFailed {
source: e,
bytes: bytes.len(),
})
}

fn read_bytes<const N: usize>(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> {
let mut bytes = [0u8; N];
read(reader, &mut bytes)?;
Ok(bytes)
}

fn rw<const N: usize>(
reader: &mut impl BufRead,
writer: &mut impl Write,
) -> Result<[u8; N], LoadError> {
let mut bytes = [0u8; N];
read(reader, &mut bytes)?;
writer.write_all(&bytes)?;
Ok(bytes)
}

pub(crate) fn read_i32(reader: &mut impl BufRead) -> Result<i32, LoadError> {
Ok(i32::from_le_bytes(read_bytes::<4>(reader)?))
}

pub(crate) fn rw_i32(reader: &mut impl BufRead, writer: &mut impl Write) -> Result<i32, LoadError> {
Ok(i32::from_le_bytes(rw::<4>(reader, writer)?))
}

pub(crate) fn read_u32(reader: &mut impl BufRead) -> Result<u32, LoadError> {
Ok(u32::from_le_bytes(read_bytes::<4>(reader)?))
}

pub(crate) fn rw_u32(reader: &mut impl BufRead, writer: &mut impl Write) -> Result<u32, LoadError> {
Ok(u32::from_le_bytes(rw::<4>(reader, writer)?))
}

pub(crate) fn read_f32(reader: &mut impl BufRead) -> Result<f32, LoadError> {
Ok(f32::from_le_bytes(read_bytes::<4>(reader)?))
}

pub(crate) fn rw_f32(reader: &mut impl BufRead, writer: &mut impl Write) -> Result<f32, LoadError> {
Ok(f32::from_le_bytes(rw::<4>(reader, writer)?))
}

/// Helper function. Reads a string from the buffer and returns it.
pub(crate) fn read_string(reader: &mut BufReader<File>, len: usize) -> Result<String, LoadError> {
let mut buf = vec![0; len];
reader
.read_exact(&mut buf)
.map_err(|e| LoadError::ReadExactFailed {
source: e,
bytes: buf.len(),
})?;
let s = String::from_utf8(buf)?;
Ok(s)
}

pub(crate) fn rw_string(
reader: &mut impl BufRead,
writer: &mut impl Write,
len: usize,
) -> Result<String, LoadError> {
let mut buf = vec![0; len];
reader
.read_exact(&mut buf)
.map_err(|e| LoadError::ReadExactFailed {
source: e,
bytes: buf.len(),
})?;
writer.write_all(&buf)?;
let s = String::from_utf8_lossy(&buf);
Ok(s.into_owned())
}
40 changes: 3 additions & 37 deletions llama-rs/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
mod file;
mod ggml;
mod quantize;

Expand All @@ -15,7 +16,7 @@ use thiserror::Error;
use partial_sort::PartialSort;
use rand::{distributions::WeightedIndex, prelude::Distribution};

pub use quantize::llama_model_quantize;
pub use quantize::{quantize, QuantizeLoadProgress};
pub const EOD_TOKEN_ID: TokenId = 2; // Hardcoded (for now?)

#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord)]
Expand Down Expand Up @@ -424,6 +425,7 @@ impl Model {
n_ctx: i32,
load_progress_callback: impl Fn(LoadProgress),
) -> Result<(Model, Vocabulary), LoadError> {
use crate::file::{read_f32, read_i32, read_string, read_u32};
use std::fs::File;
use std::io::BufReader;

Expand All @@ -437,42 +439,6 @@ impl Model {
})?,
);

fn read_bytes<const N: usize>(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> {
let mut bytes = [0u8; N];
reader
.read_exact(&mut bytes)
.map_err(|e| LoadError::ReadExactFailed {
source: e,
bytes: N,
})?;
Ok(bytes)
}

fn read_i32(reader: &mut impl BufRead) -> Result<i32, LoadError> {
Ok(i32::from_le_bytes(read_bytes::<4>(reader)?))
}

fn read_u32(reader: &mut impl BufRead) -> Result<u32, LoadError> {
Ok(u32::from_le_bytes(read_bytes::<4>(reader)?))
}

fn read_f32(reader: &mut impl BufRead) -> Result<f32, LoadError> {
Ok(f32::from_le_bytes(read_bytes::<4>(reader)?))
}

/// Helper function. Reads a string from the buffer and returns it.
fn read_string(reader: &mut BufReader<File>, len: usize) -> Result<String, LoadError> {
let mut buf = vec![0; len];
reader
.read_exact(&mut buf)
.map_err(|e| LoadError::ReadExactFailed {
source: e,
bytes: buf.len(),
})?;
let s = String::from_utf8(buf)?;
Ok(s)
}

// Verify magic
let is_legacy_model: bool = match read_u32(&mut reader)? {
ggml::FILE_MAGIC => false,
Expand Down
Loading