Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
Use functions from the new crate
Browse files Browse the repository at this point in the history
  • Loading branch information
iacore committed Apr 8, 2023
1 parent 15fe19b commit 2e9311d
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 75 deletions.
9 changes: 5 additions & 4 deletions llama-loader/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ pub trait LoadHandler<T> {
}

/// # Returns
///
///
/// `None` to skip copying
/// `Some(buf)` to provide a buffer for copying weights into
fn get_tensor_buffer(&mut self, info: TensorInfo) -> ControlFlow<T, Option<&mut [u8]>> {
Expand Down Expand Up @@ -242,10 +242,9 @@ fn load_weights_ggjt<T>(
n_dims,
n_elements,
ftype,
start_offset: offset_aligned
start_offset: offset_aligned,
};


let type_size = ggml::type_size(ftype);
if let Some(buf) = retchk(handler.get_tensor_buffer(tensor_info))? {
reader.seek(SeekFrom::Start(offset_aligned))?;
Expand All @@ -258,7 +257,9 @@ fn load_weights_ggjt<T>(
reader.read_exact(buf)?;
} else {
// skip if no buffer is given
reader.seek(SeekFrom::Start(offset_aligned + (type_size * n_elements) as u64))?;
reader.seek(SeekFrom::Start(
offset_aligned + (type_size * n_elements) as u64,
))?;
}
}

Expand Down
45 changes: 2 additions & 43 deletions llama-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use thiserror::Error;
#[cfg(feature = "mmap")]
use memmap2::Mmap;

use llama_loader::util::*;
use llama_loader::{decode_element_type, ContainerType};

/// dummy struct
Expand Down Expand Up @@ -609,48 +610,6 @@ impl Model {
})?;
let mut reader = BufReader::new(&file);

fn read_bytes<const N: usize>(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> {
let mut bytes = [0u8; N];
reader
.read_exact(&mut bytes)
.map_err(|e| LoadError::ReadExactFailed {
source: e,
bytes: N,
})?;
Ok(bytes)
}

fn read_bytes_with_len(
reader: &mut impl BufRead,
len: usize,
) -> Result<Vec<u8>, LoadError> {
let mut bytes = vec![0u8; len];
reader
.read_exact(&mut bytes)
.map_err(|e| LoadError::ReadExactFailed {
source: e,
bytes: len,
})?;
Ok(bytes)
}

fn read_i32(reader: &mut impl BufRead) -> Result<i32, LoadError> {
Ok(i32::from_le_bytes(read_bytes::<4>(reader)?))
}

fn read_u32(reader: &mut impl BufRead) -> Result<u32, LoadError> {
Ok(u32::from_le_bytes(read_bytes::<4>(reader)?))
}

fn read_f32(reader: &mut impl BufRead) -> Result<f32, LoadError> {
Ok(f32::from_le_bytes(read_bytes::<4>(reader)?))
}

/// Helper function. Reads a string from the buffer and returns it.
fn read_string(reader: &mut BufReader<File>, len: usize) -> Result<String, LoadError> {
Ok(String::from_utf8(read_bytes_with_len(reader, len)?)?)
}

// Verify magic
let model_type: ContainerType = match read_u32(&mut reader)? {
ggml::FILE_MAGIC_GGMF => ContainerType::GGMF,
Expand Down Expand Up @@ -710,7 +669,7 @@ impl Model {

for i in 0..hparams.n_vocab {
let len = read_i32(&mut reader)?;
let token = read_bytes_with_len(&mut reader, len)?;
let token = read_bytes_with_len(&mut reader, len.try_into()?)?;
max_token_length = max_token_length.max(token.len());
id_to_token.push(token.clone());
token_to_id.insert(token, TokenId::try_from(i)?);
Expand Down
29 changes: 1 addition & 28 deletions llama-rs/src/loader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,7 @@ use std::{
use crate::ElementType;
use crate::{util, LoadError, LoadProgress, Model};
use llama_loader::decode_element_type;

pub(crate) fn read_bytes<const N: usize>(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> {
let mut bytes = [0u8; N];
reader
.read_exact(&mut bytes)
.map_err(|e| LoadError::ReadExactFailed {
source: e,
bytes: N,
})?;
Ok(bytes)
}

pub(crate) fn read_i32(reader: &mut impl BufRead) -> Result<i32, LoadError> {
Ok(i32::from_le_bytes(read_bytes::<4>(reader)?))
}

pub(crate) fn read_u32(reader: &mut impl BufRead) -> Result<u32, LoadError> {
Ok(u32::from_le_bytes(read_bytes::<4>(reader)?))
}

pub(crate) fn read_f32(reader: &mut impl BufRead) -> Result<f32, LoadError> {
Ok(f32::from_le_bytes(read_bytes::<4>(reader)?))
}
use llama_loader::util::*;

/// Helper function. Reads a string from the buffer and returns it.
pub(crate) fn read_string(reader: &mut impl BufRead, len: usize) -> Result<String, LoadError> {
Expand All @@ -43,11 +21,6 @@ pub(crate) fn read_string(reader: &mut impl BufRead, len: usize) -> Result<Strin
Ok(s)
}

// NOTE: Implementation from #![feature(buf_read_has_data_left)]
pub(crate) fn has_data_left(reader: &mut impl BufRead) -> Result<bool, std::io::Error> {
reader.fill_buf().map(|b| !b.is_empty())
}

pub(crate) fn load_weights_ggmf_or_unversioned(
file_offset: u64,
main_path: &Path,
Expand Down

0 comments on commit 2e9311d

Please sign in to comment.