Skip to content

Commit

Permalink
[v2.4.12] Migrate to LCEL (logancyang#255)
Browse files Browse the repository at this point in the history
  • Loading branch information
logancyang committed Jan 24, 2024
1 parent 229bf63 commit 29b93b1
Show file tree
Hide file tree
Showing 13 changed files with 304 additions and 121 deletions.
2 changes: 1 addition & 1 deletion manifest.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"id": "copilot",
"name": "Copilot",
"version": "2.4.11",
"version": "2.4.12",
"minAppVersion": "0.15.0",
"description": "A ChatGPT Copilot in Obsidian.",
"author": "Logan Yang",
Expand Down
4 changes: 2 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "obsidian-copilot",
"version": "2.4.11",
"version": "2.4.12",
"description": "ChatGPT integration for Obsidian",
"main": "main.js",
"scripts": {
Expand Down
187 changes: 104 additions & 83 deletions src/LLMProviders/chainManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,17 @@ import {
} from '@/constants';
import { ProxyChatOpenAI } from '@/langchainWrappers';
import { ChatMessage } from '@/sharedState';
import { getModelName, isSupportedChain } from '@/utils';
import { extractChatHistory, getModelName, isSupportedChain } from '@/utils';
import VectorDBManager, { MemoryVector } from '@/vectorDBManager';
import { ChatOllama } from "@langchain/community/chat_models/ollama";
import {
BaseChain,
ConversationChain,
ConversationalRetrievalQAChain,
RetrievalQAChain
} from "langchain/chains";
import { RunnableSequence } from "@langchain/core/runnables";
import { BaseChatMemory } from "langchain/memory";
import {
ChatPromptTemplate,
HumanMessagePromptTemplate,
MessagesPlaceholder
} from "langchain/prompts";
import { ContextualCompressionRetriever } from "langchain/retrievers/contextual_compression";
import { LLMChainExtractor } from "langchain/retrievers/document_compressors/chain_extract";
import { MultiQueryRetriever } from "langchain/retrievers/multi_query";
import { RecursiveCharacterTextSplitter } from "langchain/text_splitter";
import { MemoryVectorStore } from "langchain/vectorstores/memory";
import { Notice } from 'obsidian';
Expand All @@ -33,9 +28,8 @@ import MemoryManager from './memoryManager';
import PromptManager from './promptManager';

export default class ChainManager {
private static chain: BaseChain;
private static retrievalChain: RetrievalQAChain;
private static conversationalRetrievalChain: ConversationalRetrievalQAChain;
private static chain: RunnableSequence;
private static retrievalChain: RunnableSequence;

private static isOllamaModelActive = false;
private static isOpenRouterModelActive = false;
Expand Down Expand Up @@ -162,14 +156,16 @@ export default class ChainManager {
llm: chatModel,
memory: memory,
prompt: options.prompt || chatPrompt,
}) as ConversationChain;
abortController: options.abortController,
}) as RunnableSequence;
} else {
// For navigating back to the plugin view
ChainManager.chain = ChainFactory.getLLMChainFromMap({
llm: chatModel,
memory: memory,
prompt: options.prompt || chatPrompt,
}) as ConversationChain;
abortController: options.abortController,
}) as RunnableSequence;
}

this.langChainParams.chainType = ChainType.LLM_CHAIN;
Expand All @@ -185,32 +181,37 @@ export default class ChainManager {
const docHash = VectorDBManager.getDocumentHash(options.noteContent);
const parsedMemoryVectors: MemoryVector[] | undefined = await VectorDBManager.getMemoryVectors(docHash);
if (parsedMemoryVectors) {
// Index already exists
const vectorStore = await VectorDBManager.rebuildMemoryVectorStore(
parsedMemoryVectors, embeddingsAPI
);
ChainManager.retrievalChain = RetrievalQAChain.fromLLM(
chatModel,
vectorStore.asRetriever(),
);

// Create new conversational retrieval chain
ChainManager.retrievalChain = ChainFactory.createConversationalRetrievalChain({
llm: chatModel,
retriever: vectorStore.asRetriever(),
})
console.log('Existing vector store for document hash: ', docHash);
} else {
// Index doesn't exist
await this.buildIndex(options.noteContent, docHash);
if (!this.vectorStore) {
console.error('Error creating vector store.');
return;
}

const baseCompressor = LLMChainExtractor.fromLLM(chatModel);
const retriever = new ContextualCompressionRetriever({
baseCompressor,
baseRetriever: this.vectorStore.asRetriever(),
const retriever = MultiQueryRetriever.fromLLM({
llm: chatModel,
retriever: this.vectorStore.asRetriever(),
verbose: false,
});
ChainManager.retrievalChain = RetrievalQAChain.fromLLM(
chatModel,
retriever,
);

ChainManager.retrievalChain = ChainFactory.createConversationalRetrievalChain({
llm: chatModel,
retriever: retriever,
})
console.log(
'New retrieval qa chain with contextual compression created for '
'New conversational retrieval qa chain with multi-query retriever created for '
+ 'document hash: ', docHash
);
}
Expand All @@ -225,36 +226,21 @@ export default class ChainManager {
}
}

async buildIndex(noteContent: string, docHash: string): Promise<void> {
const textSplitter = new RecursiveCharacterTextSplitter({ chunkSize: 1000 });

const docs = await textSplitter.createDocuments([noteContent]);
const embeddingsAPI = this.embeddingsManager.getEmbeddingsAPI();

// Note: HF can give 503 errors frequently (it's free)
console.log('Creating vector store...');
try {
this.vectorStore = await MemoryVectorStore.fromDocuments(
docs, embeddingsAPI,
);
// Serialize and save vector store to PouchDB
VectorDBManager.setMemoryVectors(this.vectorStore.memoryVectors, docHash);
console.log('Vector store created successfully.');
new Notice('Vector store created successfully.');
} catch (error) {
new Notice('Failed to create vector store, please try again:', error);
console.error('Failed to create vector store, please try again.:', error);
}
}

async runChain(
userMessage: string,
abortController: AbortController,
updateCurrentAiMessage: (message: string) => void,
addMessage: (message: ChatMessage) => void,
options: { debug?: boolean, ignoreSystemMessage?: boolean } = {},
options: {
debug?: boolean,
ignoreSystemMessage?: boolean,
updateLoading?: (loading: boolean) => void
} = {},
) {
const { debug = false, ignoreSystemMessage = false } = options;
const {
debug = false,
ignoreSystemMessage = false,
} = options;

// Check if chat model is initialized
if (!this.chatModelManager.validateChatModel(this.chatModelManager.getChatModel())) {
Expand Down Expand Up @@ -302,66 +288,50 @@ export default class ChainManager {

let fullAIResponse = '';
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const chain = ChainManager.chain as any;
const chatModel = (ChainManager.chain as any).last.bound;
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const chatStream = await ChainManager.chain.stream({ input: userMessage } as any);

try {
switch(chainType) {
case ChainType.LLM_CHAIN:
if (debug) {
console.log(`*** DEBUG INFO ***\n`
+ `user message: ${userMessage}\n`
// ChatOpenAI has modelName, some other ChatModels like ChatOllama have model
+ `model: ${chain.llm.modelName || chain.llm.model}\n`
+ `model: ${chatModel.modelName || chatModel.model}\n`
+ `chain type: ${chainType}\n`
+ `temperature: ${temperature}\n`
+ `maxTokens: ${maxTokens}\n`
+ `system prompt: ${systemPrompt}\n`
+ `chat context turns: ${chatContextTurns}\n`,
);
console.log('chain:', chain);
console.log('chain RunnableSequence:', ChainManager.chain);
console.log('Chat memory:', memory);
}
await ChainManager.chain.call(
{
input: userMessage,
signal: abortController.signal,
},
[
{
handleLLMNewToken: (token) => {
fullAIResponse += token;
updateCurrentAiMessage(fullAIResponse);
}
}
]
);

for await (const chunk of chatStream) {
if (abortController.signal.aborted) break;
fullAIResponse += chunk.content;
updateCurrentAiMessage(fullAIResponse);
}
break;
case ChainType.RETRIEVAL_QA_CHAIN:
if (debug) {
console.log(`*** DEBUG INFO ***\n`
+ `user message: ${userMessage}\n`
+ `model: ${chain.llm.modelName}\n`
+ `model: ${chatModel.modelName || chatModel.model}\n`
+ `chain type: ${chainType}\n`
+ `temperature: ${temperature}\n`
+ `maxTokens: ${maxTokens}\n`
+ `system prompt: ${systemPrompt}\n`
+ `chat context turns: ${chatContextTurns}\n`,
);
console.log('chain:', chain);
console.log('chain RunnableSequence:', ChainManager.chain);
console.log('embedding provider:', this.langChainParams.embeddingProvider);
}
await ChainManager.retrievalChain.call(
{
query: userMessage,
signal: abortController.signal,
},
[
{
handleLLMNewToken: (token) => {
fullAIResponse += token;
updateCurrentAiMessage(fullAIResponse);
}
}
]
fullAIResponse = await this.runRetrievalChain(
userMessage, memory, updateCurrentAiMessage, abortController
);
break;
default:
Expand All @@ -380,6 +350,11 @@ export default class ChainManager {
}
} finally {
if (fullAIResponse) {
// This line is a must for memory to work with RunnableSequence!
await memory.saveContext(
{ input: userMessage },
{ output: fullAIResponse }
);
addMessage({
message: fullAIResponse,
sender: AI_SENDER,
Expand All @@ -390,4 +365,50 @@ export default class ChainManager {
}
return fullAIResponse;
}

private async runRetrievalChain(
userMessage: string,
memory: BaseChatMemory,
updateCurrentAiMessage: (message: string) => void,
abortController: AbortController,
): Promise<string> {
const memoryVariables = await memory.loadMemoryVariables({});
const chatHistory = extractChatHistory(memoryVariables);
const qaStream = await ChainManager.retrievalChain.stream({
question: userMessage,
chat_history: chatHistory,
// eslint-disable-next-line @typescript-eslint/no-explicit-any
} as any);

let fullAIResponse = '';

for await (const chunk of qaStream) {
if (abortController.signal.aborted) break;
fullAIResponse += chunk.content;
updateCurrentAiMessage(fullAIResponse);
}
return fullAIResponse;
}

async buildIndex(noteContent: string, docHash: string): Promise<void> {
const textSplitter = new RecursiveCharacterTextSplitter({ chunkSize: 1000 });

const docs = await textSplitter.createDocuments([noteContent]);
const embeddingsAPI = this.embeddingsManager.getEmbeddingsAPI();

// Note: HF can give 503 errors frequently (it's free)
console.log('Creating vector store...');
try {
this.vectorStore = await MemoryVectorStore.fromDocuments(
docs, embeddingsAPI,
);
// Serialize and save vector store to PouchDB
VectorDBManager.setMemoryVectors(this.vectorStore.memoryVectors, docHash);
console.log('Vector store created successfully.');
new Notice('Vector store created successfully.');
} catch (error) {
new Notice('Failed to create vector store, please try again:', error);
console.error('Failed to create vector store, please try again.:', error);
}
}
}
6 changes: 3 additions & 3 deletions src/LLMProviders/memoryManager.ts
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import { LangChainParams } from '@/aiParams';
import { BufferWindowMemory } from "langchain/memory";
import { BaseChatMemory, BufferWindowMemory } from "langchain/memory";

export default class MemoryManager {
private static instance: MemoryManager;
private memory: BufferWindowMemory;
private memory: BaseChatMemory;

private constructor(
private langChainParams: LangChainParams
Expand All @@ -29,7 +29,7 @@ export default class MemoryManager {
});
}

getMemory(): BufferWindowMemory {
getMemory(): BaseChatMemory {
return this.memory;
}

Expand Down
1 change: 1 addition & 0 deletions src/aiParams.ts
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,5 @@ export interface SetChainOptions {
prompt?: ChatPromptTemplate;
noteContent?: string;
forceNewCreation?: boolean;
abortController?: AbortController;
}
4 changes: 2 additions & 2 deletions src/aiState.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ChainManager from '@/LLMProviders/chainManager';
import { SetChainOptions } from '@/aiParams';
import { ChainType } from '@/chainFactory';
import { BufferWindowMemory } from "langchain/memory";
import { BaseChatMemory } from "langchain/memory";
import { useState } from 'react';

/**
Expand All @@ -19,7 +19,7 @@ export function useAIState(
const { langChainParams } = chainManager;
const [currentModel, setCurrentModel] = useState<string>(langChainParams.modelDisplayName);
const [currentChain, setCurrentChain] = useState<ChainType>(langChainParams.chainType);
const [, setChatMemory] = useState<BufferWindowMemory | null>(chainManager.memoryManager.getMemory());
const [, setChatMemory] = useState<BaseChatMemory | null>(chainManager.memoryManager.getMemory());

const clearChatMemory = () => {
chainManager.memoryManager.clearChatMemory();
Expand Down
Loading

0 comments on commit 29b93b1

Please sign in to comment.