Skip to content

Commit

Permalink
Added encoding initialization by model name
Browse files Browse the repository at this point in the history
  • Loading branch information
dmytrostruk committed Apr 18, 2023
1 parent 84ce15a commit e38695d
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 11 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,13 @@ To use SharpToken in your project, first import the library:
```csharp
using SharpToken;
```
Next, create an instance of GptEncoding by specifying the desired encoding:
Next, create an instance of GptEncoding by specifying the desired encoding or model:
```csharp
// Get encoding by encoding name
var encoding = GptEncoding.GetEncoding("cl100k_base");

// Get encoding by model name
var encoding = GptEncoding.GetEncodingForModel("gpt-4");
```

You can then use the Encode method to encode a string:
Expand Down
21 changes: 19 additions & 2 deletions SharpToken.Tests/SharpToken.Tests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,14 @@ public class Tests
{
private static readonly List<string> ModelsList = new() { "p50k_base", "r50k_base", "cl100k_base" };

private static List<Tuple<string, string, List<int>>> _testData =
private static readonly List<Tuple<string, string, List<int>>> _testData =
TestHelpers.ReadTestPlans("SharpToken.Tests.data.TestPlans.txt");

[SetUp]
public void Setup()
{
}


[Test]
[TestCaseSource(nameof(_testData))]
public void TestEncodingAndDecoding(Tuple<string, string, List<int>> resource)
Expand Down Expand Up @@ -100,4 +99,22 @@ public async Task TestLocalResourceMatchesRemoteResource(string modelName)
// Compare the contents of the files and assert their equality
Assert.That(normalizedEmbeddedResourceText, Is.EqualTo(normalizedRemoteResourceText));
}

[Test]
public void TestEncodingForModel()
{
const string modelName = "gpt-4";
const string inputText = "Hello, world!";
var expectedEncoded = new List<int> { 9906, 11, 1917, 0 };

var encoding = GptEncoding.GetEncodingForModel(modelName);
var encoded = encoding.Encode(inputText);
var decodedText = encoding.Decode(encoded);

Assert.Multiple(() =>
{
Assert.That(encoded, Is.EqualTo(expectedEncoded));
Assert.That(decodedText, Is.EqualTo(inputText));
});
}
}
13 changes: 9 additions & 4 deletions SharpToken/Lib/Encoding.cs
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
using System;
using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
using System.Text.RegularExpressions;

namespace SharpToken
{

public class GptEncoding
{
private readonly BytePairEncodingCore _bytePairEncodingCoreProcessor;
Expand Down Expand Up @@ -43,15 +42,21 @@ private GptEncoding(string patternString,

private int MaxTokenValue { get; }

public static GptEncoding GetEncoding(string modelName)
public static GptEncoding GetEncoding(string encodingName)
{
var modelParams = ModelParamsGenerator.GetModelParams(modelName);
var modelParams = ModelParamsGenerator.GetModelParams(encodingName);

var encoding = new GptEncoding(modelParams.PatStr, modelParams.MergeableRanks,
modelParams.SpecialTokens, modelParams.ExplicitNVocab);
return encoding;
}

public static GptEncoding GetEncodingForModel(string modelName)
{
var encodingName = Model.GetEncodingNameForModel(modelName);
return GetEncoding(encodingName);
}

private static string SpecialTokenRegex(ISet<string> tokens)
{
var escapedTokens = new List<string>();
Expand Down
61 changes: 61 additions & 0 deletions SharpToken/Lib/Model.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
using System;
using System.Collections.Generic;

namespace SharpToken
{
public static class Model
{
private static readonly Dictionary<string, string> ModelToEncodingMapping = new Dictionary<string, string>()
{
// chat
{ "gpt-4", "cl100k_base" },
{ "gpt-3.5-turbo", "cl100k_base" },
// text
{ "text-davinci-003", "p50k_base" },
{ "text-davinci-002", "p50k_base" },
{ "text-davinci-001", "r50k_base" },
{ "text-curie-001", "r50k_base" },
{ "text-babbage-001", "r50k_base" },
{ "text-ada-001", "r50k_base" },
{ "davinci", "r50k_base" },
{ "curie", "r50k_base" },
{ "babbage", "r50k_base" },
{ "ada", "r50k_base" },
// code
{ "code-davinci-002", "p50k_base" },
{ "code-davinci-001", "p50k_base" },
{ "code-cushman-002", "p50k_base" },
{ "code-cushman-001", "p50k_base" },
{ "davinci-codex", "p50k_base" },
{ "cushman-codex", "p50k_base" },
// edit
{ "text-davinci-edit-001", "p50k_edit" },
{ "code-davinci-edit-001", "p50k_edit" },
// embeddings
{ "text-embedding-ada-002", "cl100k_base" },
// old embeddings
{ "text-similarity-davinci-001", "r50k_base" },
{ "text-similarity-curie-001", "r50k_base" },
{ "text-similarity-babbage-001", "r50k_base" },
{ "text-similarity-ada-001", "r50k_base" },
{ "text-search-davinci-doc-001", "r50k_base" },
{ "text-search-curie-doc-001", "r50k_base" },
{ "text-search-babbage-doc-001", "r50k_base" },
{ "text-search-ada-doc-001", "r50k_base" },
{ "code-search-babbage-code-001", "r50k_base" },
{ "code-search-ada-code-001", "r50k_base" },
};

public static string GetEncodingNameForModel(string modelName)
{
if (ModelToEncodingMapping.TryGetValue(modelName, out var encodingName))
{
return encodingName;
}

throw new Exception(
$"Could not automatically map {modelName} to a tokenizer. " +
$"Please use {nameof(GptEncoding.GetEncoding)} to explicitly get the tokenizer you expect.");
}
}
}
8 changes: 4 additions & 4 deletions SharpToken/Lib/ModelParamsGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System;
using System;
using System.Collections.Generic;

namespace SharpToken
Expand Down Expand Up @@ -32,15 +32,15 @@ public static class ModelParamsGenerator
private const string FimSuffix = "<|fim_suffix|>";
private const string EndOfPrompt = "<|endofprompt|>";

public static ModelParams GetModelParams(string name)
public static ModelParams GetModelParams(string encodingName)
{
return name.ToLower() switch
return encodingName.ToLower() switch
{
"r50k_base" => R50KBase(),
"p50k_base" => P50KBase(),
"p50k_edit" => P50KEdit(),
"cl100k_base" => Cl100KBase(),
_ => throw new ArgumentException($"Unknown model name: {name}")
_ => throw new ArgumentException($"Unknown encoding name: {encodingName}")
};
}

Expand Down

0 comments on commit e38695d

Please sign in to comment.