Skip to content

Commit

Permalink
feat: vector lookup and chat is twice as fast now
Browse files Browse the repository at this point in the history
  • Loading branch information
Dhravya committed Jun 23, 2024
1 parent 9df975e commit 9f751ba
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 105 deletions.
184 changes: 102 additions & 82 deletions apps/cf-ai-backend/src/index.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { z } from "zod";
import { Hono } from "hono";
import { CoreMessage, generateText, streamText } from "ai";
import { CoreMessage, streamText } from "ai";
import { chatObj, Env, vectorObj } from "./types";
import {
batchCreateChunksAndEmbeddings,
Expand Down Expand Up @@ -193,90 +193,102 @@ app.post(
const body = c.req.valid("json");

const sourcesOnly = query.sourcesOnly === "true";

// Return early for dumb requests
if (sourcesOnly && body.sources) {
return c.json(body.sources);
}

const spaces = query.spaces?.split(",") ?? [undefined];

// Get the AI model maker and vector store
const { model, store } = await initQuery(c, query.model);

const filter: VectorizeVectorMetadataFilter = {
[`user-${query.user}`]: 1,
};
console.log("Spaces", spaces);

// Converting the query to a vector so that we can search for similar vectors
const queryAsVector = await store.embeddings.embedQuery(query.query);
const responses: VectorizeMatches = { matches: [], count: 0 };

console.log("hello world", spaces);

// SLICED to 5 to avoid too many queries
for (const space of spaces.slice(0, 5)) {
console.log("space", space);
if (!space && spaces.length > 1) {
// it's possible for space list to be [undefined] so we only add space filter conditionally
filter[`space-${query.user}-${space}`] = 1;
if (!body.sources) {
const filter: VectorizeVectorMetadataFilter = {
[`user-${query.user}`]: 1,
};
console.log("Spaces", spaces);

// Converting the query to a vector so that we can search for similar vectors
const queryAsVector = await store.embeddings.embedQuery(query.query);
const responses: VectorizeMatches = { matches: [], count: 0 };

console.log("hello world", spaces);

// SLICED to 5 to avoid too many queries
for (const space of spaces.slice(0, 5)) {
console.log("space", space);
if (!space && spaces.length > 1) {
// it's possible for space list to be [undefined] so we only add space filter conditionally
filter[`space-${query.user}-${space}`] = 1;
}

// Because there's no OR operator in the filter, we have to make multiple queries
const resp = await c.env.VECTORIZE_INDEX.query(queryAsVector, {
topK: query.topK,
filter,
returnMetadata: true,
});

// Basically recreating the response object
if (resp.count > 0) {
responses.matches.push(...resp.matches);
responses.count += resp.count;
}
}

// Because there's no OR operator in the filter, we have to make multiple queries
const resp = await c.env.VECTORIZE_INDEX.query(queryAsVector, {
topK: query.topK,
filter,
returnMetadata: true,
});

// Basically recreating the response object
if (resp.count > 0) {
responses.matches.push(...resp.matches);
responses.count += resp.count;
}
}
const minScore = Math.min(...responses.matches.map(({ score }) => score));
const maxScore = Math.max(...responses.matches.map(({ score }) => score));

// We are "normalising" the scores - if all of them are on top, we want to make sure that
// we have a way to filter out the noise.
const normalizedData = responses.matches.map((data) => ({
...data,
normalizedScore:
maxScore !== minScore
? 1 + ((data.score - minScore) / (maxScore - minScore)) * 98
: 50, // If all scores are the same, set them to the middle of the scale
}));

let highScoreData = normalizedData.filter(
({ normalizedScore }) => normalizedScore > 50,
);

const minScore = Math.min(...responses.matches.map(({ score }) => score));
const maxScore = Math.max(...responses.matches.map(({ score }) => score));

// We are "normalising" the scores - if all of them are on top, we want to make sure that
// we have a way to filter out the noise.
const normalizedData = responses.matches.map((data) => ({
...data,
normalizedScore:
maxScore !== minScore
? 1 + ((data.score - minScore) / (maxScore - minScore)) * 98
: 50, // If all scores are the same, set them to the middle of the scale
}));

let highScoreData = normalizedData.filter(
({ normalizedScore }) => normalizedScore > 50,
);
// If the normalsation is not done properly, we have a fallback to just get the
// top 3 scores
if (highScoreData.length === 0) {
highScoreData = normalizedData
.sort((a, b) => b.score - a.score)
.slice(0, 3);
}

// If the normalsation is not done properly, we have a fallback to just get the
// top 3 scores
if (highScoreData.length === 0) {
highScoreData = normalizedData
.sort((a, b) => b.score - a.score)
.slice(0, 3);
}
const sortedHighScoreData = highScoreData.sort(
(a, b) => b.normalizedScore - a.normalizedScore,
);

const sortedHighScoreData = highScoreData.sort(
(a, b) => b.normalizedScore - a.normalizedScore,
);
body.sources = {
normalizedData,
};

// So this is kinda hacky, but the frontend needs to do 2 calls to get sources and chat.
// I think this is fine for now, but we can improve this later.
if (sourcesOnly) {
const idsAsStrings = sortedHighScoreData.map((dataPoint) =>
dataPoint.id.toString(),
);
// So this is kinda hacky, but the frontend needs to do 2 calls to get sources and chat.
// I think this is fine for now, but we can improve this later.
if (sourcesOnly) {
const idsAsStrings = sortedHighScoreData.map((dataPoint) =>
dataPoint.id.toString(),
);

const storedContent = await Promise.all(
idsAsStrings.map(async (id) => await c.env.KV.get(id)),
);
const storedContent = await Promise.all(
idsAsStrings.map(async (id) => await c.env.KV.get(id)),
);

const metadata = normalizedData.map((datapoint) => datapoint.metadata);
const metadata = normalizedData.map((datapoint) => datapoint.metadata);

return c.json({ ids: storedContent, metadata });
return c.json({ ids: storedContent, metadata, normalizedData });
}
}

const preparedContext = normalizedData.map(
const preparedContext = body.sources.normalizedData.map(
({ metadata, score, normalizedScore }) => ({
context: `Website title: ${metadata!.title}\nDescription: ${metadata!.description}\nURL: ${metadata!.url}\nContent: ${metadata!.text}`,
score,
Expand Down Expand Up @@ -330,20 +342,28 @@ app.delete(
},
);

// ERROR #1 - this is the api that the editor uses, it is just a scrape off of /api/chat so you may check that out
app.get('/api/editorai', zValidator(
"query",
z.object({
context: z.string(),
request: z.string(),
}),
), async (c)=> {
const { context, request } = c.req.valid("query");
const { model } = await initQuery(c);
// ERROR #1 - this is the api that the editor uses, it is just a scrape off of /api/chat so you may check that out
app.get(
"/api/editorai",
zValidator(
"query",
z.object({
context: z.string(),
request: z.string(),
}),
),
async (c) => {
const { context, request } = c.req.valid("query");
const { model } = await initQuery(c);

const response = await streamText({ model, prompt: `${request}-${context}`, maxTokens: 224 });
const response = await streamText({
model,
prompt: `${request}-${context}`,
maxTokens: 224,
});

return response.toTextStreamResponse();
})
return response.toTextStreamResponse();
},
);

export default app;
2 changes: 2 additions & 0 deletions apps/cf-ai-backend/src/types.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import { sourcesZod } from "@repo/shared-types";
import { z } from "zod";

export type Env = {
Expand Down Expand Up @@ -37,6 +38,7 @@ export const contentObj = z.object({

export const chatObj = z.object({
chatHistory: z.array(contentObj).optional(),
sources: sourcesZod.optional(),
});

export const vectorObj = z.object({
Expand Down
17 changes: 3 additions & 14 deletions apps/web/app/(dash)/chat/chatWindow.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import QueryInput from "../home/queryinput";
import { cn } from "@repo/ui/lib/utils";
import { motion } from "framer-motion";
import { useRouter } from "next/navigation";
import { ChatHistory } from "@repo/shared-types";
import { ChatHistory, sourcesZod } from "@repo/shared-types";
import {
Accordion,
AccordionContent,
Expand All @@ -20,15 +20,10 @@ import rehypeKatex from "rehype-katex";
import rehypeHighlight from "rehype-highlight";
import { code, p } from "./markdownRenderHelpers";
import { codeLanguageSubset } from "@/lib/constants";
import { z } from "zod";
import { toast } from "sonner";
import Link from "next/link";
import { createChatObject } from "@/app/actions/doers";
import {
ClipboardIcon,
ShareIcon,
SpeakerWaveIcon,
} from "@heroicons/react/24/outline";
import { ClipboardIcon } from "@heroicons/react/24/outline";
import { SendIcon } from "lucide-react";

function ChatWindow({
Expand Down Expand Up @@ -83,11 +78,6 @@ function ChatWindow({
// TODO: handle this properly
const sources = await sourcesFetch.json();

const sourcesZod = z.object({
ids: z.array(z.string()),
metadata: z.array(z.any()),
});

const sourcesParsed = sourcesZod.safeParse(sources);

if (!sourcesParsed.success) {
Expand All @@ -100,7 +90,6 @@ function ChatWindow({
behavior: "smooth",
});

// Assuming this is part of a larger function within a React component
const updateChatHistoryAndFetch = async () => {
// Step 1: Update chat history with the assistant's response
await new Promise((resolve) => {
Expand Down Expand Up @@ -143,7 +132,7 @@ function ChatWindow({
`/api/chat?q=${query}&spaces=${spaces}&threadId=${threadId}`,
{
method: "POST",
body: JSON.stringify({ chatHistory }),
body: JSON.stringify({ chatHistory, sources: sourcesParsed.data }),
},
);

Expand Down
24 changes: 15 additions & 9 deletions apps/web/app/api/chat/route.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import { type NextRequest } from "next/server";
import { ChatHistoryZod, convertChatHistoryList } from "@repo/shared-types";
import {
ChatHistory,
ChatHistoryZod,
convertChatHistoryList,
SourcesFromApi,
} from "@repo/shared-types";
import { ensureAuth } from "../ensureAuth";
import { z } from "zod";

Expand All @@ -23,17 +28,19 @@ export async function POST(req: NextRequest) {

const sourcesOnly = url.searchParams.get("sourcesOnly") ?? "false";

const chatHistory = await req.json();
const jsonRequest = (await req.json()) as {
chatHistory: ChatHistory[];
sources: SourcesFromApi[] | undefined;
};
const { chatHistory, sources } = jsonRequest;

if (!query || query.trim.length < 0) {
return new Response(JSON.stringify({ message: "Invalid query" }), {
status: 400,
});
}

const validated = z
.object({ chatHistory: z.array(ChatHistoryZod) })
.safeParse(chatHistory ?? []);
const validated = z.array(ChatHistoryZod).safeParse(chatHistory ?? []);

if (!validated.success) {
return new Response(
Expand All @@ -45,9 +52,7 @@ export async function POST(req: NextRequest) {
);
}

const modelCompatible = await convertChatHistoryList(
validated.data.chatHistory,
);
const modelCompatible = await convertChatHistoryList(validated.data);

const resp = await fetch(
`${process.env.BACKEND_BASE_URL}/api/chat?query=${query}&user=${session.user.id}&sourcesOnly=${sourcesOnly}&spaces=${spaces}`,
Expand All @@ -59,12 +64,13 @@ export async function POST(req: NextRequest) {
method: "POST",
body: JSON.stringify({
chatHistory: modelCompatible,
sources,
}),
},
);

if (sourcesOnly == "true") {
const data = await resp.json();
const data = (await resp.json()) as SourcesFromApi;
return new Response(JSON.stringify(data), { status: 200 });
}

Expand Down
8 changes: 8 additions & 0 deletions packages/shared-types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,11 @@ export function convertChatHistoryList(

return convertedChats;
}

export const sourcesZod = z.object({
ids: z.array(z.string()),
metadata: z.array(z.any()),
normalizedData: z.array(z.any()).optional(),
});

export type SourcesFromApi = z.infer<typeof sourcesZod>;

1 comment on commit 9f751ba

@CodeTorso
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is full of cool stuff,

Please sign in to comment.