Skip to content

Commit

Permalink
[Multi-User Part 3]: Separate chat sesssions based on authenticated u…
Browse files Browse the repository at this point in the history
…sers (khoj-ai#511)

- Add a data model which allows us to store Conversations with users. This does a minimal lift over the current setup, where the underlying data is stored in a JSON file. This maintains parity with that configuration.
- There does _seem_ to be some regression in chat quality, which is most likely attributable to search results.

This will help us with khoj-ai#275. It should become much easier to maintain multiple Conversations in a given table in the backend now. We will have to do some thinking on the UI.
  • Loading branch information
sabaimran committed Oct 26, 2023
1 parent a8a82d2 commit 4b6ec24
Show file tree
Hide file tree
Showing 24 changed files with 716 additions and 623 deletions.
2 changes: 2 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
DJANGO_SETTINGS_MODULE = app.settings
pythonpath = . src
testpaths = tests
markers =
chatquality: marks tests as chatquality (deselect with '-m "not chatquality"')
145 changes: 134 additions & 11 deletions src/database/adapters/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Type, TypeVar, List
import uuid
from datetime import date

from django.db import models
Expand All @@ -21,6 +20,13 @@
GithubConfig,
Embeddings,
GithubRepoConfig,
Conversation,
ConversationProcessorConfig,
OpenAIProcessorConversationConfig,
OfflineChatProcessorConversationConfig,
)
from khoj.utils.rawconfig import (
ConversationProcessorConfig as UserConversationProcessorConfig,
)
from khoj.search_filter.word_filter import WordFilter
from khoj.search_filter.file_filter import FileFilter
Expand Down Expand Up @@ -54,18 +60,17 @@ async def get_or_create_user(token: dict) -> KhojUser:


async def create_google_user(token: dict) -> KhojUser:
user_info = token.get("userinfo")
user = await KhojUser.objects.acreate(username=user_info.get("email"), email=user_info.get("email"))
user = await KhojUser.objects.acreate(username=token.get("email"), email=token.get("email"))
await user.asave()
await GoogleUser.objects.acreate(
sub=user_info.get("sub"),
azp=user_info.get("azp"),
email=user_info.get("email"),
name=user_info.get("name"),
given_name=user_info.get("given_name"),
family_name=user_info.get("family_name"),
picture=user_info.get("picture"),
locale=user_info.get("locale"),
sub=token.get("sub"),
azp=token.get("azp"),
email=token.get("email"),
name=token.get("name"),
given_name=token.get("given_name"),
family_name=token.get("family_name"),
picture=token.get("picture"),
locale=token.get("locale"),
user=user,
)

Expand Down Expand Up @@ -137,6 +142,124 @@ async def set_user_github_config(user: KhojUser, pat_token: str, repos: list):
return config


class ConversationAdapters:
@staticmethod
def get_conversation_by_user(user: KhojUser):
conversation = Conversation.objects.filter(user=user)
if conversation.exists():
return conversation.first()
return Conversation.objects.create(user=user)

@staticmethod
async def aget_conversation_by_user(user: KhojUser):
conversation = Conversation.objects.filter(user=user)
if await conversation.aexists():
return await conversation.afirst()
return await Conversation.objects.acreate(user=user)

@staticmethod
def has_any_conversation_config(user: KhojUser):
return ConversationProcessorConfig.objects.filter(user=user).exists()

@staticmethod
def get_openai_conversation_config(user: KhojUser):
return OpenAIProcessorConversationConfig.objects.filter(user=user).first()

@staticmethod
def get_offline_chat_conversation_config(user: KhojUser):
return OfflineChatProcessorConversationConfig.objects.filter(user=user).first()

@staticmethod
def has_valid_offline_conversation_config(user: KhojUser):
return OfflineChatProcessorConversationConfig.objects.filter(user=user, enable_offline_chat=True).exists()

@staticmethod
def has_valid_openai_conversation_config(user: KhojUser):
return OpenAIProcessorConversationConfig.objects.filter(user=user).exists()

@staticmethod
def get_conversation_config(user: KhojUser):
return ConversationProcessorConfig.objects.filter(user=user).first()

@staticmethod
def save_conversation(user: KhojUser, conversation_log: dict):
conversation = Conversation.objects.filter(user=user)
if conversation.exists():
conversation.update(conversation_log=conversation_log)
else:
Conversation.objects.create(user=user, conversation_log=conversation_log)

@staticmethod
def set_conversation_processor_config(user: KhojUser, new_config: UserConversationProcessorConfig):
conversation_config, _ = ConversationProcessorConfig.objects.get_or_create(user=user)
conversation_config.max_prompt_size = new_config.max_prompt_size
conversation_config.tokenizer = new_config.tokenizer
conversation_config.save()

if new_config.openai:
default_values = {
"api_key": new_config.openai.api_key,
}
if new_config.openai.chat_model:
default_values["chat_model"] = new_config.openai.chat_model

OpenAIProcessorConversationConfig.objects.update_or_create(user=user, defaults=default_values)

if new_config.offline_chat:
default_values = {
"enable_offline_chat": str(new_config.offline_chat.enable_offline_chat),
}

if new_config.offline_chat.chat_model:
default_values["chat_model"] = new_config.offline_chat.chat_model

OfflineChatProcessorConversationConfig.objects.update_or_create(user=user, defaults=default_values)

@staticmethod
def get_enabled_conversation_settings(user: KhojUser):
openai_config = ConversationAdapters.get_openai_conversation_config(user)
offline_chat_config = ConversationAdapters.get_offline_chat_conversation_config(user)

return {
"openai": True if openai_config is not None else False,
"offline_chat": True
if (offline_chat_config is not None and offline_chat_config.enable_offline_chat)
else False,
}

@staticmethod
def clear_conversation_config(user: KhojUser):
ConversationProcessorConfig.objects.filter(user=user).delete()
ConversationAdapters.clear_openai_conversation_config(user)
ConversationAdapters.clear_offline_chat_conversation_config(user)

@staticmethod
def clear_openai_conversation_config(user: KhojUser):
OpenAIProcessorConversationConfig.objects.filter(user=user).delete()

@staticmethod
def clear_offline_chat_conversation_config(user: KhojUser):
OfflineChatProcessorConversationConfig.objects.filter(user=user).delete()

@staticmethod
async def has_offline_chat(user: KhojUser):
return await OfflineChatProcessorConversationConfig.objects.filter(
user=user, enable_offline_chat=True
).aexists()

@staticmethod
async def get_offline_chat(user: KhojUser):
return await OfflineChatProcessorConversationConfig.objects.filter(user=user).afirst()

@staticmethod
async def has_openai_chat(user: KhojUser):
return await OpenAIProcessorConversationConfig.objects.filter(user=user).aexists()

@staticmethod
async def get_openai_chat(user: KhojUser):
return await OpenAIProcessorConversationConfig.objects.filter(user=user).afirst()


class EmbeddingsAdapters:
word_filer = WordFilter()
file_filter = FileFilter()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Generated by Django 4.2.5 on 2023-10-18 05:31

from django.conf import settings
from django.db import migrations, models
import django.db.models.deletion


class Migration(migrations.Migration):
dependencies = [
("database", "0006_embeddingsdates"),
]

operations = [
migrations.RemoveField(
model_name="conversationprocessorconfig",
name="conversation",
),
migrations.RemoveField(
model_name="conversationprocessorconfig",
name="enable_offline_chat",
),
migrations.AddField(
model_name="conversationprocessorconfig",
name="max_prompt_size",
field=models.IntegerField(blank=True, default=None, null=True),
),
migrations.AddField(
model_name="conversationprocessorconfig",
name="tokenizer",
field=models.CharField(blank=True, default=None, max_length=200, null=True),
),
migrations.AddField(
model_name="conversationprocessorconfig",
name="user",
field=models.ForeignKey(
default=1, on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL
),
preserve_default=False,
),
migrations.CreateModel(
name="OpenAIProcessorConversationConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("api_key", models.CharField(max_length=200)),
("chat_model", models.CharField(max_length=200)),
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
"abstract": False,
},
),
migrations.CreateModel(
name="OfflineChatProcessorConversationConfig",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("enable_offline_chat", models.BooleanField(default=False)),
("chat_model", models.CharField(default="llama-2-7b-chat.ggmlv3.q4_0.bin", max_length=200)),
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
"abstract": False,
},
),
migrations.CreateModel(
name="Conversation",
fields=[
("id", models.BigAutoField(auto_created=True, primary_key=True, serialize=False, verbose_name="ID")),
("created_at", models.DateTimeField(auto_now_add=True)),
("updated_at", models.DateTimeField(auto_now=True)),
("conversation_log", models.JSONField()),
("user", models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to=settings.AUTH_USER_MODEL)),
],
options={
"abstract": False,
},
),
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Generated by Django 4.2.5 on 2023-10-18 16:46

from django.db import migrations, models


class Migration(migrations.Migration):
dependencies = [
("database", "0007_remove_conversationprocessorconfig_conversation_and_more"),
]

operations = [
migrations.AlterField(
model_name="conversation",
name="conversation_log",
field=models.JSONField(default=dict),
),
]
22 changes: 20 additions & 2 deletions src/database/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,9 +82,27 @@ class LocalPlaintextConfig(BaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)


class ConversationProcessorConfig(BaseModel):
conversation = models.JSONField()
class OpenAIProcessorConversationConfig(BaseModel):
api_key = models.CharField(max_length=200)
chat_model = models.CharField(max_length=200)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)


class OfflineChatProcessorConversationConfig(BaseModel):
enable_offline_chat = models.BooleanField(default=False)
chat_model = models.CharField(max_length=200, default="llama-2-7b-chat.ggmlv3.q4_0.bin")
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)


class ConversationProcessorConfig(BaseModel):
max_prompt_size = models.IntegerField(default=None, null=True, blank=True)
tokenizer = models.CharField(max_length=200, default=None, null=True, blank=True)
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)


class Conversation(BaseModel):
user = models.ForeignKey(KhojUser, on_delete=models.CASCADE)
conversation_log = models.JSONField(default=dict)


class Embeddings(BaseModel):
Expand Down
Loading

0 comments on commit 4b6ec24

Please sign in to comment.