Skip to content

Commit

Permalink
moved config into json
Browse files Browse the repository at this point in the history
  • Loading branch information
eisber committed Feb 27, 2023
1 parent 42548d8 commit e3ab3f6
Show file tree
Hide file tree
Showing 15 changed files with 241 additions and 308 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ global-include py.typed
recursive-include scripts *.py
recursive-include tests *.py
recursive-include src *.rs
include tiktoken *.json
139 changes: 41 additions & 98 deletions core/src/openai_public.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,112 +2,55 @@
use rustc_hash::FxHashMap as HashMap;
use std::error::Error;
use std::sync::RwLock;
use json;

#[path = "load.rs"]
mod load;

type Result<T> = std::result::Result<T, Box<dyn Error>>;

lazy_static! {
pub static ref REGISTRY: HashMap<String, EncodingLazy> = [
EncodingLazy::new(
"gpt2".into(),
Some(50257),
r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+".into(),
[ ("<|endoftext|>".into(), 50256), ].into_iter().collect(),
EncoderLoadingStrategy::DataGym(
DataGymDef {
vocab_bpe_file: "https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/vocab.bpe".into(),
encoder_json_file: "https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/encoder.json".into()
}
)),
EncodingLazy::new(
"r50k_base".into(),
Some(50257),
r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+".into(),
[ ("<|endoftext|>".into(), 50256), ].into_iter().collect(),
EncoderLoadingStrategy::BPE("https://openaipublic.blob.core.windows.net/encodings/r50k_base.tiktoken".into())
),
EncodingLazy::new(
"p50k_base".into(),
Some(50281),
r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+".into(),
[ ("<|endoftext|>".into(), 50256), ].into_iter().collect(),
EncoderLoadingStrategy::BPE("https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken".into())
),
EncodingLazy::new(
"p50k_edit".into(),
Some(50281),
r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+".into(),
[
("<|endoftext|>".into(), 50256),
("<|fim_prefix|>".into(), 50281),
("<|fim_middle|>".into(), 50282),
("<|fim_suffix|>".into(), 50283),
].into_iter().collect(),
EncoderLoadingStrategy::BPE("https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken".into())
),
EncodingLazy::new(
"cl100k_base".into(),
None,
r"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+".into(),
[
("<|endoftext|>".into(), 100257),
("<|fim_prefix|>".into(), 100258),
("<|fim_middle|>".into(), 100259),
("<|fim_suffix|>".into(), 100260),
("<|endofprompt|>".into(), 100276),
].into_iter().collect(),
EncoderLoadingStrategy::BPE("https://openaipublic.blob.core.windows.net/encodings/p50k_base.tiktoken".into())
),
]
.into_iter()
pub static ref REGISTRY: HashMap<String, EncodingLazy> = {
// TODO: error handling
json::parse(include_str!("../../tiktoken/registry.json"))
.expect("Failed to parse internal JSON")
.entries()
.map(|(key, value)| {
let loading_strategy = if value.has_key("data_gym_to_mergeable_bpe_ranks") {
EncoderLoadingStrategy::DataGym(
DataGymDef {
vocab_bpe_file: value["data_gym_to_mergeable_bpe_ranks"]["vocab_bpe_file"].as_str().expect("error").into(),
encoder_json_file: value["data_gym_to_mergeable_bpe_ranks"]["encoder_json_file"].as_str().expect("error").into()
})
}
else if value.has_key("load_tiktoken_bpe") {
EncoderLoadingStrategy::BPE(value["load_tiktoken_bpe"].as_str().expect("fail").into())
}
else {
panic!("Invalid encoding");
};

EncodingLazy::new(
key.into(),
value["explicit_n_vocab"].as_usize(),
value["pat_str"].as_str().expect("foo").into(),
value["special_tokens"].entries()
.map(|(key, value)| (key.into(), value.as_usize().expect("foo")))
.collect::<HashMap<String, usize>>(),
loading_strategy
)
})

.map(|enc| (enc.name.clone(), enc))
.collect::<HashMap<String, EncodingLazy>>();



pub static ref MODEL_TO_ENCODING: HashMap<String, String> = [
// text
("text-davinci-003", "p50k_base"),
("text-davinci-002", "p50k_base"),
("text-davinci-001", "r50k_base"),
("text-curie-001", "r50k_base"),
("text-babbage-001", "r50k_base"),
("text-ada-001", "r50k_base"),
("davinci", "r50k_base"),
("curie", "r50k_base"),
("babbage", "r50k_base"),
("ada", "r50k_base"),
// code
("code-davinci-002", "p50k_base"),
("code-davinci-001", "p50k_base"),
("code-cushman-002", "p50k_base"),
("code-cushman-001", "p50k_base"),
("davinci-codex", "p50k_base"),
("cushman-codex", "p50k_base"),
// edit
("text-davinci-edit-001", "p50k_edit"),
("code-davinci-edit-001", "p50k_edit"),
// embeddings
("text-embedding-ada-002", "cl100k_base"),
// old embeddings
("text-similarity-davinci-001", "r50k_base"),
("text-similarity-curie-001", "r50k_base"),
("text-similarity-babbage-001", "r50k_base"),
("text-similarity-ada-001", "r50k_base"),
("text-search-davinci-doc-001", "r50k_base"),
("text-search-curie-doc-001", "r50k_base"),
("text-search-babbage-doc-001", "r50k_base"),
("text-search-ada-doc-001", "r50k_base"),
("code-search-babbage-code-001", "r50k_base"),
("code-search-ada-code-001", "r50k_base"),
// open source
("gpt2", "gpt2"),
]
.into_iter()
.map(|(k, v)| (k.to_string(), v.to_string()))
.collect::<HashMap<String, String>>();
.collect::<HashMap<String, EncodingLazy>>()
};

pub static ref MODEL_TO_ENCODING: HashMap<String, String> =
json::parse(include_str!("../../tiktoken/model_to_encoding.json"))
.expect("Failed to parse internal JSON")
.entries()
.map(|(k, v)| (k.into(), v.as_str().expect("foo").into()))
.collect::<HashMap<String, String>>();
}

#[derive(Clone, PartialEq, Eq, Hash)]
Expand Down
43 changes: 27 additions & 16 deletions java/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

<name>tiktoken</name>
<url>https://github.com/openai/tiktoken</url>
<packaging>jar</packaging>

<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
Expand All @@ -24,9 +25,23 @@
<version>4.11</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scijava</groupId>
<artifactId>native-lib-loader</artifactId>
<version>2.4.0</version>
</dependency>
</dependencies>

<build>
<resources>
<resource>
<directory>${project.build.directory}/../../target/release/</directory>
<targetPath>${project.build.directory}/classes/natives/linux_64</targetPath>
<includes>
<include>lib_tiktoken_jni.so</include>
</includes>
</resource>
</resources>
<pluginManagement><!-- lock down plugins versions to avoid using Maven defaults (may be moved to parent pom) -->
<plugins>
<!-- clean lifecycle, see https://maven.apache.org/ref/current/maven-core/lifecycles.html#clean_Lifecycle -->
Expand Down Expand Up @@ -69,22 +84,18 @@
<version>3.0.0</version>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>2.17</version>
<executions>
<execution>
<id>surefire-test</id>
<phase>test</phase>
<goals>
<goal>test</goal>
</goals>
</execution>
</executions>
<configuration>
<argLine>-Djava.library.path=${project.build.directory}/../../target/debug/</argLine>
</configuration>
</plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-failsafe-plugin</artifactId>
<version>2.22.1</version>
<executions>
<execution>
<goals>
<goal>integration-test</goal>
<goal>verify</goal>
</goals>
</execution>
</executions>
</plugin>
</plugins>
</pluginManagement>
</build>
Expand Down
15 changes: 12 additions & 3 deletions java/src/main/java/tiktoken/Encoding.java
Original file line number Diff line number Diff line change
@@ -1,20 +1,29 @@
package tiktoken;

import org.scijava.nativelib.NativeLoader;
import java.io.IOException;

public class Encoding implements AutoCloseable
{
static {
System.loadLibrary("_tiktoken_jni");
// TODO: unpack the library from the jar
// System.loadLibrary("_tiktoken_jni");
try {
NativeLoader.loadLibrary("_tiktoken_jni");
}
catch(IOException e) {
throw new RuntimeException(e);
}
}

// initialized by init
private long handle;

private native void init(String modelName);

public native long[] encode(String text, String[] allowedSpecialTokens, long maxTokenLength);

private native void destroy();

public native long[] encode(String text, String[] allowedSpecialTokens, long maxTokenLength);

public Encoding(String modelName) {
this.init(modelName);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
package tiktoken;

import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertTrue;

import org.junit.Test;

public class EncodingTest
// run test: mvn failsafe:integration-test
public class EncodingTestIT
{
@Test
public void shouldAnswerWithTrue() throws Exception
Expand All @@ -16,7 +16,6 @@ public void shouldAnswerWithTrue() throws Exception

encoding.close();

assertTrue( true );
assertArrayEquals(new long[] {9288}, a);
}
}
29 changes: 29 additions & 0 deletions java/src/tiktoken_Encoding.h

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions jni/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,8 @@ jni = "0.20.0"

[profile.release]
incremental = true
opt-level = 'z' # Optimize for size
lto = true # Enable link-time optimization
codegen-units = 1 # Reduce number of codegen units to increase optimizations
panic = 'abort' # Abort on panic
strip = true # Strip symbols from binary*
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
debug=False,
)
],
package_data={"tiktoken": ["py.typed"]},
packages=["tiktoken", "tiktoken_ext"],
include_package_data=True,
package_data={ "tiktoken": ["py.typed", "registry.json", "model_to_encoding.json"] },
packages=["tiktoken"],
zip_safe=False,
)
15 changes: 0 additions & 15 deletions tests/test_simple_public.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,18 +24,3 @@ def test_encoding_for_model():
assert enc.name == "gpt2"
enc = tiktoken.encoding_for_model("text-davinci-003")
assert enc.name == "p50k_base"

def test_loading():
x = tiktoken.load.data_gym_to_mergeable_bpe_ranks(
vocab_bpe_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/vocab.bpe",
encoder_json_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/encoder.json",
)

print(len(x))

y = tiktoken._tiktoken.py_data_gym_to_mergable_bpe_ranks(
vocab_bpe_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/vocab.bpe",
encoder_json_file="https://openaipublic.blob.core.windows.net/gpt-2/encodings/main/encoder.json",
)

print(len(y))
4 changes: 0 additions & 4 deletions tiktoken/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def data_gym_to_mergeable_bpe_ranks(
# NB: do not add caching to this function
rank_to_intbyte = [b for b in range(2**8) if chr(b).isprintable() and chr(b) != " "]

print(f"rank_to_intbyte: {len(rank_to_intbyte)}")
data_gym_byte_to_byte = {chr(b): b for b in rank_to_intbyte}
n = 0
for b in range(2**8):
Expand All @@ -75,9 +74,6 @@ def decode_data_gym(value: str) -> bytes:
# add the single byte tokens
bpe_ranks = {bytes([b]): i for i, b in enumerate(rank_to_intbyte)}

# print(len(rank_to_intbyte))
print(f"py data gym: {len(data_gym_byte_to_byte)} '{data_gym_byte_to_byte[chr(288)]}'")

# add the merged tokens
n = len(bpe_ranks)
for first, second in bpe_merges:
Expand Down
Loading

0 comments on commit e3ab3f6

Please sign in to comment.