-
Notifications
You must be signed in to change notification settings - Fork 1.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Vertex AI init commit #616
Merged
Merged
Changes from 7 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
00bf5a3
init commit
Demontego 7144fcf
fix
Demontego 5073e33
Merge branch 'main' of https://github.com/Demontego/dspy into vertexai
Demontego bdd5c46
add readme for models
Demontego 6e4b395
fix
Demontego 0bac91c
Merge branch 'main' of https://github.com/Demontego/dspy into vertexai
Demontego 69683df
fix print
Demontego 78ce9c4
Merge branch 'main' of https://github.com/Demontego/dspy into vertexai
Demontego 5f2a0a5
Merge branch 'main' of https://github.com/Demontego/dspy into vertexai
Demontego b841d64
fix
Demontego a1c6bd5
Merge branch 'main' of https://github.com/Demontego/dspy into vertexai
Demontego c7bd2bd
fix
Demontego File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
"""Module for interacting with Google Vertex AI.""" | ||
from typing import Any, Dict | ||
|
||
import backoff | ||
from pydantic_core import PydanticCustomError | ||
|
||
from dsp.modules.lm import LM | ||
|
||
try: | ||
import vertexai # type: ignore[import-untyped] | ||
from vertexai.language_models import CodeGenerationModel, TextGenerationModel | ||
from vertexai.preview.generative_models import GenerativeModel | ||
except ImportError: | ||
pass | ||
|
||
|
||
def backoff_hdlr(details): | ||
"""Handler from https://pypi.org/project/backoff/""" | ||
print(f"Backing off {details['wait']:0.1f} seconds after {details['tries']} tries " | ||
f"calling function {details['target']} with kwargs " | ||
f"{details['kwargs']}") | ||
|
||
|
||
def giveup_hdlr(details): | ||
"""wrapper function that decides when to give up on retry""" | ||
if "rate limits" in details.message: | ||
return False | ||
return True | ||
|
||
class GoogleVertexAI(LM): | ||
"""Wrapper around GoogleVertexAI's API. | ||
|
||
Currently supported models include `gemini-pro-1.0`. | ||
""" | ||
|
||
def __init__( | ||
self, model_name: str = "text-bison@002", **kwargs, | ||
): | ||
""" | ||
Parameters | ||
---------- | ||
model : str | ||
Which pre-trained model from Google to use? | ||
Choices are [`text-bison@002`] | ||
**kwargs: dict | ||
Additional arguments to pass to the API provider. | ||
""" | ||
super().__init__(model_name) | ||
self._is_gemini = "gemini" in model_name | ||
self._init_vertexai(kwargs) | ||
if "code" in model_name: | ||
model_cls = CodeGenerationModel | ||
self.available_args = { | ||
'suffix', | ||
'max_output_tokens', | ||
'temperature', | ||
'stop_sequences', | ||
'candidate_count', | ||
} | ||
elif "gemini" in model_name: | ||
model_cls = GenerativeModel | ||
self.available_args = { | ||
'max_output_tokens', | ||
'temperature', | ||
'top_k', | ||
'top_p', | ||
'stop_sequences', | ||
'candidate_count', | ||
} | ||
elif 'text' in model_name: | ||
model_cls = TextGenerationModel | ||
self.available_args = { | ||
'max_output_tokens', | ||
'temperature', | ||
'top_k', | ||
'top_p', | ||
'stop_sequences', | ||
'candidate_count', | ||
} | ||
else: | ||
raise PydanticCustomError( | ||
'model', | ||
'model name is not valid, got "{model_name}"', | ||
dict(wrong_value=model_name), | ||
) | ||
if self._is_gemini: | ||
self.client = model_cls(model_name=model_name, safety_settings=kwargs.get('safety_settings')) # pylint: disable=unexpected-keyword-arg,no-value-for-parameter | ||
else: | ||
self.client = model_cls.from_pretrained(model_name) | ||
self.provider = "googlevertexai" | ||
self.kwargs = { | ||
**self.kwargs, | ||
"temperature": 0.7, | ||
"max_output_tokens": 1024, | ||
"top_p": 1.0, | ||
"top_k": 1, | ||
Comment on lines
+95
to
+99
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why are these kwargs available for non-gemini models if they are set to defaults for non-gemini models? |
||
**kwargs, | ||
} | ||
|
||
@classmethod | ||
def _init_vertexai(cls, values: Dict) -> None: | ||
vertexai.init( | ||
project=values.get("project"), | ||
location=values.get("location"), | ||
credentials=values.get("credentials"), | ||
) | ||
return | ||
|
||
def _prepare_params( | ||
self, | ||
parameters: Any, | ||
) -> dict: | ||
stop_sequences = parameters.get('stop') | ||
params_mapping = {"n": "candidate_count", 'max_tokens':'max_output_tokens'} | ||
params = {params_mapping.get(k, k): v for k, v in parameters.items()} | ||
params = {**self.kwargs, "stop_sequences": stop_sequences, **params} | ||
return {k: params[k] for k in set(params.keys()) & self.available_args} | ||
|
||
def basic_request(self, prompt: str, **kwargs): | ||
raw_kwargs = kwargs | ||
kwargs = self._prepare_params(raw_kwargs) | ||
if self._is_gemini: | ||
response = self.client.generate_content( | ||
[prompt], | ||
generation_config=kwargs, | ||
) | ||
history = { | ||
"prompt": prompt, | ||
"response": { | ||
"prompt": prompt, | ||
"choices": [{ | ||
"text": '\n'.join(v.text for v in c.content.parts), | ||
'safetyAttributes': {v.category: v.probability for v in c.safety_ratings}, | ||
} | ||
for c in response.candidates], | ||
}, | ||
"kwargs": kwargs, | ||
"raw_kwargs": raw_kwargs, | ||
} | ||
else: | ||
response = self.client.predict(prompt, **kwargs).raw_prediction_response | ||
history = { | ||
"prompt": prompt, | ||
"response": { | ||
"prompt": prompt, | ||
"choices": [{"text": c["content"], 'safetyAttributes': c['safetyAttributes']} | ||
for c in response.predictions], | ||
}, | ||
"kwargs": kwargs, | ||
"raw_kwargs": raw_kwargs, | ||
} | ||
self.history.append(history) | ||
|
||
return [i['text'] for i in history['response']['choices']] | ||
|
||
@backoff.on_exception( | ||
backoff.expo, | ||
(Exception), | ||
max_time=1000, | ||
on_backoff=backoff_hdlr, | ||
giveup=giveup_hdlr, | ||
) | ||
def request(self, prompt: str, **kwargs): | ||
"""Handles retrieval of completions from Google whilst handling API errors""" | ||
return self.basic_request(prompt, **kwargs) | ||
|
||
def __call__( | ||
self, | ||
prompt: str, | ||
only_completed: bool = True, | ||
return_sorted: bool = False, | ||
**kwargs, | ||
): | ||
return self.request(prompt, **kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# GoogleVertexAI Usage Guide | ||
|
||
This guide provides instructions on how to use the `GoogleVertexAI` class to interact with Google Vertex AI's API for text and code generation. | ||
|
||
## Requirements | ||
|
||
- Python 3.10 or higher. | ||
- The `vertexai` package installed, which can be installed via pip. | ||
- A Google Cloud account and a configured project with access to Vertex AI. | ||
|
||
## Installation | ||
|
||
Ensure you have installed the `vertexai` package along with other necessary dependencies: | ||
|
||
```bash | ||
pip install dspy-ai[google-vertex-ai] | ||
``` | ||
|
||
## Configuration | ||
|
||
Before using the `GoogleVertexAI` class, you need to set up access to Google Cloud: | ||
|
||
1. Create a project in Google Cloud Platform (GCP). | ||
2. Enable the Vertex AI API for your project. | ||
3. Create authentication credentials and save them in a JSON file. | ||
|
||
## Usage | ||
|
||
Here's an example of how to instantiate the `GoogleVertexAI` class and send a text generation request: | ||
|
||
```python | ||
from dsp.modules import GoogleVertexAI # Import the GoogleVertexAI class | ||
|
||
# Initialize the class with the model name and parameters for Vertex AI | ||
vertex_ai = GoogleVertexAI( | ||
model_name="text-bison@002", | ||
project="your-google-cloud-project-id", | ||
location="us-central1", | ||
credentials="path-to-your-service-account-file.json" | ||
) | ||
``` | ||
|
||
## Customizing Requests | ||
|
||
You can customize requests by passing additional parameters such as `temperature`, `max_output_tokens`, and others supported by the Vertex AI API. This allows you to control the behavior of the text generation. | ||
|
||
## Important Notes | ||
|
||
- Make sure you have correctly set up access to Google Cloud to avoid authentication issues. | ||
- Be aware of the quotas and limits of the Vertex AI API to prevent unexpected interruptions in service. | ||
|
||
With this guide, you're ready to use `GoogleVertexAI` for interacting with Google Vertex AI's text and code generation services. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you clarify what the supported models are? Additionally,
model
is not a parameter.model_name
is. This was confusing when I was using this code for a small demo project.