Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Ported quantize.cpp #84

Merged
merged 22 commits into from
Apr 25, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Merge branch 'main' of github.com:rustformers/llama-rs into quantize
  • Loading branch information
philpax committed Apr 13, 2023
commit 02a89997f19d1f52733288f74745ee5700889e3d
63 changes: 30 additions & 33 deletions llama-rs/src/file.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,73 +4,61 @@ use crate::LoadError;
pub use std::fs::File;
pub use std::io::{BufRead, BufReader, BufWriter, Read, Write};

fn read(reader: &mut impl BufRead, bytes: &mut [u8]) -> Result<(), LoadError> {
reader
.read_exact(bytes)
.map_err(|e| LoadError::ReadExactFailed {
source: e,
bytes: bytes.len(),
})
}

fn read_bytes<const N: usize>(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> {
pub fn read_bytes<const N: usize>(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> {
let mut bytes = [0u8; N];
read(reader, &mut bytes)?;
Ok(bytes)
}

fn rw<const N: usize>(
reader: &mut impl BufRead,
writer: &mut impl Write,
) -> Result<[u8; N], LoadError> {
let mut bytes = [0u8; N];
pub fn read_bytes_with_len(reader: &mut impl BufRead, len: usize) -> Result<Vec<u8>, LoadError> {
let mut bytes = vec![0u8; len];
read(reader, &mut bytes)?;
writer.write_all(&bytes)?;
Ok(bytes)
}

pub(crate) fn read_i32(reader: &mut impl BufRead) -> Result<i32, LoadError> {
pub fn read_i32(reader: &mut impl BufRead) -> Result<i32, LoadError> {
Ok(i32::from_le_bytes(read_bytes::<4>(reader)?))
}

pub(crate) fn rw_i32(reader: &mut impl BufRead, writer: &mut impl Write) -> Result<i32, LoadError> {
pub fn rw_i32(reader: &mut impl BufRead, writer: &mut impl Write) -> Result<i32, LoadError> {
Ok(i32::from_le_bytes(rw::<4>(reader, writer)?))
}

pub(crate) fn read_u32(reader: &mut impl BufRead) -> Result<u32, LoadError> {
pub fn read_u32(reader: &mut impl BufRead) -> Result<u32, LoadError> {
Ok(u32::from_le_bytes(read_bytes::<4>(reader)?))
}

pub(crate) fn rw_u32(reader: &mut impl BufRead, writer: &mut impl Write) -> Result<u32, LoadError> {
pub fn rw_u32(reader: &mut impl BufRead, writer: &mut impl Write) -> Result<u32, LoadError> {
Ok(u32::from_le_bytes(rw::<4>(reader, writer)?))
}

pub(crate) fn read_f32(reader: &mut impl BufRead) -> Result<f32, LoadError> {
pub fn read_f32(reader: &mut impl BufRead) -> Result<f32, LoadError> {
Ok(f32::from_le_bytes(read_bytes::<4>(reader)?))
}

pub(crate) fn rw_f32(reader: &mut impl BufRead, writer: &mut impl Write) -> Result<f32, LoadError> {
pub fn rw_f32(reader: &mut impl BufRead, writer: &mut impl Write) -> Result<f32, LoadError> {
Ok(f32::from_le_bytes(rw::<4>(reader, writer)?))
}

/// Helper function. Reads a string from the buffer and returns it.
pub(crate) fn read_string(reader: &mut BufReader<File>, len: usize) -> Result<String, LoadError> {
let mut buf = vec![0; len];
pub fn read_string(reader: &mut BufReader<File>, len: usize) -> Result<String, LoadError> {
Ok(String::from_utf8(read_bytes_with_len(reader, len)?)?)
}

fn read(reader: &mut impl BufRead, bytes: &mut [u8]) -> Result<(), LoadError> {
reader
.read_exact(&mut buf)
.read_exact(bytes)
.map_err(|e| LoadError::ReadExactFailed {
source: e,
bytes: buf.len(),
})?;
let s = String::from_utf8(buf)?;
Ok(s)
bytes: bytes.len(),
})
}

pub(crate) fn rw_string(
pub fn rw_bytes_with_len(
reader: &mut impl BufRead,
writer: &mut impl Write,
len: usize,
) -> Result<String, LoadError> {
) -> Result<Vec<u8>, LoadError> {
let mut buf = vec![0; len];
reader
.read_exact(&mut buf)
Expand All @@ -79,6 +67,15 @@ pub(crate) fn rw_string(
bytes: buf.len(),
})?;
writer.write_all(&buf)?;
let s = String::from_utf8_lossy(&buf);
Ok(s.into_owned())
Ok(buf)
}

fn rw<const N: usize>(
reader: &mut impl BufRead,
writer: &mut impl Write,
) -> Result<[u8; N], LoadError> {
let mut bytes = [0u8; N];
read(reader, &mut bytes)?;
writer.write_all(&bytes)?;
Ok(bytes)
}
2 changes: 1 addition & 1 deletion llama-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,7 @@ impl Model {
n_context_tokens: usize,
mut load_progress_callback: impl FnMut(LoadProgress),
) -> Result<(Model, Vocabulary), LoadError> {
use crate::file::{read_f32, read_i32, read_string, read_u32};
use crate::file::{read_bytes_with_len, read_f32, read_i32, read_string, read_u32};
use std::fs::File;
use std::io::BufReader;

Expand Down
2 changes: 1 addition & 1 deletion llama-rs/src/quantize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ pub fn quantize(

for i in 0..hparams.n_vocab {
let len = rw_u32(&mut finp, &mut fout)?.try_into()?;
let word = rw_string(&mut finp, &mut fout, len)?;
let word = rw_bytes_with_len(&mut finp, &mut fout, len)?;
let score = rw_f32(&mut finp, &mut fout)?;

vocab.token_to_id.insert(word.clone(), i.try_into()?);
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.