Skip to content

Commit

Permalink
Add Anthropic and update setup for embedding, handler, llm, retrieval…
Browse files Browse the repository at this point in the history
… & vector store with cache, timer and optional checks
  • Loading branch information
luandro committed May 8, 2024
1 parent 3dfb932 commit 029bdba
Show file tree
Hide file tree
Showing 10 changed files with 1,499 additions and 1,340 deletions.
9 changes: 8 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# GENERAL
PRODUCTION=false
LOCAL_MODE=false
CHECK=true
CACHE_DIR=/tmp/rag_cache

# LANGFUSE
# If using cloud, get keys from https://cloud.langfuse.com
Expand All @@ -14,6 +16,11 @@ OPENAI_MODEL=gpt-3.5-turbo
OPENAI_API_KEY=
OPENAI_BASE_URL=https://api.openai.com/v1

# ANTHROPIC
ANTHROPIC_API_KEY=
ANTHROPIC_MODEL=


# QDRANT
# If using cloud, get keys from https://cloud.qdrant.io
QDRANT_API_KEY=
Expand All @@ -35,7 +42,7 @@ TOGETHERAI_MODEL=microsoft/phi-2

# GROQ
GROQ_API_KEY=
GROQ_MODEL=mixtral-8x7b-32768
GROQ_MODEL=llama3-8b-8192

# RETRIEVAL
TOP_K=4
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Kakawa RAG API

[![Offline Stack Test](https://github.com/digidem/kakawa-rag-api/actions/workflows/offline-stack.yml/badge.svg)](https://github.com/digidem/kakawa-rag-api/actions/workflows/offline-stack.yml)
[![Online Stack Test](https://github.com/digidem/kakawa-rag-api/actions/workflows/online-stack.yml/badge.svg)](https://github.com/digidem/kakawa-rag-api/actions/workflows/online-stack.yml)


The document retrieval API for Kakawa, a product-support bot that efficiently responds to questions given any documentation. It can run online with cutting-edge AI models, or offline on regular hardware. It can also be integrated with existing platforms such as WhatsApp.
Expand Down
47 changes: 28 additions & 19 deletions app/internal/rag.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
import sys
import time
from shutil import rmtree

import requests
Expand All @@ -9,9 +10,6 @@
# LlamaIndex
from llama_index.core import Settings

# Embedding
from app.rag.embedding import setup_embedding

# Handler
from app.rag.handler import setup_langfuse

Expand All @@ -29,39 +27,47 @@
from app.rag.vector_store import initialize_vector_store

# Load environment variables
start_time = time.time()
load_dotenv()
local_mode = os.getenv("LOCAL_MODE", "false").lower() == "true"
top_k = int(os.getenv("TOP_K", "4"))
check = os.getenv("CHECK", "false").lower() == "true"

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
# logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))

if local_mode:
print("Running in local mode")
else:
connectivity_test_url = "https://httpbin.org/get"
try:
with requests.get(connectivity_test_url, timeout=5) as response:
response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
print("Running in cloud mode")
except (requests.ConnectionError, requests.Timeout, requests.HTTPError) as e:
print(f"Cloud mode connectivity failed ({e}), switching to local mode")
local_mode = True
if check:
connectivity_test_url = "https://httpbin.org/get"
try:
with requests.get(connectivity_test_url, timeout=5) as response:
response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
print("Running in cloud mode")
except (requests.ConnectionError, requests.Timeout, requests.HTTPError) as e:
print(f"Cloud mode connectivity failed ({e}), switching to local mode")
local_mode = True
else:
print("Running in cloud mode without connectivity check")
# Setup LLM
used_llm = setup_llm(local_mode)
# Setup Langfuse as handler
langfuse_handler = setup_langfuse(local_mode)
# Setup embedding
used_embedding_model = setup_embedding(local_mode)
# Log embedding and LLM details
logging.info(f"Using embedding model: {Settings.embed_model.__class__.__name__}")
logging.info(f"Using LLM: {Settings.llm.__class__.__name__}")
# Setup Langfuse as handler
if os.getenv("TEST", "false").lower() != "true":
langfuse_handler = setup_langfuse(local_mode)
logging.info("Using LangFuse handler")
else:
langfuse_handler = None

# Setup vector store
document_files, vector_database, index = initialize_vector_store(local_mode)
document_files, vector_database, vector_index, used_embedding_model = (
initialize_vector_store(local_mode)
)

colbert_reranker, retrieval_strategy = initialize_colbert_reranker(top_k)
logging.info(f"Creating query engine with similarity top k set to '{top_k}'.")
query_engine = index.as_query_engine(
query_engine = vector_index.as_query_engine(
similarity_top_k=top_k, node_postprocessors=[colbert_reranker]
)

Expand All @@ -81,6 +87,9 @@
for key, value in metadata.items():
logging.info(f"Metadata - {key}: {value}")

initialization_time = time.time() - start_time
logging.info(f"RAG app initialized in {initialization_time:.2f} seconds")


def rag(query, user_id="test_user", session_id="test_session"):
if langfuse_handler is not None:
Expand Down
26 changes: 24 additions & 2 deletions app/rag/embedding.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import logging
import os
import time

from llama_index.core import Settings
from llama_index.embeddings.cohere import CohereEmbedding
Expand All @@ -7,6 +9,10 @@


def setup_embedding(local_mode):
start_time = time.time()
cache_dir = os.getenv("CACHE_DIR", "/tmp/rag_cache")
embedding_cache = cache_dir + "/embeddings"
logging.info("Setting up embedding model.")
local_embedding = os.getenv("LOCAL_EMBEDDING", "false").lower() == "true"
cohere_api_key = os.getenv("COHERE_API_KEY")
openai_api_key = os.getenv("OPENAI_API_KEY")
Expand All @@ -16,16 +22,26 @@ def setup_embedding(local_mode):
default_openai_embedding_model = "text-embedding-3-small"
default_cohere_embedding_model = "embed-english-v3.0"
default_baai_embedding_model = "BAAI/bge-small-en-v1.5"

logging.info(f"local_mode: {local_mode}, local_embedding: {local_embedding}")
if local_mode or local_embedding:
used_embedding_model = default_baai_embedding_model
Settings.embed_model = FastEmbedEmbedding()
logging.info("Setting up FastEmbedEmbedding.")
try:
Settings.embed_model = FastEmbedEmbedding(
model_name=used_embedding_model, cache_dir=embedding_cache
)
logging.info(f"Using local embedding model: {used_embedding_model}")
except Exception as e:
logging.error(f"Failed to set up FastEmbedEmbedding: {e}")
elif openai_api_key:
used_embedding_model = os.getenv(
"EMBEDDING_MODEL", default_openai_embedding_model
)
Settings.embed_model = OpenAIEmbedding(
api_key=openai_api_key, model_name=used_embedding_model
)
logging.info(f"Using OpenAI embedding model: {used_embedding_model}")
elif cohere_api_key:
used_embedding_model = os.getenv(
"EMBEDDING_MODEL", default_cohere_embedding_model
Expand All @@ -35,7 +51,13 @@ def setup_embedding(local_mode):
model_name=used_embedding_model,
input_type="search_document",
)

logging.info(f"Using Cohere embedding model: {used_embedding_model}")
else:
logging.error("No API key found for OpenAI or Cohere.")
raise ValueError("No API key found for OpenAI or Cohere.")

initialization_time = time.time() - start_time
logging.info(
f"Embedding model setup completed in {initialization_time:.2f} seconds"
)
return used_embedding_model
26 changes: 15 additions & 11 deletions app/rag/handler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
import os

import requests
Expand All @@ -12,16 +13,19 @@ def setup_langfuse(local_mode):
)
os.getenv("LANGFUSE_PUBLIC_KEY")
os.getenv("LANGFUSE_SECRET_KEY")
os.getenv("LANGFUSE_HOST", langfuse_default_host)
langfuse_handler = None
langfuse_host = os.getenv("LANGFUSE_HOST", "http:https://langfuse:3000")
try:
response = requests.get(f"{langfuse_host}/api/public/health")
if response.ok:
langfuse_handler = LlamaIndexCallbackHandler()
Settings.callback_manager = CallbackManager([langfuse_handler])
else:
print("Langfuse isn't running")
except requests.exceptions.RequestException as e:
print("Langfuse isn't running, error:", e)
langfuse_host = os.getenv("LANGFUSE_HOST", langfuse_default_host)
check = os.getenv("CHECK", "false").lower() == "true"
if check:
logging.info(f"Pinging {langfuse_host}/api/public/health for health check.")
try:
response = requests.get(f"{langfuse_host}/api/public/health", timeout=5)
if response.ok:
logging.info("Langfuse is running")
langfuse_handler = LlamaIndexCallbackHandler()
Settings.callback_manager = CallbackManager([langfuse_handler])
else:
print("Langfuse isn't running")
except requests.exceptions.RequestException as e:
print("Langfuse isn't running, error:", e)
return langfuse_handler
110 changes: 70 additions & 40 deletions app/rag/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,64 +4,94 @@

import requests
from llama_index.core import Settings
from llama_index.llms.anthropic import Anthropic
from llama_index.llms.groq import Groq
from llama_index.llms.ollama import Ollama
from llama_index.llms.openai import OpenAI
from llama_index.llms.together import TogetherLLM


def setup_llm(local_mode):
def setup_llm(local_mode, eval=False):
openai_model = os.getenv("OPENAI_MODEL", "gpt-3.5-turbo")
ollama_model = os.getenv("OLLAMA_MODEL", "phi")
togetherai_api_key = os.getenv("TOGETHERAI_API_KEY")
togetherai_model = os.getenv("TOGETHERAI_MODEL", "microsoft/phi-2")
groq_api_key = os.getenv("GROQ_API_KEY")
groq_model = os.getenv("GROQ_MODEL", "mixtral-8x7b-32768")
groq_model = os.getenv("GROQ_MODEL", "llama3-8b-8192")
openai_api_key = os.getenv("OPENAI_API_KEY")
openai_base_url = os.getenv("OPENAI_BASE_URL", "https://api.openai.com/v1")
ollama_base_url = os.getenv("OLLAMA_BASE_URL", "http:https://localhost:11434")
ollama_timeout = float(os.getenv("OLLAMA_TIMEOUT", "120.0"))
anthropic_api_key = os.getenv("ANTHROPIC_API_KEY")
anthropic_model = os.getenv("ANTHROPIC_MODEL", "claude-3-opus-20240229")
eval_model = os.getenv("EVAL_MODEL", "openhermes2.5-mistral")

# Setup LLM
used_llm = (
ollama_model
if local_mode
else (
groq_model
if groq_api_key
if eval and not local_mode:
if eval_model.startswith("claude"):
tokenizer = Anthropic().tokenizer
Settings.tokenizer = tokenizer
Settings.llm = Anthropic(
temperature=0, model=eval_model, api_key=anthropic_api_key
)
elif eval_model.startswith("gpt"):
Settings.llm = OpenAI(
temperature=0,
model=eval_model,
api_key=openai_api_key,
openai_base_url=openai_base_url,
)
else:
Settings.llm = Ollama(
model=eval_model,
request_timeout=ollama_timeout,
base_url=ollama_base_url,
)
else:
# Setup LLM
used_llm = (
ollama_model
if local_mode
else (
togetherai_model
if togetherai_api_key
else (openai_model if openai_api_key else ollama_model)
groq_model
if groq_api_key
else (
togetherai_model
if togetherai_api_key
else (openai_model if openai_api_key else ollama_model)
)
)
)
)
if used_llm == openai_model:
Settings.llm = OpenAI(
temperature=0.1,
model=openai_model,
api_key=openai_api_key,
openai_base_url=openai_base_url,
)
elif used_llm == togetherai_model:
Settings.llm = TogetherLLM(model=togetherai_model, api_key=togetherai_api_key)
elif used_llm == groq_model:
Settings.llm = Groq(model=groq_model, api_key=groq_api_key)
elif used_llm == ollama_model:
# Check if the OLLAMA_MODEL is available
try:
logging.info(f"Checking if {ollama_model} is available...")
response = requests.post(
f"{ollama_base_url}/api/show", data=json.dumps({"name": ollama_model})
if used_llm == openai_model:
Settings.llm = OpenAI(
temperature=0.1,
model=openai_model,
api_key=openai_api_key,
openai_base_url=openai_base_url,
)
if not response.ok:
print(response)
print(f"Error checking {ollama_model}: {response.text}")
except requests.exceptions.RequestException as e:
raise ValueError(f"Request to check {ollama_model} failed: {e}")
Settings.llm = Ollama(
model=ollama_model, request_timeout=ollama_timeout, base_url=ollama_base_url
)
else:
raise ValueError(f"No LLM configured for model: {used_llm}")
elif used_llm == togetherai_model:
Settings.llm = TogetherLLM(
model=togetherai_model, api_key=togetherai_api_key
)
elif used_llm == groq_model:
Settings.llm = Groq(model=groq_model, api_key=groq_api_key)
elif used_llm == ollama_model:
# Check if the OLLAMA_MODEL is available
try:
logging.info(f"Checking if {ollama_model} is available...")
response = requests.post(
f"{ollama_base_url}/api/show",
data=json.dumps({"name": ollama_model}),
)
if not response.ok:
print(response)
print(f"Error checking {ollama_model}: {response.text}")
except requests.exceptions.RequestException as e:
raise ValueError(f"Request to check {ollama_model} failed: {e}")
Settings.llm = Ollama(
model=ollama_model,
request_timeout=ollama_timeout,
base_url=ollama_base_url,
)
else:
raise ValueError(f"No LLM configured for model: {used_llm}")
return Settings.llm
5 changes: 5 additions & 0 deletions app/rag/retrieval.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@


def initialize_colbert_reranker(top_k):
import time

start_time = time.time()
logging.info("Initializing ColbertRerank with top_k={}".format(top_k))
retrieval_strategy = "ColbertRerank"
colbert_reranker = ColbertRerank(
Expand All @@ -12,4 +15,6 @@ def initialize_colbert_reranker(top_k):
tokenizer="colbert-ir/colbertv2.0",
keep_retrieval_score=True,
)
initialization_time = time.time() - start_time
logging.info(f"ColbertRerank initialized in {initialization_time:.2f} seconds")
return colbert_reranker, retrieval_strategy
Loading

0 comments on commit 029bdba

Please sign in to comment.