Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Goolge PaLM API models #239

Closed
hugoferrero opened this issue Feb 7, 2024 · 13 comments
Closed

Goolge PaLM API models #239

hugoferrero opened this issue Feb 7, 2024 · 13 comments

Comments

@hugoferrero
Copy link

Hi. I want to try vanna ai on PaLM API models (bison). Do you have any tutorial or documentation on how to set up those models on vanna?. It is not clear to me how to implement any other model if you choose "Ohter LLM" in the configuration options. Here is the code i can't figure it out how to adapt to PaLM API models:

class MyCustomLLM(VannaBase):
  def __init__(self, config=None):
    pass

  def generate_plotly_code(self, question: str = None, sql: str = None, df_metadata: str = None, **kwargs) -> str:
    # Implement here

  def generate_question(self, sql: str, **kwargs) -> str:
    # Implement here
    
  def get_followup_questions_prompt(self, question: str, question_sql_list: list, ddl_list: list, doc_list: list, **kwargs):
    # Implement here
  
  def get_sql_prompt(self, question: str, question_sql_list: list, ddl_list: list, doc_list: list, **kwargs):
    # Implement here

  def submit_prompt(self, prompt, **kwargs) -> str:
    # Implement here
            

class MyVanna(ChromaDB_VectorStore, MyCustomLLM):
    def __init__(self, config=None):
        ChromaDB_VectorStore.__init__(self, config=config)
        MyCustomLLM.__init__(self, config=config)

vn = MyVanna()
@andreped
Copy link
Contributor

andreped commented Feb 23, 2024

I am adding the following notebook, which I believe the code example above is derived from:
https://github.com/vanna-ai/vanna/blob/fb384d43a0fb50a7cbe366cd6c242d5d37b16569/notebooks/bigquery-other-llm-chromadb.ipynb

Would be great to have a notebook describing how to do this or similar :]

@andreped
Copy link
Contributor

andreped commented Feb 23, 2024

Perhaps you want to take a look at the recently added Ollama implementation:

class Ollama(VannaBase):

I added the code below which should give you some ideas on what is required to add support to any other LLM model.

Perhaps that is exactly what you were looking for, @hugoferrero? :]

from ..base import VannaBase
import requests
import json
import re

class Ollama(VannaBase):
    def __init__(self, config=None):
        if config is None or 'ollama_host' not in config:
            self.host = "http:https://localhost:11434"
        else:
            self.host = config['ollama_host']

        if config is None or 'model' not in config:
            raise ValueError("config must contain a Ollama model")
        else:
            self.model = config['model']

    def system_message(self, message: str) -> any:
        return {"role": "system", "content": message}

    def user_message(self, message: str) -> any:
        return {"role": "user", "content": message}

    def assistant_message(self, message: str) -> any:
        return {"role": "assistant", "content": message}
    
    def extract_sql_query(self, text):
        """
        Extracts the first SQL statement after the word 'select', ignoring case,
        matches until the first semicolon, three backticks, or the end of the string,
        and removes three backticks if they exist in the extracted string.
        
        Args:
        - text (str): The string to search within for an SQL statement.
        
        Returns:
        - str: The first SQL statement found, with three backticks removed, or an empty string if no match is found.
        """
        # Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string
        pattern = re.compile(r'select.*?(?:;|```|$)', re.IGNORECASE | re.DOTALL)
        
        match = pattern.search(text)
        if match:
            # Remove three backticks from the matched string if they exist
            return match.group(0).replace('```', '')
        else:
            return text

    def generate_sql(self, question: str, **kwargs) -> str:
        # Use the super generate_sql
        sql = super().generate_sql(question, **kwargs)

        # Replace "\_" with "_"
        sql = sql.replace("\\_", "_")

        sql = sql.replace("\\", "")

        return self.extract_sql_query(sql)

    def submit_prompt(self, prompt, **kwargs) -> str:
        url = f"{self.host}/api/chat"
        data = {
            "model": self.model,
            "stream": False,
            "messages": prompt,
        }

        response = requests.post(url, json=data)

        response_dict = response.json()

        self.log(response.text)
        
        return response_dict['message']['content']

@hugoferrero
Copy link
Author

Thanks for the response @andreped. I will try it, and send you feddback.

@zainhoda
Copy link
Contributor

@hugoferrero if you happen to make progress on this, could you pass along your code and we can potentially integrate this into the main Vanna repo?

@andreped
Copy link
Contributor

andreped commented Feb 26, 2024

@hugoferrero if you happen to make progress on this, could you pass along your code and we can potentially integrate this into the main Vanna repo?

@zainhoda I am open to drafting a PR for :] I can tag you in, @hugoferrero, if you wish to test it before merging.

@andreped
Copy link
Contributor

I made a PR #264.

It is a rather simple implementation but sadly I do not have access to Google Cloud. I am therefore dependent on some of you to test it.

To install:

pip install git+https://github.com/andreped/vanna.git@bison-support
pip install chromadb google-cloud-aiplatform

Then you should be able to initialize it with a vector DB like Chroma like so:

from vertexai.language_models import ChatModel
from vanna.chromadb.chromadb_vector import ChromaDB_VectorStore
from vanna.palm.palm import Palm

class MyVanna(ChromaDB_VectorStore, Palm):
        def __init__(self, config=None):
            ChromaDB_VectorStore.__init__(self, config=config)
            Palm.__init__(self, client=ChatModel("chat-bison@001"), config=config)

vn = MyVanna()

# do as you normally would with any other client(s)
# [...]

@yedhukr
Copy link

yedhukr commented Feb 28, 2024

@andreped I've tried the code and it's not working, it gives errors like:
'MyVanna' object has no attribute 'temperature'

I then tried explicitly mentioning the parameters as below:

def submit_prompt(self, prompt, **kwargs) -> str:
    temperature = 0.7
    max_tokens = 500
    top_p = 0.95
    top_k = 40

    params = {
        "temperature": temperature,
        "max_output_tokens": max_tokens,
        "top_p": top_p,
        "top_k": top_k,
    }

    response = self.client.send_message(prompt, **params)
    return response.text

This gave me the error:
'ChatModel' object has no attribute 'send_message'

I also attempted to specify the model more explicitly:

def submit_prompt(self, prompt, **kwargs) -> str:
    temperature = 0.7
    max_tokens = 500
    top_p = 0.95
    top_k = 40

    client = ChatModel("chat-bison@001")
    chat = client.start_chat()
    params = {
        "temperature": temperature,
        "max_output_tokens": max_tokens,
        "top_p": top_p,
        "top_k": top_k,
    }
    print(prompt)
    response = chat.send_message(prompt, **params)
    return response.text

This gave me this error:
400 Invalid resource field value in the request.

Please help 🙏🏼

@andreped
Copy link
Contributor

@andreped I've tried the code and it's not working, it gives errors like:
'MyVanna' object has no attribute 'temperature'

Hello, @yedhukr! :] Great that you were able to test the implementation!

I don't have access to Google Cloud, so I have no way of testing it. Perhaps someone could reach out, and I could borrow some API key such that I could debug this properly? Just for this PR, then the key could be rotated. @zainhoda?

@yedhukr
Copy link

yedhukr commented Feb 28, 2024

@andreped Let me know if there's anything else I can do to help!

Modifying the function in this manner gets it to run, but I get a response like:
'Hi there, how can I help you today?'

    def submit_prompt(self, prompt, **kwargs) -> str:
        temperature = 0.7
        max_tokens = 500
        top_p = 0.95
        top_k = 40

        chat_model = ChatModel.from_pretrained("chat-bison@001")
        chat = chat_model.start_chat()
        params = {
            "temperature": temperature,
            "max_output_tokens": max_tokens,
            "top_p": top_p,
            "top_k": top_k,
        }
        
        response = chat.send_message("{prompt}", **params)
        return response.text

@andreped
Copy link
Contributor

andreped commented Feb 28, 2024

Modifying the function in this manner gets it to run, but I get a response like:
'Hi there, how can I help you today?'

Which user prompt did you use? It also sounds like you are missing the system message that Vanna uses.

I think by doing this chat = chat_model.start_chat(), you are not setting the system message. Here you can see an example which includes a system message:
https://cloud.google.com/vertex-ai/generative-ai/docs/sdk-for-llm/sdk-use-text-models#generate-text-chat-sdk

As a test, could you try to feed the system message that the Vanna Base class uses here:
https://github.com/vanna-ai/vanna/blob/main/src/vanna/base/base.py#L448

Basically, change chat = chat_model.start_chat() to:

chat = chat_model.start_chat(
    context="The user will give you SQL and you will try to guess what the business question this query is answering. Return just the question without any additional explanation. Do not reference the table name in the question."
)

If that works, I think I know how to fix the issue.


EDIT: It could also be that you just wrote "Hello" as the user prompt, and in that case I think I have gotten the same reply, even with Azure OpenAI and pretrained Chroma instance. Write a more advanced question and see if it produces a query.

@yedhukr
Copy link

yedhukr commented Feb 28, 2024

@andreped
Alright I made the change as you mentioned, here is the entire code and response for reference:

[{'role': 'system', 'content': 'The user provides a question and you provide SQL...;} ... {'role': 'user', 'content': 'What are the top 5 properties by sales in 2023?'}]
What are the average sales per store in each state?
What are the average sales per store in each state?
Couldn't run sql: Execution failed on sql 'What are the average sales per store in each state?': near "What": syntax error

class Palm(VannaBase):
    def __init__(self, client=None, config=None):
        VannaBase.__init__(self, config=config)

        if client is not None:
            self.client = client
            return

        # default values for params
        temperature = 0.7
        max_tokens = 500
        top_p = 0.95
        top_k = 40


    def system_message(self, message: str) -> any:
        return {"role": "system", "content": message}

    def user_message(self, message: str) -> any:
        return {"role": "user", "content": message}

    def assistant_message(self, message: str) -> any:
        return {"role": "assistant", "content": message}

    def extract_sql_query(self, text):
        """
        Extracts the first SQL statement after the word 'select', ignoring case,
        matches until the first semicolon, three backticks, or the end of the string,
        and removes three backticks if they exist in the extracted string.

        Args:
        - text (str): The string to search within for an SQL statement.

        Returns:
        - str: The first SQL statement found, with three backticks removed, or an empty string if no match is found.
        """
        # Regular expression to find 'select' (ignoring case) and capture until ';', '```', or end of string
        pattern = re.compile(r"select.*?(?:;|```|$)", re.IGNORECASE | re.DOTALL)

        match = pattern.search(text)
        if match:
            # Remove three backticks from the matched string if they exist
            return match.group(0).replace("```", "")
        else:
            return text

    def generate_sql(self, question: str, **kwargs) -> str:
        # Use the super generate_sql
        sql = super().generate_sql(question, **kwargs)

        # Replace "\_" with "_"
        sql = sql.replace("\\_", "_")

        sql = sql.replace("\\", "")

        return self.extract_sql_query(sql)

    def submit_prompt(self, prompt, **kwargs) -> str:
        temperature = 0.7
        max_tokens = 500
        top_p = 0.95
        top_k = 40

        chat_model = ChatModel.from_pretrained("chat-bison@001")
        chat = chat_model.start_chat(
            context="The user will give you SQL and you will try to guess what the business question this query is answering. Return just the question without any additional explanation. Do not reference the table name in the question."
        )
        params = {
            "temperature": temperature,
            "max_output_tokens": max_tokens,
            "top_p": top_p,
            "top_k": top_k,
        }
        
        response = chat.send_message("{prompt}", **params)
        return response.text
    
    
class MyVanna(VannaDB_VectorStore, Palm):
    def __init__(self, config=None):
        VannaDB_VectorStore.__init__(self, vanna_model="***", vanna_api_key="***", config=config)
        Palm.__init__(self, client=ChatModel("chat-bison@001"), config=config)
        print(self)

vn = MyVanna()
vn.connect_to_sqlite('database/rl_database.sqlite')
vn.ask("What are the top 5 properties by sales in 2023?")

@hugoferrero
Copy link
Author

Hi Guys. Sorry, i can't lend the API KEY. I have a corporate account.

@andreped
Copy link
Contributor

Hi Guys. Sorry, i can't lend the API KEY. I have a corporate account.

No problem. I will check around if I can get a new trial. Maybe I just need to setup a new account :P

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants