Skip to content

Commit

Permalink
New Nodes
Browse files Browse the repository at this point in the history
Batch OpenAI Image Chat
  • Loading branch information
daxcay committed Jun 23, 2024
1 parent f918241 commit 8dff16d
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 0 deletions.
6 changes: 6 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
from .classes.DATASET_OpenAIChatImage import N_CLASS_MAPPINGS as OpenAIChatImageMappings, N_DISPLAY_NAME_MAPPINGS as OpenAIChatImageNameMappings
from .classes.DATASET_LoadImage import N_CLASS_MAPPINGS as LoadImageMappings, N_DISPLAY_NAME_MAPPINGS as LoadImageNameMappings
from .classes.DATASET_SaveImage import N_CLASS_MAPPINGS as SaveImageMappings, N_DISPLAY_NAME_MAPPINGS as SaveImageNameMappings
from .classes.DATASET_TXTFileSaverBatch import N_CLASS_MAPPINGS as TXTFileSaverBatchMappings, N_DISPLAY_NAME_MAPPINGS as TXTFileSaverBatchNameMappings
from .classes.DATASET_OpenAIChatImageBatch import N_CLASS_MAPPINGS as OpenAIChatImageBatchMappings, N_DISPLAY_NAME_MAPPINGS as OpenAIChatImageBatchNameMappings

NODE_CLASS_MAPPINGS = {}
NODE_CLASS_MAPPINGS.update(TXTFileSaverMappings)
Expand All @@ -40,6 +42,8 @@
NODE_CLASS_MAPPINGS.update(OpenAIChatImageMappings)
NODE_CLASS_MAPPINGS.update(LoadImageMappings)
NODE_CLASS_MAPPINGS.update(SaveImageMappings)
NODE_CLASS_MAPPINGS.update(TXTFileSaverBatchMappings)
NODE_CLASS_MAPPINGS.update(OpenAIChatImageBatchMappings)

NODE_DISPLAY_NAME_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS.update(TXTFileSaverNameMappings)
Expand All @@ -52,5 +56,7 @@
NODE_DISPLAY_NAME_MAPPINGS.update(OpenAIChatImageNameMappings)
NODE_DISPLAY_NAME_MAPPINGS.update(LoadImageNameMappings)
NODE_DISPLAY_NAME_MAPPINGS.update(SaveImageNameMappings)
NODE_DISPLAY_NAME_MAPPINGS.update(TXTFileSaverNameMappings)
NODE_DISPLAY_NAME_MAPPINGS.update(OpenAIChatImageBatchNameMappings)

WEB_DIRECTORY = "./web"
84 changes: 84 additions & 0 deletions classes/DATASET_OpenAIChatImageBatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import base64
import io
from PIL import Image
import numpy as np
from openai import OpenAI
import os

class DATASET_OpenAIChatImageBatch:

def __init__(self):
pass

@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"images": ("IMAGE",),
"image_detail": (["low","high"], {"default": "high"}),
"prompt": ("STRING", {"multiline": True, "default": ""}),
"model": (["gpt-4o","gpt-4", "gpt-4-32k", "gpt-3.5-turbo", "gpt-4-0125-preview", "gpt-4-turbo-preview", "gpt-4-1106-preview", "gpt-4-0613"], {"default": "gpt-4o"}),
"api_url": ("STRING", {"multiline": False, "default": "https://api.openai.com/v1"}),
"api_key": ("STRING", {"multiline": False}),
"token_length": ("INT", {"default": 1024})
}
}

INPUT_IS_LIST = True
RETURN_TYPES = ("STRING",)
OUTPUT_IS_LIST = (True,)
FUNCTION = "generate"
CATEGORY = "🔶DATASET🔶"

def to_base64(self, image):
image = image[0]
i = 255. * image.cpu().numpy()
img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
buffered = io.BytesIO()
img.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")

def generate(self, images, image_detail, model, api_url, api_key, prompt, token_length):

try:

image_detail = image_detail[0]
model = model[0]
api_url = api_url[0]
api_key = api_key[0]
prompt = prompt[0]
token_length = token_length[0]

answers = []

for image in images:

ai = OpenAI(api_key=api_key, base_url=api_url)
base64img = self.to_base64(image)
if not api_key:
return "OpenAI API key is required."

request = [{"role": "system","content": "You are GPT-4."}]
request.append({"role": "user","content": [{"type": "image_url", "image_url": {"url": f"data:image/png;base64,{base64img}", "detail": image_detail}}]})
request.append({"role": "user","content": prompt})
response = ai.chat.completions.create(model=model,messages=request,max_tokens=token_length)
answer = response.choices[0].message.content

answers.append(answer)

return (answers,)

except Exception as e:
return (f"Error: {str(e)}",)

@classmethod
def IS_CHANGED(s, image, image_detail, model, api_url, api_key, prompt, token_length):
return os.urandom(16).hex()

N_CLASS_MAPPINGS = {
"DATASET_OpenAIChatImageBatch": DATASET_OpenAIChatImageBatch,
}

N_DISPLAY_NAME_MAPPINGS = {
"DATASET_OpenAIChatImageBatch": "DATASET_OpenAIChatImageBatch",
}
85 changes: 85 additions & 0 deletions classes/DATASET_TXTFileSaverBatch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import os

def save_file(filename, output_dir, content, mode='SaveNew'):
os.makedirs(output_dir, exist_ok=True)
file_path = os.path.join(output_dir, filename)

if mode == 'SaveNew':
counter = 0
while os.path.exists(file_path):
counter += 1
file_path = os.path.join(output_dir, f"{os.path.splitext(filename)[0]}_{counter}{os.path.splitext(filename)[1]}")
elif mode == 'Merge' and os.path.exists(file_path):
with open(file_path, 'a') as file:
file.write(content)
print(f"Content appended successfully to {file_path}")
return
elif mode == 'Overwrite' and os.path.exists(file_path):
os.remove(file_path)
elif mode == 'MergeAndSaveNew' and os.path.exists(file_path):
with open(file_path, 'r') as file:
existing_content = file.read()
content = existing_content + content
counter = 0
while os.path.exists(file_path):
counter += 1
file_path = os.path.join(output_dir, f"{os.path.splitext(filename)[0]}_{counter}{os.path.splitext(filename)[1]}")

with open(file_path, 'w') as file:
file.write(content)
print(f"File saved successfully at {file_path}")

class DATASET_TXTFileSaverBatch:

def __init__(self):
pass

@classmethod
def INPUT_TYPES(s):
return {
"required": {
"file_names": ("STRING",{"forceInput": True}),
"contents": ("STRING",{"forceInput": True}),
"save_in": ("STRING", {"default": "directory path"}),
"save_mode": (['Merge','Overwrite','SaveNew','MergeAndSaveNew'],),
},
}

INPUT_IS_LIST = True
RETURN_TYPES = ()
FUNCTION = "SaveIT"
OUTPUT_NODE = True

CATEGORY = "🔶DATASET🔶"

def SaveIT(self, file_names, contents, save_in, save_mode):
try:

directory = save_in[0]
mode = save_mode[0]

if not os.path.exists(directory):
os.makedirs(directory)

for i in range(0, len(contents)):
text = contents[i]
file_name = file_names[i]
save_file(f"{file_name}.txt", directory, text, mode)

except Exception as e:
print(f"Error saving: {e}")

return ()

@classmethod
def IS_CHANGED(s, content, file_name, directory, mode):
return os.urandom(16).hex()


N_CLASS_MAPPINGS = {
"DATASET_TXTFileSaverBatch": DATASET_TXTFileSaverBatch,
}

N_DISPLAY_NAME_MAPPINGS = {
"DATASET_TXTFileSaverBatch": "DATASET_TXTFileSaverBatch",
}

0 comments on commit 8dff16d

Please sign in to comment.