Skip to content

Commit

Permalink
Add image targets
Browse files Browse the repository at this point in the history
  • Loading branch information
hupe1980 committed Apr 21, 2024
1 parent 6740590 commit f9d7660
Show file tree
Hide file tree
Showing 7 changed files with 287 additions and 23 deletions.
4 changes: 3 additions & 1 deletion aisploit/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from .model import BaseChatModel, BaseEmbeddings, BaseLLM, BaseModel
from .prompt import BasePromptValue
from .report import BaseReport
from .target import BaseTarget, Response
from .target import BaseImageTarget, BaseTarget, ContentFilteredException, Response
from .vectorstore import BaseVectorStore

__all__ = [
Expand All @@ -30,6 +30,8 @@
"BasePromptValue",
"BaseReport",
"BaseTarget",
"BaseImageTarget",
"Response",
"ContentFilteredException",
"BaseVectorStore",
]
22 changes: 21 additions & 1 deletion aisploit/core/target.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import base64
import io
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Any, Dict
from typing import Any, Dict, Literal

from PIL import Image

from .prompt import BasePromptValue


class ContentFilteredException(Exception):
pass


@dataclass
class Response:
"""A class representing a response from the target."""
Expand Down Expand Up @@ -32,3 +40,15 @@ def send_prompt(self, prompt: BasePromptValue) -> Response:
Response: The response from the target.
"""
pass


@dataclass
class BaseImageTarget(ABC):
size: Literal["512x512", "1024x1024"] = "512x512"
show_image: bool = False

def _show_base64_image(self, base64_image: str) -> None:
base64_bytes = base64_image.encode("ascii")
image_bytes = base64.b64decode(base64_bytes)
image = Image.open(io.BytesIO(image_bytes))
image.show()
4 changes: 3 additions & 1 deletion aisploit/targets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .copilot import CopilotClient, CopilotTarget
from .email import EmailReceiver, EmailSender, EmailTarget, UserPasswordAuth
from .image import OpenAIImageTarget
from .image import BedrockAmazonImageTarget, BedrockStabilityImageTarget, OpenAIImageTarget
from .langchain import LangchainTarget
from .stdout import StdOutTarget
from .target import WrapperTarget, target
Expand All @@ -12,6 +12,8 @@
"EmailSender",
"EmailReceiver",
"UserPasswordAuth",
"BedrockAmazonImageTarget",
"BedrockStabilityImageTarget",
"OpenAIImageTarget",
"LangchainTarget",
"StdOutTarget",
Expand Down
125 changes: 119 additions & 6 deletions aisploit/targets/image.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,117 @@
import json
import os
from dataclasses import dataclass
from abc import ABC
from dataclasses import dataclass, field
from typing import Optional

import boto3
from botocore.exceptions import ClientError
from openai import OpenAI

from ..core import BasePromptValue, BaseTarget, Response
from ..core import BaseImageTarget, BasePromptValue, ContentFilteredException, Response


@dataclass
class OpenAIImageTarget(BaseTarget):
class BaseBedrockImageTarget(BaseImageTarget, ABC):
session: boto3.Session = field(default_factory=lambda: boto3.Session())
region_name: str = "us-east-1"

def __post_init__(self):
self._client = self.session.client("bedrock-runtime", region_name=self.region_name)


@dataclass
class BedrockAmazonImageTarget(BaseBedrockImageTarget):
model: str = "titan-image-generator-v1"
quality: str = "standard"
seed: int = 0
cfg_scale: int = 8

def send_prompt(self, prompt: BasePromptValue) -> Response:
width, height = self.size.split("x")
body = {
"textToImageParams": {
"text": prompt.to_string(),
},
"taskType": "TEXT_IMAGE",
"imageGenerationConfig": {
"seed": self.seed,
"cfgScale": self.cfg_scale,
"quality": self.quality,
"width": int(width),
"height": int(height),
"numberOfImages": 1,
},
}

try:
response = self._client.invoke_model(
body=json.dumps(body),
modelId=f"amazon.{self.model}",
)

response_body = json.loads(response["body"].read())

if response_body["error"]:
raise Exception(response_body["error"])

base64_image = response_body["images"][0]

if self.show_image:
self._show_base64_image(base64_image)

return Response(content=base64_image)
except ClientError as e:
if e.response['Error']['Code'] == 'ValidationException':
if "blocked by our content filters" in e.response['Error']['Message']:
raise ContentFilteredException(e.response['Error']['Message']) from e

raise e


@dataclass
class BedrockStabilityImageTarget(BaseBedrockImageTarget):
model: str = "stable-diffusion-xl-v1"
steps: int = 50
seed: int = 0
cfg_scale: int = 8

def send_prompt(self, prompt: BasePromptValue) -> Response:
width, height = self.size.split("x")
body = {
"text_prompts": [{"text": prompt.to_string(), "weight": 1}],
"seed": self.seed,
"cfg_scale": self.cfg_scale,
"width": int(width),
"height": int(height),
"steps": self.steps,
}

response = self._client.invoke_model(
body=json.dumps(body),
modelId=f"stability.{self.model}",
)

response_body = json.loads(response["body"].read())

finish_reason = response_body.get("artifacts")[0].get("finishReason")

if finish_reason == "CONTENT_FILTERED":
raise ContentFilteredException(f"Image error: {finish_reason}")

if finish_reason == "ERROR":
raise Exception(f"Image error: {finish_reason}")

base64_image = response_body["artifacts"][0]["base64"]

if self.show_image:
self._show_base64_image(base64_image)

return Response(content=base64_image)


@dataclass
class OpenAIImageTarget(BaseImageTarget):
api_key: Optional[str] = None

def __post_init__(self):
Expand All @@ -18,6 +121,16 @@ def __post_init__(self):
self._client = OpenAI(api_key=self.api_key)

def send_prompt(self, prompt: BasePromptValue) -> Response:
response = self._client.images.generate(prompt=prompt.to_string(), n=1)
print(response)
return Response(content="")
response = self._client.images.generate(
prompt=prompt.to_string(),
size=self.size,
n=1,
response_format="b64_json",
)

base64_image = response.data[0].b64_json

if self.show_image:
self._show_base64_image(base64_image)

return Response(content=base64_image)
86 changes: 73 additions & 13 deletions examples/target.ipynb

Large diffs are not rendered by default.

68 changes: 67 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ python-docx = "^1.1.0"
brotli = "^1.1.0"
stdlib-list = "^0.10.0"
presidio-analyzer = "^2.2.354"
boto3 = "^1.34.88"

[tool.poetry.group.dev.dependencies]
chromadb = "^0.4.23"
Expand Down

0 comments on commit f9d7660

Please sign in to comment.