Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Standalone loader #125

Merged
merged 46 commits into from
Apr 22, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
bdbea68
Add loader stub for GGJT
iacore Apr 6, 2023
b0a666f
Add loading code for ggjt
iacore Apr 6, 2023
9eefdc5
code cleanup that doesn't change anything
iacore Apr 6, 2023
c212c53
more code cleanup
iacore Apr 6, 2023
bfaec3a
minor change
iacore Apr 7, 2023
b6044ee
Add non-mmap loader for GGJT
iacore Apr 7, 2023
1872dda
Prefer traits in loader.rs
iacore Apr 7, 2023
ec1fca7
cargo fmt
iacore Apr 7, 2023
cc846ae
cargo clippy --fix
iacore Apr 7, 2023
bf847dd
Remove ggml::Tensor::set_data
iacore Apr 7, 2023
ea7094c
fix(llama): buffer tokens until valid UTF-8
philpax Apr 7, 2023
c848d5e
Add standalone loader
iacore Apr 8, 2023
8390593
Move loader to standalone crate llama-loader
iacore Apr 8, 2023
15fe19b
[llama-loader] Support non-copy loader
iacore Apr 8, 2023
2e9311d
Use functions from the new crate
iacore Apr 8, 2023
4dd0fc5
Merge branch 'main' into llama-loader
philpax Apr 13, 2023
c40e36e
Merge branch 'main' of github.com:rustformers/llama-rs into llama-loader
philpax Apr 13, 2023
34429e0
refactor(llama): pass mut tensors down
philpax Apr 13, 2023
38e7d58
feat/loader Make hparams configurable
iacore Apr 14, 2023
5dfc55d
feat/loader Add hook to support multi-part model loading
iacore Apr 14, 2023
48efd74
rename llama-loader to ggml-loader
iacore Apr 14, 2023
0fbbedd
Merge branch 'main' into llama-loader
philpax Apr 19, 2023
d65996d
fix
jon-chuang Apr 12, 2023
267d8ae
no_alloc
jon-chuang Apr 12, 2023
81a6979
chore: fix clippy
philpax Apr 19, 2023
80d189e
refactor(util): make find_all_model_files error
philpax Apr 19, 2023
85e1148
UnsupportedElementtype -> UnsupportedElementType
philpax Apr 19, 2023
3f29992
feat: experimental loader2 wire-up (incomplete)
philpax Apr 19, 2023
94951c4
fix dead doc link
philpax Apr 19, 2023
69f355b
feat: turn mmap on by default, add --no-mmap
philpax Apr 19, 2023
17bc0cc
Fix loading GGJT
iacore Apr 20, 2023
6641ae9
minor fix
iacore Apr 20, 2023
3910b6a
Add mmap
iacore Apr 20, 2023
e4834bd
cargo fmt
iacore Apr 20, 2023
c380cee
Make loader2 default
iacore Apr 20, 2023
5b9788b
fix: remove dbg!(start_pos)
philpax Apr 22, 2023
cbf0756
fix: respect --no-mmap
philpax Apr 22, 2023
8813b0f
Merge branch 'main' of github.com:rustformers/llama-rs into llama-loader
philpax Apr 22, 2023
430abfe
chore: remove old comments
philpax Apr 22, 2023
bf6a917
chore: remove unused error case
philpax Apr 22, 2023
9b908ae
fix: remove some panics
philpax Apr 22, 2023
d8c4ca6
feat: remove AlreadyAdded error
philpax Apr 22, 2023
cabc4c9
minor fix
iacore Apr 22, 2023
1930496
fix: Vocabulary::push_token is infallible
philpax Apr 22, 2023
bdb9856
fix: bail on multipart models with loader2
philpax Apr 22, 2023
b41fe14
refactor: make Vocabulary::push_token pub(crate)
philpax Apr 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Add loader stub for GGJT
  • Loading branch information
iacore committed Apr 8, 2023
commit bdbea689c5ec17916dc86092d44ad8d55cf1b244
6 changes: 4 additions & 2 deletions ggml/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@ use std::{
sync::{Arc, Weak},
};

/// Magic constant for `ggml` files (versioned).
pub const FILE_MAGIC: u32 = 0x67676d66;
/// Magic constant for `ggml` files (versioned, ggmf).
pub const FILE_MAGIC_GGMF: u32 = 0x67676d66;
/// Magic constant for `ggml` files (versioned, ggjt).
pub const FILE_MAGIC_GGJT: u32 = 0x67676a74;
/// Magic constant for `ggml` files (unversioned).
pub const FILE_MAGIC_UNVERSIONED: u32 = 0x67676d6c;

Expand Down
322 changes: 36 additions & 286 deletions llama-rs/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
#![deny(missing_docs)]
//! LLaMA-rs is a Rust port of the llama.cpp project. This allows running inference for Facebook's LLaMA model on a CPU with good performance using full precision, f16 or 4-bit quantized versions of the model.

mod loader;

use core::slice;
use std::{
collections::HashMap,
Expand Down Expand Up @@ -580,6 +582,7 @@ impl Model {
n_context_tokens: usize,
load_progress_callback: impl Fn(LoadProgress),
) -> Result<(Model, Vocabulary), LoadError> {
use loader::*;
use std::fs::File;
use std::io::BufReader;

Expand All @@ -593,46 +596,11 @@ impl Model {
})?,
);

fn read_bytes<const N: usize>(reader: &mut impl BufRead) -> Result<[u8; N], LoadError> {
let mut bytes = [0u8; N];
reader
.read_exact(&mut bytes)
.map_err(|e| LoadError::ReadExactFailed {
source: e,
bytes: N,
})?;
Ok(bytes)
}

fn read_i32(reader: &mut impl BufRead) -> Result<i32, LoadError> {
Ok(i32::from_le_bytes(read_bytes::<4>(reader)?))
}

fn read_u32(reader: &mut impl BufRead) -> Result<u32, LoadError> {
Ok(u32::from_le_bytes(read_bytes::<4>(reader)?))
}

fn read_f32(reader: &mut impl BufRead) -> Result<f32, LoadError> {
Ok(f32::from_le_bytes(read_bytes::<4>(reader)?))
}

/// Helper function. Reads a string from the buffer and returns it.
fn read_string(reader: &mut BufReader<File>, len: usize) -> Result<String, LoadError> {
let mut buf = vec![0; len];
reader
.read_exact(&mut buf)
.map_err(|e| LoadError::ReadExactFailed {
source: e,
bytes: buf.len(),
})?;
let s = String::from_utf8(buf)?;
Ok(s)
}

// Verify magic
let is_legacy_model: bool = match read_u32(&mut reader)? {
ggml::FILE_MAGIC => false,
ggml::FILE_MAGIC_UNVERSIONED => true,
let model_type: ModelType = match read_u32(&mut reader)? {
ggml::FILE_MAGIC_GGMF => ModelType::GGMF,
ggml::FILE_MAGIC_GGJT => ModelType::GGJT,
ggml::FILE_MAGIC_UNVERSIONED => ModelType::Unversioned,
_ => {
return Err(LoadError::InvalidMagic {
path: main_path.to_owned(),
Expand All @@ -641,12 +609,14 @@ impl Model {
};

// Load format version
if !is_legacy_model {
#[allow(unused_variables)]
let version: u32 = match read_u32(&mut reader)? {
ggml::FORMAT_VERSION => ggml::FORMAT_VERSION,
version => return Err(LoadError::InvalidFormatVersion { value: version }),
};
match model_type {
ModelType::GGMF | ModelType::GGJT => {
let _version: u32 = match read_u32(&mut reader)? {
ggml::FORMAT_VERSION => ggml::FORMAT_VERSION,
version => return Err(LoadError::InvalidFormatVersion { value: version }),
};
}
ModelType::Unversioned => {}
}

// =================
Expand Down Expand Up @@ -681,8 +651,12 @@ impl Model {
let mut max_token_length = 0;

for i in 0..hparams.n_vocab {
let len = read_i32(&mut reader)?;
if let Ok(word) = read_string(&mut reader, len as usize) {
let len = match model_type {
// `read_i32` maybe a typo
ModelType::GGMF | ModelType::Unversioned => read_i32(&mut reader)? as usize,
ModelType::GGJT => read_u32(&mut reader)? as usize,
};
if let Ok(word) = read_string(&mut reader, len) {
max_token_length = max_token_length.max(word.len());
id_to_token.push(word.clone());
token_to_id.insert(word, TokenId::try_from(i)?);
Expand All @@ -692,13 +666,16 @@ impl Model {
}

// Token score, currently unused
if !is_legacy_model {
if let Ok(score) = read_f32(&mut reader) {
id_to_token_score.push(score);
match model_type {
ModelType::GGMF | ModelType::GGJT => {
if let Ok(score) = read_f32(&mut reader) {
id_to_token_score.push(score);
}
}
ModelType::Unversioned => {
// Legacy model, set empty score
id_to_token_score.push(0.);
}
} else {
// Legacy model, set empty score
id_to_token_score.push(0.);
}
}

Expand Down Expand Up @@ -825,240 +802,13 @@ impl Model {
}
};

// Close the file, but keep its offset. That way we know how to skip the
// metadata when loading the parts.
let file_offset = reader.stream_position()?;
drop(reader);

let paths = util::find_all_model_files(main_path)?;
let n_parts = paths.len();

for (i, part_path) in paths.into_iter().enumerate() {
let part_id = i;

load_progress_callback(LoadProgress::PartLoading {
file: &part_path,
current_part: i,
total_parts: n_parts,
});

let mut part_reader = BufReader::new(File::open(&part_path)?);

// Skip metadata
part_reader.seek(SeekFrom::Start(file_offset))?;

let mut total_size = 0;
let mut n_tensors = 0;

// Load weights
loop {
// NOTE: Implementation from #![feature(buf_read_has_data_left)]
let is_eof = part_reader.fill_buf().map(|b| b.is_empty())?;

if is_eof {
break;
}

let n_dims = usize::try_from(read_i32(&mut part_reader)?)?;
let length = read_i32(&mut part_reader)?;
let ftype = read_u32(&mut part_reader)?;

let mut nelements = 1;
let mut ne = [1i64, 1i64];

#[allow(clippy::needless_range_loop)]
for i in 0..n_dims {
ne[i] = read_i32(&mut part_reader)? as i64;
nelements *= usize::try_from(ne[i])?;
}

let tensor_name = read_string(&mut part_reader, length as usize)?;

let Some(tensor) = model.tensors.get(&tensor_name)
else {
return Err(LoadError::UnknownTensor { tensor_name, path: part_path });
};

// split_type = 0: split by columns
// split_type = 1: split by rows
//
// split_type = 0:
// regex:
// - tok_embeddings.*
// - layers.*.attention.wo.weight
// - layers.*.feed_forward.w2.weight

// split_type = 1:
// regex:
// - output.*
// - layers.*.attention.wq.weight
// - layers.*.attention.wk.weight
// - layers.*.attention.wv.weight
// - layers.*.feed_forward.w1.weight
// - layers.*.feed_forward.w3.weight
#[allow(clippy::if_same_then_else)]
let split_type = if tensor_name.contains("tok_embeddings") {
0
} else if tensor_name.contains("layers") {
if tensor_name.contains("attention.wo.weight") {
0
} else if tensor_name.contains("feed_forward.w2.weight") {
0
} else {
1
}
} else if tensor_name.contains("output") {
1
} else {
0
};

if n_dims == 1 {
if tensor.nelements() != nelements {
return Err(LoadError::TensorWrongSize {
tensor_name,
path: part_path,
});
}
} else if tensor.nelements() / n_parts != nelements {
return Err(LoadError::TensorWrongSize {
tensor_name,
path: part_path,
});
}

if n_dims == 1 {
if tensor.get_ne()[0] != ne[0] || tensor.get_ne()[1] != ne[1] {
return Err(LoadError::TensorWrongSize {
tensor_name,
path: part_path,
});
}
} else if split_type == 0 {
if tensor.get_ne()[0] / i64::try_from(n_parts)? != ne[0]
|| tensor.get_ne()[1] != ne[1]
{
return Err(LoadError::TensorWrongSize {
tensor_name,
path: part_path,
});
}
} else if tensor.get_ne()[0] != ne[0]
|| tensor.get_ne()[1] / i64::try_from(n_parts)? != ne[1]
{
return Err(LoadError::TensorWrongSize {
tensor_name,
path: part_path,
});
}

let bpe = match ftype {
0 => ggml::type_size(ggml::Type::F32),
1 => ggml::type_size(ggml::Type::F16),
2 => {
assert_eq!(ne[0] % 64, 0);
ggml::type_size(ggml::Type::Q4_0)
}
3 => {
assert_eq!(ne[0] % 64, 0);
ggml::type_size(ggml::Type::Q4_1)
}
_ => {
return Err(LoadError::InvalidFtype {
tensor_name,
ftype,
path: part_path,
})
}
};

if n_dims == 1 || n_parts == 1 {
if (nelements * bpe) / ggml::blck_size(tensor.get_type()) != tensor.nbytes() {
return Err(LoadError::TensorWrongSize {
tensor_name,
path: part_path,
});
}

if part_id == 0 {
// SAFETY: yolo, same as original code
let slice = unsafe {
let data = tensor.data();
std::slice::from_raw_parts_mut(data as *mut u8, tensor.nbytes())
};
part_reader.read_exact(slice)?;
} else {
part_reader.seek(SeekFrom::Current(tensor.nbytes() as i64))?;
}

total_size += tensor.nbytes();
} else {
if (nelements * bpe) / ggml::blck_size(tensor.get_type())
!= tensor.nbytes() / n_parts
{
return Err(LoadError::TensorWrongSize {
tensor_name,
path: part_path,
});
}

if split_type == 0 {
let np0 = ne[0];
let row_size = (usize::try_from(tensor.get_ne()[0])?
/ ggml::blck_size(tensor.get_type()))
* ggml::type_size(tensor.get_type());

assert_eq!(row_size, tensor.get_nb()[1]);

for i1 in 0..ne[1] {
let offset_row = i1 as usize * row_size;
let offset = offset_row
+ ((part_id * np0 as usize) / ggml::blck_size(tensor.get_type()))
* ggml::type_size(tensor.get_type());
// SAFETY: yolo, same as original code
unsafe {
let ptr = tensor.data().add(offset);
let slice = std::slice::from_raw_parts_mut(
ptr as *mut u8,
row_size / n_parts,
);
part_reader.read_exact(slice)?;
}
}
} else {
let np1 = ne[1];
let row_size = (usize::try_from(tensor.get_ne()[0])?
/ ggml::blck_size(tensor.get_type()))
* ggml::type_size(tensor.get_type());

for i1 in 0..ne[1] {
let offset_row = (i1 as usize + part_id * np1 as usize) * row_size;
// SAFETY: yolo, same as original code
unsafe {
let ptr = tensor.data().add(offset_row);
let slice =
std::slice::from_raw_parts_mut(ptr as *mut u8, row_size);
part_reader.read_exact(slice)?;
}
}
}

total_size += tensor.nbytes() / n_parts;
}

n_tensors += 1;
load_progress_callback(LoadProgress::PartTensorLoaded {
file: &part_path,
current_tensor: n_tensors.try_into()?,
tensor_count: model.tensors.len(),
});
match model_type {
ModelType::GGMF | ModelType::Unversioned => {
load_weights_ggmf_or_unversioned(reader, main_path, load_progress_callback, &model)?
}
ModelType::GGJT => {
load_weights_ggjt(reader, main_path, load_progress_callback, &model)?
}

load_progress_callback(LoadProgress::PartLoaded {
file: &part_path,
byte_size: total_size,
tensor_count: n_tensors.try_into()?,
});
}

Ok((model, vocab))
Expand Down
Loading