Skip to content

Commit

Permalink
Add support for createChatCompletion
Browse files Browse the repository at this point in the history
  • Loading branch information
dqbd committed Mar 5, 2023
1 parent df06be4 commit 937bcde
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions scripts/override_stream.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,6 @@ import {
MethodDeclaration,
} from "ts-morph";

const METHOD_NAME = "createCompletion";

function writeBody(writer: CodeBlockWriter, argName: string, body: string) {
return writer
.write(`if (${argName}.stream) `)
Expand Down Expand Up @@ -47,10 +45,10 @@ function extractKeyTypes(method: MethodDeclaration) {
return { requestName, requestType, responseType };
}

function transformObject(sourceNode: Node) {
function transformObject(methodName: string, sourceNode: Node) {
const method = sourceNode
.getDescendantsOfKind(ts.SyntaxKind.MethodDeclaration)
.filter((declaration) => declaration.getName() === METHOD_NAME)[0]!;
.filter((declaration) => declaration.getName() === methodName)[0]!;

const keyTypes = extractKeyTypes(method);
const body = method
Expand All @@ -65,19 +63,19 @@ function transformObject(sourceNode: Node) {
.getParent()
.asKindOrThrow(ts.SyntaxKind.ObjectLiteralExpression)
.insertPropertyAssignment(method.getChildIndex(), {
name: METHOD_NAME,
name: methodName,
leadingTrivia: docs + "\n",
initializer: (writer) =>
writer
.write("(() => {")
.indent(() => {
writer
.setIndentationLevel(writer.getIndentationLevel() - 1)
.write(`function ${METHOD_NAME}(): unknown`)
.write(`function ${methodName}(): unknown`)
.inlineBlock(() => writeBody(writer, keyTypes.requestName, body))
.write(";");

writer.writeLine(`return ${METHOD_NAME};`);
writer.writeLine(`return ${methodName};`);
})
.write("})()"),
})
Expand Down Expand Up @@ -109,10 +107,10 @@ function transformObject(sourceNode: Node) {
method.remove();
}

function transformClass(sourceNode: Node) {
function transformClass(methodName: string, sourceNode: Node) {
const method = sourceNode
.getDescendantsOfKind(ts.SyntaxKind.MethodDeclaration)
.filter((declaration) => declaration.getName() === METHOD_NAME)[0]!;
.filter((declaration) => declaration.getName() === methodName)[0]!;

const keyTypes = extractKeyTypes(method);

Expand All @@ -123,7 +121,7 @@ function transformClass(sourceNode: Node) {
const structure = method.getStructure();

method.set({
name: METHOD_NAME,
name: methodName,
scope: structure.scope,
parameters: structure.parameters,
docs: [],
Expand Down Expand Up @@ -153,8 +151,10 @@ project.addSourceFilesAtPaths(process.argv[process.argv.length - 1]);
const sourceFile = project.getSourceFileOrThrow("api.ts");
const declarations = sourceFile.getExportedDeclarations();

transformObject(declarations.get("OpenAIApiFp")![0]);
transformObject(declarations.get("OpenAIApiFactory")![0]);
transformClass(declarations.get("OpenAIApi")![0]);
for (const method of ["createCompletion", "createChatCompletion"]) {
transformObject(method, declarations.get("OpenAIApiFp")![0]);
transformObject(method, declarations.get("OpenAIApiFactory")![0]);
transformClass(method, declarations.get("OpenAIApi")![0]);
}

sourceFile.saveSync();

0 comments on commit 937bcde

Please sign in to comment.