Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vertex AI init commit #616

Merged
merged 12 commits into from
Apr 8, 2024
1 change: 1 addition & 0 deletions dsp/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .colbertv2 import ColBERTv2
from .databricks import *
from .google import *
from .googlevertexai import *
from .gpt3 import *
from .hf import HFModel
from .hf_client import Anyscale, HFClientTGI, Together
Expand Down
174 changes: 174 additions & 0 deletions dsp/modules/googlevertexai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
"""Module for interacting with Google Vertex AI."""
from typing import Any, Dict

import backoff
from pydantic_core import PydanticCustomError

from dsp.modules.lm import LM

try:
import vertexai # type: ignore[import-untyped]
from vertexai.language_models import CodeGenerationModel, TextGenerationModel
from vertexai.preview.generative_models import GenerativeModel
except ImportError:
pass


def backoff_hdlr(details):
"""Handler from https://pypi.org/project/backoff/"""
print(f"Backing off {details['wait']:0.1f} seconds after {details['tries']} tries "
f"calling function {details['target']} with kwargs "
f"{details['kwargs']}")


def giveup_hdlr(details):
"""wrapper function that decides when to give up on retry"""
if "rate limits" in details.message:
return False
return True

class GoogleVertexAI(LM):
"""Wrapper around GoogleVertexAI's API.

Currently supported models include `gemini-pro-1.0`.
"""

def __init__(
self, model_name: str = "text-bison@002", **kwargs,
):
"""
Parameters
----------
model : str
Which pre-trained model from Google to use?
Choices are [`text-bison@002`]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you clarify what the supported models are? Additionally, model is not a parameter. model_name is. This was confusing when I was using this code for a small demo project.

**kwargs: dict
Additional arguments to pass to the API provider.
"""
super().__init__(model_name)
self._is_gemini = "gemini" in model_name
self._init_vertexai(kwargs)
if "code" in model_name:
model_cls = CodeGenerationModel
self.available_args = {
'suffix',
'max_output_tokens',
'temperature',
'stop_sequences',
'candidate_count',
}
elif "gemini" in model_name:
model_cls = GenerativeModel
self.available_args = {
'max_output_tokens',
'temperature',
'top_k',
'top_p',
'stop_sequences',
'candidate_count',
}
elif 'text' in model_name:
model_cls = TextGenerationModel
self.available_args = {
'max_output_tokens',
'temperature',
'top_k',
'top_p',
'stop_sequences',
'candidate_count',
}
else:
raise PydanticCustomError(
'model',
'model name is not valid, got "{model_name}"',
dict(wrong_value=model_name),
)
if self._is_gemini:
self.client = model_cls(model_name=model_name, safety_settings=kwargs.get('safety_settings')) # pylint: disable=unexpected-keyword-arg,no-value-for-parameter
else:
self.client = model_cls.from_pretrained(model_name)
self.provider = "googlevertexai"
self.kwargs = {
**self.kwargs,
"temperature": 0.7,
"max_output_tokens": 1024,
"top_p": 1.0,
"top_k": 1,
Comment on lines +95 to +99

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are these kwargs available for non-gemini models if they are set to defaults for non-gemini models?

**kwargs,
}

@classmethod
def _init_vertexai(cls, values: Dict) -> None:
vertexai.init(
project=values.get("project"),
location=values.get("location"),
credentials=values.get("credentials"),
)
return

def _prepare_params(
self,
parameters: Any,
) -> dict:
stop_sequences = parameters.get('stop')
params_mapping = {"n": "candidate_count", 'max_tokens':'max_output_tokens'}
params = {params_mapping.get(k, k): v for k, v in parameters.items()}
params = {**self.kwargs, "stop_sequences": stop_sequences, **params}
return {k: params[k] for k in set(params.keys()) & self.available_args}

def basic_request(self, prompt: str, **kwargs):
raw_kwargs = kwargs
kwargs = self._prepare_params(raw_kwargs)
if self._is_gemini:
response = self.client.generate_content(
[prompt],
generation_config=kwargs,
)
history = {
"prompt": prompt,
"response": {
"prompt": prompt,
"choices": [{
"text": '\n'.join(v.text for v in c.content.parts),
'safetyAttributes': {v.category: v.probability for v in c.safety_ratings},
}
for c in response.candidates],
},
"kwargs": kwargs,
"raw_kwargs": raw_kwargs,
}
else:
response = self.client.predict(prompt, **kwargs).raw_prediction_response
history = {
"prompt": prompt,
"response": {
"prompt": prompt,
"choices": [{"text": c["content"], 'safetyAttributes': c['safetyAttributes']}
for c in response.predictions],
},
"kwargs": kwargs,
"raw_kwargs": raw_kwargs,
}
self.history.append(history)

return [i['text'] for i in history['response']['choices']]

@backoff.on_exception(
backoff.expo,
(Exception),
max_time=1000,
on_backoff=backoff_hdlr,
giveup=giveup_hdlr,
)
def request(self, prompt: str, **kwargs):
"""Handles retrieval of completions from Google whilst handling API errors"""
return self.basic_request(prompt, **kwargs)

def __call__(
self,
prompt: str,
only_completed: bool = True,
return_sorted: bool = False,
**kwargs,
):
return self.request(prompt, **kwargs)
1 change: 1 addition & 0 deletions dspy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Pyserini = dsp.PyseriniRetriever
Clarifai = dsp.ClarifaiLLM
Google = dsp.Google
GoogleVertexAI = dsp.GoogleVertexAI

HFClientTGI = dsp.HFClientTGI
HFClientVLLM = HFClientVLLM
Expand Down
52 changes: 52 additions & 0 deletions examples/integrations/googlevertexai/garden_models.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# GoogleVertexAI Usage Guide

This guide provides instructions on how to use the `GoogleVertexAI` class to interact with Google Vertex AI's API for text and code generation.

## Requirements

- Python 3.10 or higher.
- The `vertexai` package installed, which can be installed via pip.
- A Google Cloud account and a configured project with access to Vertex AI.

## Installation

Ensure you have installed the `vertexai` package along with other necessary dependencies:

```bash
pip install dspy-ai[google-vertex-ai]
```

## Configuration

Before using the `GoogleVertexAI` class, you need to set up access to Google Cloud:

1. Create a project in Google Cloud Platform (GCP).
2. Enable the Vertex AI API for your project.
3. Create authentication credentials and save them in a JSON file.

## Usage

Here's an example of how to instantiate the `GoogleVertexAI` class and send a text generation request:

```python
from dsp.modules import GoogleVertexAI # Import the GoogleVertexAI class

# Initialize the class with the model name and parameters for Vertex AI
vertex_ai = GoogleVertexAI(
model_name="text-bison@002",
project="your-google-cloud-project-id",
location="us-central1",
credentials="path-to-your-service-account-file.json"
)
```

## Customizing Requests

You can customize requests by passing additional parameters such as `temperature`, `max_output_tokens`, and others supported by the Vertex AI API. This allows you to control the behavior of the text generation.

## Important Notes

- Make sure you have correctly set up access to Google Cloud to avoid authentication issues.
- Be aware of the quotas and limits of the Vertex AI API to prevent unexpected interruptions in service.

With this guide, you're ready to use `GoogleVertexAI` for interacting with Google Vertex AI's text and code generation services.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
"pinecone": ["pinecone-client~=2.2.4"],
"weaviate": ["weaviate-client~=3.26.1"],
"faiss-cpu": ["sentence_transformers", "faiss-cpu"],
"google-vertex-ai": ["google-cloud-aiplatform==1.43.0"],
},
classifiers=[
"Development Status :: 3 - Alpha",
Expand Down
Loading