Skip to content

Commit

Permalink
[v2.4.14] Make openai key not required for other chat and embedding m…
Browse files Browse the repository at this point in the history
…odels (logancyang#261)
  • Loading branch information
logancyang committed Jan 26, 2024
1 parent 48ef083 commit d6a28ec
Show file tree
Hide file tree
Showing 10 changed files with 87 additions and 40 deletions.
2 changes: 1 addition & 1 deletion manifest.json
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"id": "copilot",
"name": "Copilot",
"version": "2.4.13",
"version": "2.4.14",
"minAppVersion": "0.15.0",
"description": "A ChatGPT Copilot in Obsidian.",
"author": "Logan Yang",
Expand Down
4 changes: 2 additions & 2 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "obsidian-copilot",
"version": "2.4.13",
"version": "2.4.14",
"description": "ChatGPT integration for Obsidian",
"main": "main.js",
"scripts": {
Expand Down
29 changes: 21 additions & 8 deletions src/LLMProviders/chainManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ export default class ChainManager {
this.memoryManager = MemoryManager.getInstance(this.langChainParams);
this.chatModelManager = ChatModelManager.getInstance(this.langChainParams);
this.promptManager = PromptManager.getInstance(this.langChainParams);
this.embeddingsManager = EmbeddingsManager.getInstance(this.langChainParams);
this.createChainWithNewModel(this.langChainParams.modelDisplayName);
}

Expand Down Expand Up @@ -134,12 +133,15 @@ export default class ChainManager {
return;
}
this.validateChainType(chainType);
// MUST set embeddingsManager when switching to QA mode
if (chainType === ChainType.RETRIEVAL_QA_CHAIN) {
this.embeddingsManager = EmbeddingsManager.getInstance(this.langChainParams);
}

// Get chatModel, memory, prompt, and embeddingAPI from respective managers
const chatModel = this.chatModelManager.getChatModel();
const memory = this.memoryManager.getMemory();
const chatPrompt = this.promptManager.getChatPrompt();
const embeddingsAPI = this.embeddingsManager.getEmbeddingsAPI();

switch (chainType) {
case ChainType.LLM_CHAIN: {
Expand Down Expand Up @@ -182,8 +184,14 @@ export default class ChainManager {
const parsedMemoryVectors: MemoryVector[] | undefined = await VectorDBManager.getMemoryVectors(docHash);
if (parsedMemoryVectors) {
// Index already exists
const embeddingsAPI = this.embeddingsManager.getEmbeddingsAPI();
if (!embeddingsAPI) {
console.error('Error getting embeddings API. Please check your settings.');
return;
}
const vectorStore = await VectorDBManager.rebuildMemoryVectorStore(
parsedMemoryVectors, embeddingsAPI
parsedMemoryVectors,
embeddingsAPI,
);

// Create new conversational retrieval chain
Expand Down Expand Up @@ -391,14 +399,19 @@ export default class ChainManager {
}

async buildIndex(noteContent: string, docHash: string): Promise<void> {
const textSplitter = new RecursiveCharacterTextSplitter({ chunkSize: 1000 });

const docs = await textSplitter.createDocuments([noteContent]);
const embeddingsAPI = this.embeddingsManager.getEmbeddingsAPI();

// Note: HF can give 503 errors frequently (it's free)
console.log('Creating vector store...');
try {
const textSplitter = new RecursiveCharacterTextSplitter({ chunkSize: 1000 });

const docs = await textSplitter.createDocuments([noteContent]);
const embeddingsAPI = this.embeddingsManager.getEmbeddingsAPI();
if (!embeddingsAPI) {
const errorMsg = 'Failed to create vector store, embedding API is not set correctly, please check your settings.';
new Notice(errorMsg);
console.error(errorMsg);
return;
}
this.vectorStore = await MemoryVectorStore.fromDocuments(
docs, embeddingsAPI,
);
Expand Down
69 changes: 42 additions & 27 deletions src/LLMProviders/embeddingManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ export default class EmbeddingManager {
return EmbeddingManager.instance;
}

getEmbeddingsAPI(): Embeddings {
getEmbeddingsAPI(): Embeddings | undefined {
const {
openAIApiKey,
azureOpenAIApiKey,
Expand All @@ -33,26 +33,32 @@ export default class EmbeddingManager {

// Note that openAIProxyBaseUrl has the highest priority.
// If openAIProxyBaseUrl is set, it overrides both chat and embedding models.
const OpenAIEmbeddingsAPI = openAIProxyBaseUrl ?
new ProxyOpenAIEmbeddings({
modelName: this.langChainParams.embeddingModel,
openAIApiKey,
maxRetries: 3,
maxConcurrency: 3,
timeout: 10000,
openAIProxyBaseUrl,
}):
new OpenAIEmbeddings({
modelName: this.langChainParams.embeddingModel,
openAIApiKey,
maxRetries: 3,
maxConcurrency: 3,
timeout: 10000,
});
const OpenAIEmbeddingsAPI = openAIApiKey ? (
openAIProxyBaseUrl ?
new ProxyOpenAIEmbeddings({
modelName: this.langChainParams.embeddingModel,
openAIApiKey,
maxRetries: 3,
maxConcurrency: 3,
timeout: 10000,
openAIProxyBaseUrl,
}) :
new OpenAIEmbeddings({
modelName: this.langChainParams.embeddingModel,
openAIApiKey,
maxRetries: 3,
maxConcurrency: 3,
timeout: 10000,
})
) : null;

switch(this.langChainParams.embeddingProvider) {
case ModelProviders.OPENAI:
return OpenAIEmbeddingsAPI
if (OpenAIEmbeddingsAPI) {
return OpenAIEmbeddingsAPI;
}
console.error('OpenAI API key is not provided for the embedding model.');
break;
case ModelProviders.HUGGINGFACE:
return new HuggingFaceInferenceEmbeddings({
apiKey: this.langChainParams.huggingfaceApiKey,
Expand All @@ -66,18 +72,27 @@ export default class EmbeddingManager {
maxConcurrency: 3,
});
case ModelProviders.AZURE_OPENAI:
return new OpenAIEmbeddings({
azureOpenAIApiKey,
azureOpenAIApiInstanceName,
azureOpenAIApiDeploymentName: azureOpenAIApiEmbeddingDeploymentName,
azureOpenAIApiVersion,
if (azureOpenAIApiKey) {
return new OpenAIEmbeddings({
azureOpenAIApiKey,
azureOpenAIApiInstanceName,
azureOpenAIApiDeploymentName: azureOpenAIApiEmbeddingDeploymentName,
azureOpenAIApiVersion,
maxRetries: 3,
maxConcurrency: 3,
});
}
console.error('Azure OpenAI API key is not provided for the embedding model.');
break;
default:
console.error('No embedding provider set or no valid API key provided. Defaulting to OpenAI.');
return OpenAIEmbeddingsAPI || new OpenAIEmbeddings({
modelName: this.langChainParams.embeddingModel,
openAIApiKey: 'default-key',
maxRetries: 3,
maxConcurrency: 3,
timeout: 10000,
});
default:
console.error('No embedding provider set. Using OpenAI.');
return OpenAIEmbeddingsAPI;
}
}

}
1 change: 1 addition & 0 deletions src/settings/SettingsPage.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ export class CopilotSettingTab extends PluginSettingTab {
await this.plugin.saveSettings();

// Reload the plugin
// eslint-disable-next-line @typescript-eslint/no-explicit-any
const app = (this.plugin.app as any);
await app.plugins.disablePlugin("copilot");
await app.plugins.enablePlugin("copilot");
Expand Down
12 changes: 12 additions & 0 deletions src/settings/components/ApiSettings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ interface ApiSettingsProps {
setAzureOpenAIApiDeploymentName: (value: string) => void;
azureOpenAIApiVersion: string;
setAzureOpenAIApiVersion: (value: string) => void;
azureOpenAIApiEmbeddingDeploymentName: string;
setAzureOpenAIApiEmbeddingDeploymentName: (value: string) => void;
}

const ApiSettings: React.FC<ApiSettingsProps> = ({
Expand All @@ -39,6 +41,8 @@ const ApiSettings: React.FC<ApiSettingsProps> = ({
setAzureOpenAIApiDeploymentName,
azureOpenAIApiVersion,
setAzureOpenAIApiVersion,
azureOpenAIApiEmbeddingDeploymentName,
setAzureOpenAIApiEmbeddingDeploymentName,
}) => {
return (
<div>
Expand Down Expand Up @@ -154,6 +158,14 @@ const ApiSettings: React.FC<ApiSettingsProps> = ({
placeholder="Enter Azure OpenAI API Version"
type="text"
/>
<ApiSetting
title="Azure OpenAI API Embedding Deployment Name"
description="(Optional) For embedding provider Azure OpenAI"
value={azureOpenAIApiEmbeddingDeploymentName}
setValue={setAzureOpenAIApiEmbeddingDeploymentName}
placeholder="Enter Azure OpenAI API Embedding Deployment Name"
type="text"
/>
</div>
</Collapsible>
</div>
Expand Down
1 change: 1 addition & 0 deletions src/settings/components/QASettings.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ const QASettings: React.FC<QASettingsProps> = ({
/>
<DropdownComponent
name="OpenAI Embedding Model"
description="(for when embedding provider is OpenAI)"
value={embeddingModel}
onChange={setEmbeddingModel}
options={OPENAI_EMBEDDING_MODELS}
Expand Down
4 changes: 4 additions & 0 deletions src/settings/components/SettingsMain.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ export default function SettingsMain({ plugin, reloadPlugin }: SettingsMainProps
const [azureOpenAIApiInstanceName, setAzureOpenAIApiInstanceName] = useState(plugin.settings.azureOpenAIApiInstanceName);
const [azureOpenAIApiDeploymentName, setAzureOpenAIApiDeploymentName] = useState(plugin.settings.azureOpenAIApiDeploymentName);
const [azureOpenAIApiVersion, setAzureOpenAIApiVersion] = useState(plugin.settings.azureOpenAIApiVersion);
const [azureOpenAIApiEmbeddingDeploymentName, setAzureOpenAIApiEmbeddingDeploymentName] = useState(plugin.settings.azureOpenAIApiEmbeddingDeploymentName);

// QA settings
const [embeddingProvider, setEmbeddingProvider] = useState(plugin.settings.embeddingProvider);
Expand Down Expand Up @@ -66,6 +67,7 @@ export default function SettingsMain({ plugin, reloadPlugin }: SettingsMainProps
plugin.settings.azureOpenAIApiInstanceName = azureOpenAIApiInstanceName;
plugin.settings.azureOpenAIApiDeploymentName = azureOpenAIApiDeploymentName;
plugin.settings.azureOpenAIApiVersion = azureOpenAIApiVersion;
plugin.settings.azureOpenAIApiEmbeddingDeploymentName = azureOpenAIApiEmbeddingDeploymentName;

// QA settings
plugin.settings.embeddingProvider = embeddingProvider;
Expand Down Expand Up @@ -184,6 +186,8 @@ export default function SettingsMain({ plugin, reloadPlugin }: SettingsMainProps
setAzureOpenAIApiDeploymentName={setAzureOpenAIApiDeploymentName}
azureOpenAIApiVersion={azureOpenAIApiVersion}
setAzureOpenAIApiVersion={setAzureOpenAIApiVersion}
azureOpenAIApiEmbeddingDeploymentName={azureOpenAIApiEmbeddingDeploymentName}
setAzureOpenAIApiEmbeddingDeploymentName={setAzureOpenAIApiEmbeddingDeploymentName}
/>
<QASettings
embeddingProvider={embeddingProvider}
Expand Down
3 changes: 2 additions & 1 deletion versions.json
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,6 @@
"2.4.10": "0.15.0",
"2.4.11": "0.15.0",
"2.4.12": "0.15.0",
"2.4.13": "0.15.0"
"2.4.13": "0.15.0",
"2.4.14": "0.15.0"
}

0 comments on commit d6a28ec

Please sign in to comment.