Skip to content

Commit

Permalink
Once upon a time I was told to comment my code
Browse files Browse the repository at this point in the history
  • Loading branch information
carlini committed Feb 24, 2022
1 parent 262cea5 commit b682f7c
Showing 1 changed file with 162 additions and 90 deletions.
252 changes: 162 additions & 90 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ pub fn to_bytes(input: &[u64]) -> Vec<u8> {
bytes
}

/* Convert a uint8 array to a uint64. Only called on (relatively) files. */
/* Convert a uint8 array to a uint64. Only called on (relatively) small files. */
pub fn from_bytes(input: Vec<u8>) -> Vec<u64> {
println!("S {}", input.len());
let mut bytes:Vec<u64> = Vec::with_capacity(input.len()/8);
Expand All @@ -193,13 +193,15 @@ fn get_next_pointer_from_table(mut tablestream:&mut TableStream) -> u64 {
return out;
}

/* For a suffix array, just compute A[i], but load off disk because A is biiiiiiigggggg. */
fn table_load_disk(table:&mut BufReader<File>, index:usize) -> usize{
table.seek(std::io::SeekFrom::Start ((index*8) as u64)).expect ("Seek failed!");
let mut tmp = [0u8; 8];
table.read_exact(&mut tmp).unwrap();
return u64::from_le_bytes(tmp) as usize;
}

/* Binary search to find where query happens to exist in text */
fn off_disk_position(text: &[u8], table: &mut BufReader<File>, query: &[u8]) -> usize {
let (mut left, mut right) = (0, text.len());
while left < right {
Expand All @@ -220,26 +222,6 @@ struct TableStream {
}


#[derive(Copy, Clone, Eq, PartialEq)]
struct MergeState<'a> {
suffix: &'a [u8],
position: u64,
table_index: usize
}

impl<'a> Ord for MergeState<'a> {
fn cmp(&self, other: &Self) -> Ordering {
other.suffix.cmp(&self.suffix)
}
}

impl<'a> PartialOrd for MergeState<'a> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}


fn make_table(path: std::string::String, offset: usize) -> TableStream {
let mut table = TableStream {
file: std::io::BufReader::new(fs::File::open(path).unwrap()),
Expand All @@ -253,6 +235,14 @@ fn make_table(path: std::string::String, offset: usize) -> TableStream {

const HACKSIZE:usize=100000;

/*
* Helper function to actually do the count of the number of times something is repeated.
* This should be fairly simple.
* First, perform binary search using the on-disk suffix array to find the first place
* where the string occurrs. If it doesn't exist then return 0.
* Then, binary search again to find the last location it occurrs.
* Return the difference between the two.
*/
fn count_occurances(text: &mut File, mut table: &mut BufReader<File>, size: u64, str: &[u8]) -> u64{
let mut buf = vec![0u8; str.len()];

Expand Down Expand Up @@ -294,22 +284,23 @@ fn count_occurances(text: &mut File, mut table: &mut BufReader<File>, size: u64,
return low-start;
}

/*
* Create a suffix array for a given file in one go.
* Calling this method is memory heavy---it's technically linear in the
* length of the file, but the constant is quite big.
* As a result, this method should only be called for files that comfortably
* fit into memory.
*
* The result of calling this method is a new file with ".table.bin" appended
* to the name which is the suffix array of sorted suffix pointers. This file
* should be exactly 8x larger than the original file (one u64 pointer per
* byte of the original).
*
* If the file does not fit into memory, then instead you should use the
* alternate save_part and then merge_parallel in two steps. See the comments
* below for how those work.
*/
fn cmd_build(fpath: &String) -> std::io::Result<()> {
/* Create a suffix array for a given file in one go.
* Calling this method is memory heavy---it's technically linear in the
* length of the file, but the constant is quite big.
* As a result, this method should only be called for files that comfortably
* fit into memory.
*
* The result of calling this method is a new file with ".table.bin" appended
* to the name which is the suffix array of sorted suffix pointers. This file
* should be exactly 8x larger than the original file (one u64 pointer per
* byte of the original).
*
* If the file does not fit into memory, then instead you should use the
* alternate save_part and then merge_parallel in two steps. See the comments
* below for how those work.
*/
let now = Instant::now();
println!("Reading the dataset at time t={}ms", now.elapsed().as_millis());
let mut text_ = Vec::with_capacity(std::fs::metadata(fpath.clone()).unwrap().len() as usize);
Expand All @@ -333,12 +324,13 @@ fn cmd_build(fpath: &String) -> std::io::Result<()> {
Ok(())
}

/*
* Create a suffix array for a subsequence of bytes.
* As with save, this method is linear in the number of bytes that are
* being saved but the constant is rather high. This method does exactly
* the same thing as save except on a range of bytes.
*/
fn cmd_build_part(fpath: &String, start: u64, end: u64) -> std::io::Result<()> {
/* Create a suffix array for a subsequence of bytes.
* As with save, this method is linear in the number of bytes that are
* being saved but the constant is rather high. This method does exactly
* the same thing as save except on a range of bytes.
*/
let now = Instant::now();
println!("Opening up the dataset files");

Expand Down Expand Up @@ -370,6 +362,17 @@ fn cmd_build_part(fpath: &String, start: u64, end: u64) -> std::io::Result<()>
Ok(())
}

/*
* Count how many times a particular string has occurred in the dataset.
*
* This is the easiest method to understand. It just performs binary search on the
* suffix array and uses it exactly as it was designed. It will output the number of counts.
*
* NOTE: This function allows overlapping sequences to count as different duplicates.
* So if our string is `aaaa` and we count how many times `aa` occurrs, it will return 3,
* not 2. This is different from python's "aaaa".count("aa") which will say 2.
* This may or may not be a problem for you. But if is is, that's you're problem, not mine.
*/
fn cmd_count_occurrences(fpath: &String, querypath: &String) -> std::io::Result<()> {
/* Count the numberof times a particular sequence occurs in the table.
*/
Expand All @@ -389,34 +392,50 @@ fn cmd_count_occurrences(fpath: &String, querypath: &String) -> std::io::Resul
Ok(())
}


/*
* Given a string S and suffix array A, compute statistics about how many
* sequences in A are duplicated (and do it using as many threads as possible).
*
* The basic algorithm is simple. For every pair of items (i,i+1) in the
* suffix array, we compare the suffixes S[A[i]..] and S[A[i+i]..] and count
* how many characters they have in common. We then report various statistics
* about this (e.g., the length of the match, which sequences match each other
* with at least T tokens, etc).
*
* The first complication is that we can't load all of A into memory at once.
* This is too big. (e.g., the suffix array for C4 is 2.7 terabytes (!).
* We might be able to fit 345GB in memory on current hardware, but not
* 2.7TB. (If you're reading this in 2030, hello there. This must all look
* very silly to you. But I promise that, today, 2.7TB of memory is just too
* much. By the way, has AGI taken over the world? I hope not.)
*
* Fortunately our algorithm doesn't require random access into A, so we can
* just stream it off disk and then immediately throw away the old data.
*
* The second complication is that we want this to be fast. Very fast. So
* we're going to parallelize the algorithm over as many threads as possible.
* Fortunately this is Rust, and not Python, so the GIL is not going to make
* life terrible. We set up one copy of the string S in memory, and then we
* can have each of the threads in parallel stream over A starting at different
* offsets.
*
* The output of this algorithm is a bunch of files saved to cache_dir named
* /cache_dir/dups_S_i-j
* /cache_dir/sizes_S_i-j
* Where i-j is the range of bytes that are covered by this file.
* The dups file stores just a list of 8-byte values [x_i] of indexs where S[x..x+T]
* is duplicated elsewhere in the dataset.
*
* Because the list is produced in lexical order, the duplicates for the same string
* will all be sequential in the list, and this is where the sizes file comes in.
* The sizes file says which duplicates from the dups file correspond to the same "cluster".
* So if sizes = [5, 2, 8 ...] then it means the first 5 entries in the dups file correspond
* to the same string that's repeated 5 times, and the next 2 entries in the dups file are
* a pair of repeated strings.
*/
fn cmd_self_similar(data_file: &String, length_threshold: &usize, frequency_threshold: &usize,
only_save_one: &bool, cache_dir: &String, num_threads: i64) -> std::io::Result<()> {
/* Given a string S and suffix array A, compute statistics about how many
* sequences in A are duplicated (and do it using as many threads as possible).
*
* The basic algorithm is simple. For every pair of items (i,i+1) in the
* suffix array, we compare the suffixes S[A[i]..] and S[A[i+i]..] and count
* how many characters they have in common. We then report various statistics
* about this (e.g., the length of the match, which sequences match each other
* with at least T tokens, etc).
*
* The first complication is that we can't load all of A into memory at once.
* This is too big. (e.g., the suffix array for C4 is 2.7 terabytes (!).
* We might be able to fit 345GB in memory on current hardware, but not
* 2.7TB. (If you're reading this in 2030, hello there. This must all look
* very silly to you. But I promise that, today, 2.7TB of memory is just too
* much. By the way, has AGI taken over the world? I hope not.)
*
* Fortunately our algorithm doesn't require random access into A, so we can
* just stream it off disk and then immediately throw away the old data.
*
* The second complication is that we want this to be fast. Very fast. So
* we're going to parallelize the algorithm over as many threads as possible.
* Fortunately this is Rust, and not Python, so the GIL is not going to make
* life terrible. We set up one copy of the string S in memory, and then we
* can have each of the threads in parallel stream over A starting at different
* offsets.
*/
println!("Start load!");

let text = filebuffer::FileBuffer::open(data_file).unwrap();
Expand Down Expand Up @@ -518,6 +537,29 @@ fn cmd_self_similar(data_file: &String, length_threshold: &usize, frequency_thre
Ok(())
}

/*
* Given a string S1 and suffix array A1, and another string S2 with array A2,
* find all sequences that are duplicated between S1 and S2 with any particular length.
*
* The basic algorithm is simple, and seems very much like a merge operation.
* Start enumerating all sequences from A1 which gives a sorted enumeration of S1.
* If S1[A1[0]..] < S2[A2[0]..] then advance the pointer walking S1, otherwise
* advance the pointer walking S2. If ever S1[A1[i]..A[i]+L] = S2[A2[j]..A2[j]+L]
* then we have a match and write it down.
*
* As with the self-similar comparison, we can't fit A1 or A2 into memory. So do the
* same streming tricks. And again we want things to go fast, so we're going to run
* it on as many parallel threads as possible.
*
* The output of this algorithm is a bunch of files saved to cache_dir named
* /cache_dir/dups_S1_i-j_S1-k-l
* /cache_dir/sizes_S2_i-j_S2-k-l
* Here, A and B are the two files we're cross-deduplicating (probably a train and test set).
* i-j is the range of bytes that are covered by this file in S1, and similarly k-l for S2.
*
* The dups and size file have the same interpretation as before. But this time there are
* two, one for the A -> B comparison, and another for the B -> A comparison.
*/
fn cmd_across_similar(data_file_1: &String, data_file_2: &String, cache_dir: &String,
length_threshold: usize, num_threads: i64) -> std::io::Result<()> {
let text1 = filebuffer::FileBuffer::open(data_file_1).unwrap();
Expand Down Expand Up @@ -670,26 +712,59 @@ fn cmd_across_similar(data_file_1: &String, data_file_2: &String, cache_dir: &St
Ok(())
}


/*
* A little bit of state for the merge operation below.
* - suffix is suffix of one of the parts of the dataset we're merging;
this is the value we're sorting on
* - position is the location of this suffix (so suffix = array[position..])
* - table_index says which suffix array this suffix is a part of
*/
#[derive(Copy, Clone, Eq, PartialEq)]
struct MergeState<'a> {
suffix: &'a [u8],
position: u64,
table_index: usize
}

impl<'a> Ord for MergeState<'a> {
fn cmp(&self, other: &Self) -> Ordering {
other.suffix.cmp(&self.suffix)
}
}

impl<'a> PartialOrd for MergeState<'a> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}

/*
* Merge together M different suffix arrays (probably created with make-part).
* That is, given strings S_i and suffix arrays A_i compute the suffix array
* A* = make-suffix-array(concat S_i)
* In order to do this we just implement mergesort's Merge operation on each
* of the arrays A_i to construct a sorted array A*.
*
* This algorithm is *NOT A LINEAR TIME ALGORITHM* in the worst case. If you run
* it on a dataset consisting entirely of the character A it will be quadratic.
* Fortunately for us, language model datasets typically don't just repeat the same
* character a hundred million times in a row. So in practice, it's linear time.
*
* There are thre complications here.
*
* As with selfsimilar_parallel, we can't fit all A_i into memory at once, and
* we want to make things fast and so parallelize our execution. So we do the
* same tricks as before to make things work.
*
* However we have one more problem. In order to know how to merge the final
* few bytes of array S_0 into their correct, we need to know what bytes come next.
* So in practice we make sure that S_{i}[-HACKSIZE:] === S_{i+1}[:HACKSIZE].
* As long as HACKSIZE is longer than the longest potential match, everything
* will work out correctly. (I did call it hacksize after all.....)
* In practice this works. It may not for your use case if there are long duplicates.
*/
fn cmd_merge(data_files: &Vec<String>, output_file: &String, num_threads: i64) -> std::io::Result<()> {
/* Merge together M different suffix arrays (probably created with save_part).
* That is, given strings S_i and suffix arrays A_i compute the suffix array
* A* = make-suffix-array(concat S_i)
* In order to do this we just implement mergesort's Merge operation on each
* of the arrays A_i to construct a sorted array A*.
*
* This algorithm is *NOT A LINEAR TIME ALGORITHM* in the worst case. If you run
* it on a dataset consisting entirely of the character A it will be quadratic.
* Fortunately for us, language model datasets typically don't just repeat the same
* character a hundred million times in a row. So in practice, it's linear time.
*
* There are thre complications here.
*
* As with selfsimilar_parallel, we can't fit all A_i into memory at once, and
* we want to make things fast and so parallelize our execution. So we do the
* same tricks as before to make things work.
*
* However we have one more problem. TODO
*/
let nn:usize = data_files.len();

fn load_text2<'s,'t>(fpath:String) -> Vec<u8> {
Expand All @@ -700,12 +775,11 @@ fn cmd_merge(data_files: &Vec<String>, output_file: &String, num_threads: i64)
println!("Done read buffer");
return text_;
}

let texts:Vec<Vec<u8>> = (0..nn).map(|x| load_text2(data_files[x].clone())).collect();

let texts_len:Vec<usize> = texts.iter().enumerate().map(|(i,x)| x.len() - (if i+1 == texts.len() {0} else {HACKSIZE})).collect();


let metadatas:Vec<u64> = (0..nn).map(|x| {
let meta = fs::metadata(format!("{}.table.bin", data_files[x].clone())).unwrap();
assert!(meta.len()%(texts[x].len() as u64) == 0);
Expand All @@ -715,8 +789,6 @@ fn cmd_merge(data_files: &Vec<String>, output_file: &String, num_threads: i64)
let ratio = metadatas[0] / (texts[0].len() as u64);
assert!(ratio == 8);

println!("Loading ratio is {}", ratio);

fn sdf(texts:&Vec<Vec<u8>>, starts:Vec<usize>, ends:Vec<usize>, texts_len:Vec<usize>, part:usize,
output_file: String, data_files: Vec<String>) {

Expand Down

0 comments on commit b682f7c

Please sign in to comment.