Skip to content

Commit

Permalink
feat: create embeddings for uploaded documents
Browse files Browse the repository at this point in the history
  • Loading branch information
marcusschiesser committed Oct 23, 2023
1 parent dfa3310 commit 345b350
Show file tree
Hide file tree
Showing 12 changed files with 208 additions and 85 deletions.
36 changes: 36 additions & 0 deletions app/api/fetch/embeddings.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import { Embedding } from "@/app/client/fetch";
import {
DATASOURCES_CHUNK_OVERLAP,
DATASOURCES_CHUNK_SIZE,
} from "@/scripts/constants.mjs";
import {
Document,
MetadataMode,
SentenceSplitter,
VectorStoreIndex,
getNodesFromDocument,
serviceContextFromDefaults,
} from "llamaindex";

export default async function splitAndEmbed(
document: string,
): Promise<Embedding[]> {
const nodes = getNodesFromDocument(
new Document({ text: document }),
new SentenceSplitter({
chunkSize: DATASOURCES_CHUNK_SIZE,
chunkOverlap: DATASOURCES_CHUNK_OVERLAP,
}),
);

const nodesWithEmbeddings = await VectorStoreIndex.getNodeEmbeddingResults(
nodes,
serviceContextFromDefaults(),
true,
);

return nodesWithEmbeddings.map((nodeWithEmbedding) => ({
text: nodeWithEmbedding.getContent(MetadataMode.NONE),
embedding: nodeWithEmbedding.getEmbedding(),
}));
}
58 changes: 48 additions & 10 deletions app/api/fetch/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ import {
getPDFContentFromBuffer,
} from "@/app/utils/content";
import { NextResponse, NextRequest } from "next/server";
import splitAndEmbed from "./embeddings";
import { URLDetailContent } from "@/app/client/fetch";

export async function GET(request: NextRequest) {
const url = new URL(request.url);
Expand All @@ -26,22 +28,58 @@ export async function GET(request: NextRequest) {
}
}

async function handleText(
fileName: string,
text: string,
): Promise<URLDetailContent> {
const embeddings = await splitAndEmbed(text);
return {
content: text,
embeddings: embeddings,
url: fileName,
size: text.length,
type: "text/plain",
};
}

async function handlePDF(
fileName: string,
pdf: string,
): Promise<URLDetailContent> {
const pdfBuffer = Buffer.from(pdf, "base64");
const pdfData = await getPDFContentFromBuffer(pdfBuffer);
const embeddings = await splitAndEmbed(pdfData.content);
return {
content: pdfData.content,
embeddings: embeddings,
size: pdfData.size,
type: "application/pdf",
url: fileName,
};
}

type Input = {
fileName: string;
pdf?: string;
text?: string;
};

export async function POST(request: NextRequest) {
try {
const body = await request.json();
if (!body || !body.pdf) {
const { fileName, pdf, text }: Input = await request.json();
if (!fileName && (!pdf || !text)) {
return NextResponse.json(
{ error: "PDF file is required in the request body" },
{
error:
"filename and either text or pdf is required in the request body",
},
{ status: 400 },
);
}

const pdfBuffer = Buffer.from(body.pdf, "base64");
const pdfData = await getPDFContentFromBuffer(pdfBuffer);
return NextResponse.json({
...pdfData,
url: body.fileName,
});
const json = await (pdf
? handlePDF(fileName, pdf)
: handleText(fileName, text!));
return NextResponse.json(json);
} catch (error) {
return NextResponse.json(
{
Expand Down
51 changes: 47 additions & 4 deletions app/api/llm/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,36 @@ import {
ChatMessage,
DefaultContextGenerator,
HistoryChatEngine,
IndexDict,
OpenAI,
ServiceContext,
SimpleChatHistory,
SummaryChatHistory,
TextNode,
VectorStoreIndex,
serviceContextFromDefaults,
} from "llamaindex";
import { NextRequest, NextResponse } from "next/server";
import { LLMConfig } from "../../client/platforms/llm";
import { getDataSource } from "./datasource";
import { DATASOURCES_CHUNK_SIZE } from "@/scripts/constants.mjs";
import { Embedding } from "@/app/client/fetch";

async function createChatEngine(
serviceContext: ServiceContext,
datasource?: string,
embeddings?: Embedding[],
) {
let contextGenerator;
if (datasource) {
const index = await getDataSource(serviceContext, datasource);
const retriever = index.asRetriever();
if (datasource || embeddings) {
let index;
if (datasource) {
index = await getDataSource(serviceContext, datasource);
}
if (embeddings) {
index = await createIndex(serviceContext, embeddings);
}
const retriever = index!.asRetriever();
retriever.similarityTopK = 5;

contextGenerator = new DefaultContextGenerator({ retriever });
Expand All @@ -32,6 +43,32 @@ async function createChatEngine(
});
}

async function createIndex(
serviceContext: ServiceContext,
embeddings: Embedding[],
) {
const embeddingResults = embeddings.map((config) => {
return new TextNode({ text: config.text, embedding: config.embedding });
});
const indexDict = new IndexDict();
for (const node of embeddingResults) {
indexDict.addNode(node);
}

const index = await VectorStoreIndex.init({
indexStruct: indexDict,
serviceContext: serviceContext,
});

index.vectorStore.add(embeddingResults);
if (!index.vectorStore.storesText) {
await index.docStore.addDocuments(embeddingResults, true);
}
await index.indexStore?.addIndexStruct(indexDict);
index.indexStruct = indexDict;
return index;
}

export async function POST(request: NextRequest) {
try {
const body = await request.json();
Expand All @@ -40,11 +77,13 @@ export async function POST(request: NextRequest) {
chatHistory: messages,
datasource,
config,
embeddings,
}: {
message: string;
chatHistory: ChatMessage[];
datasource: string | undefined;
config: LLMConfig;
embeddings: Embedding[] | undefined;
} = body;
if (!message || !messages || !config) {
return NextResponse.json(
Expand All @@ -68,7 +107,11 @@ export async function POST(request: NextRequest) {
chunkSize: DATASOURCES_CHUNK_SIZE,
});

const chatEngine = await createChatEngine(serviceContext, datasource);
const chatEngine = await createChatEngine(
serviceContext,
datasource,
embeddings,
);
const chatHistory = config.sendMemory
? new SummaryChatHistory({ llm, messages })
: new SimpleChatHistory({ messages });
Expand Down
4 changes: 2 additions & 2 deletions app/bots/bot.data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ export const BUILTIN_BOTS: Bot[] = [
botHello: "Hello! How can I assist you today?",
context: [],
modelConfig: {
model: "gpt-3.5-turbo",
model: "gpt-3.5-turbo-16k",
temperature: 0.5,
maxTokens: 6000,
maxTokens: 8000,
sendMemory: true,
},
readOnly: true,
Expand Down
15 changes: 15 additions & 0 deletions app/client/fetch.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
export type Embedding = {
text: string;
embedding: number[];
};

export type URLDetail = {
url: string;
size: number;
type: "text/html" | "application/pdf" | "text/plain";
embeddings?: Embedding[];
};

export type URLDetailContent = URLDetail & {
content?: string;
};
3 changes: 3 additions & 0 deletions app/client/platforms/llm.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ import { REQUEST_TIMEOUT_MS } from "@/app/constant";

import { prettyObject } from "@/app/utils/format";
import { fetchEventSource } from "@fortaine/fetch-event-source";
import { Embedding } from "../fetch";

export const MESSAGE_ROLES = [
"system",
Expand Down Expand Up @@ -40,6 +41,7 @@ export interface ChatOptions {
chatHistory: RequestMessage[];
config: LLMConfig;
datasource?: string;
embeddings?: Embedding[];
onUpdate?: (message: string) => void;
onFinish: (newMessages: RequestMessage[]) => void;
onError?: (err: Error) => void;
Expand All @@ -56,6 +58,7 @@ export class LLMApi {
})),
config: options.config,
datasource: options.datasource,
embeddings: options.embeddings,
};

console.log("[Request] payload: ", requestPayload);
Expand Down
24 changes: 21 additions & 3 deletions app/components/chat/chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import {
REQUEST_TIMEOUT_MS,
} from "../../constant";
import Locale from "../../locales";
import { ChatMessage, createMessage } from "../../store";
import { ChatMessage, callSession, createMessage } from "../../store";
import { useMobileScreen } from "../../utils/mobile";
import { autoGrowTextArea } from "../../utils/autogrow";
import { copyToClipboard } from "@/app/utils/clipboard";
Expand Down Expand Up @@ -129,16 +129,34 @@ export function Chat() {
});
};

const onUserInput = async (input: string | FileWrap) => {
const inputContent = input instanceof FileWrap ? input.name : input;
await callSession(
bot,
session,
inputContent,
{
onUpdateMessages: (messages) => {
botStore.updateCurrentSession((session) => {
// trigger re-render of messages
session.messages = messages;
});
},
},
input instanceof FileWrap ? input : undefined,
);
};

const doSubmitFile = async (fileInput: FileWrap) => {
await botStore.onUserInput(fileInput);
await onUserInput(fileInput);
};

const doSubmit = (userInput: string) => {
if (userInput.trim() === "") return;
if (isURL(userInput)) {
setTemporaryURLInput(userInput);
}
botStore.onUserInput(userInput).then(() => {
onUserInput(userInput).then(() => {
setTemporaryURLInput("");
});
setUserInput("");
Expand Down
11 changes: 0 additions & 11 deletions app/global.d.ts

This file was deleted.

24 changes: 1 addition & 23 deletions app/store/bot.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@ import { create } from "zustand";
import { persist } from "zustand/middleware";
import { LLMConfig } from "../client/platforms/llm";
import { Deployment } from "./deployment";
import { ChatSession, ChatMessage, callSession } from "./session";
import { ChatSession, ChatMessage } from "./session";
import {
BUILTIN_BOTS,
botListToMap,
createEmptyBot,
} from "@/app/bots/bot.data";
import { FileWrap } from "@/app/utils/file";

export type Share = {
id: string;
Expand Down Expand Up @@ -41,7 +40,6 @@ type BotStore = BotState & {
selectBot: (id: string) => void;
currentSession: () => ChatSession;
updateCurrentSession: (updater: (session: ChatSession) => void) => void;
onUserInput: (input: string | FileWrap) => Promise<void>;
get: (id: string) => Bot | undefined;
getByShareId: (shareId: string) => Bot | undefined;
getAll: () => Bot[];
Expand Down Expand Up @@ -76,25 +74,6 @@ export const useBotStore = create<BotStore>()(
updater(bots[get().currentBotId].session);
set(() => ({ bots }));
},
async onUserInput(input) {
const inputContent = input instanceof FileWrap ? input.name : input;
const session = get().currentSession();
await callSession(
get().currentBot(),
session,
inputContent,
{
onUpdateMessages: (messages) => {
get().updateCurrentSession((session) => {
// trigger re-render of messages
session.messages = messages;
});
},
},
input instanceof FileWrap ? input : undefined,
);
},

get(id) {
return get().bots[id];
},
Expand All @@ -110,7 +89,6 @@ export const useBotStore = create<BotStore>()(
.getAll()
.find((b) => shareId === b.share?.id);
},

create(bot, options) {
const bots = get().bots;
const id = nanoid();
Expand Down
Loading

0 comments on commit 345b350

Please sign in to comment.