Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
daxcay committed Jun 23, 2024
1 parent 205788f commit f918241
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 0 deletions.
6 changes: 6 additions & 0 deletions __init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
from .classes.DATASET_xCopy import N_CLASS_MAPPINGS as xCopyMappings, N_DISPLAY_NAME_MAPPINGS as xCopyNameMappings
from .classes.DATASET_OpenAIChat import N_CLASS_MAPPINGS as OpenAIChatMappings, N_DISPLAY_NAME_MAPPINGS as OpenAIChatNameMappings
from .classes.DATASET_OpenAIChatImage import N_CLASS_MAPPINGS as OpenAIChatImageMappings, N_DISPLAY_NAME_MAPPINGS as OpenAIChatImageNameMappings
from .classes.DATASET_LoadImage import N_CLASS_MAPPINGS as LoadImageMappings, N_DISPLAY_NAME_MAPPINGS as LoadImageNameMappings
from .classes.DATASET_SaveImage import N_CLASS_MAPPINGS as SaveImageMappings, N_DISPLAY_NAME_MAPPINGS as SaveImageNameMappings

NODE_CLASS_MAPPINGS = {}
NODE_CLASS_MAPPINGS.update(TXTFileSaverMappings)
Expand All @@ -36,6 +38,8 @@
NODE_CLASS_MAPPINGS.update(xCopyMappings)
NODE_CLASS_MAPPINGS.update(OpenAIChatMappings)
NODE_CLASS_MAPPINGS.update(OpenAIChatImageMappings)
NODE_CLASS_MAPPINGS.update(LoadImageMappings)
NODE_CLASS_MAPPINGS.update(SaveImageMappings)

NODE_DISPLAY_NAME_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS.update(TXTFileSaverNameMappings)
Expand All @@ -46,5 +50,7 @@
NODE_DISPLAY_NAME_MAPPINGS.update(xCopyNameMappings)
NODE_DISPLAY_NAME_MAPPINGS.update(OpenAIChatNameMappings)
NODE_DISPLAY_NAME_MAPPINGS.update(OpenAIChatImageNameMappings)
NODE_DISPLAY_NAME_MAPPINGS.update(LoadImageNameMappings)
NODE_DISPLAY_NAME_MAPPINGS.update(SaveImageNameMappings)

WEB_DIRECTORY = "./web"
97 changes: 97 additions & 0 deletions classes/DATASET_LoadImage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import folder_paths
import os
import node_helpers
import numpy as np
import torch
import hashlib
from PIL import Image, ImageOps, ImageSequence, ImageFile

#code credit: nodes.py comfui
class DATASET_LoadImage:
@classmethod
def INPUT_TYPES(s):
input_dir = folder_paths.get_input_directory()
files = [f for f in os.listdir(input_dir) if os.path.isfile(os.path.join(input_dir, f))]
return {
"required": {
"image": (sorted(files), {"image_upload": True})
},
}

CATEGORY = "image"

RETURN_TYPES = ("IMAGE", "MASK", "STRING", "STRING", "STRING", "STRING")
RETURN_NAMES = ("image", "image_mask", "image_name", "image_name_without_extension", "image_path", "image_directory_path")
FUNCTION = "load_image"

def load_image(self, image):

image_path = folder_paths.get_annotated_filepath(image)
img = node_helpers.pillow(Image.open, image_path)

output_images = []
output_masks = []
w, h = None, None

excluded_formats = ['MPO']

for i in ImageSequence.Iterator(img):
i = node_helpers.pillow(ImageOps.exif_transpose, i)

if i.mode == 'I':
i = i.point(lambda i: i * (1 / 255))
image = i.convert("RGB")

if len(output_images) == 0:
w = image.size[0]
h = image.size[1]

if image.size[0] != w or image.size[1] != h:
continue

image = np.array(image).astype(np.float32) / 255.0
image = torch.from_numpy(image)[None,]
if 'A' in i.getbands():
mask = np.array(i.getchannel('A')).astype(np.float32) / 255.0
mask = 1. - torch.from_numpy(mask)
else:
mask = torch.zeros((64, 64), dtype=torch.float32, device="cpu")
output_images.append(image)
output_masks.append(mask.unsqueeze(0))

if len(output_images) > 1 and img.format not in excluded_formats:
output_image = torch.cat(output_images, dim=0)
output_mask = torch.cat(output_masks, dim=0)
else:
output_image = output_images[0]
output_mask = output_masks[0]

image_name = os.path.basename(image_path)
image_dir = os.path.dirname(image_path)
image_name_without_extension = os.path.splitext(image_name)[0]

return (output_image, output_mask, image_name, image_name_without_extension, image_path, image_dir)

@classmethod
def IS_CHANGED(s, image):
image_path = folder_paths.get_annotated_filepath(image)
m = hashlib.sha256()
with open(image_path, 'rb') as f:
m.update(f.read())
return m.digest().hex()

@classmethod
def VALIDATE_INPUTS(s, image):
if not folder_paths.exists_annotated_filepath(image):
return "Invalid image file: {}".format(image)

return True


N_CLASS_MAPPINGS = {
"DATASET_LoadImage": DATASET_LoadImage,
}

N_DISPLAY_NAME_MAPPINGS = {
"DATASET_LoadImage": "DATASET_LoadImage",
}
69 changes: 69 additions & 0 deletions classes/DATASET_SaveImage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import os
import json
import json
import numpy as np
from PIL import Image
from PIL.PngImagePlugin import PngInfo
from comfy.cli_args import args

class DATASET_SaveImage:

def __init__(self):
self.compression = 4

@classmethod
def INPUT_TYPES(s):
return {
"required": {
"Images": ("IMAGE",),
"Directory": ("STRING", {}),
"Filename": ("STRING", {"default": "Image"}),
},
"hidden": {"prompt": "PROMPT", "extra_pnginfo": "EXTRA_PNGINFO"},
}

RETURN_TYPES = ()
FUNCTION = "BatchSave"
OUTPUT_NODE = True
CATEGORY = "🔵 JDCN 🔵"

def BatchSave(self, Images, Directory, Filename, prompt=None, extra_pnginfo=None):

try:

Directory = Directory
Filename = Filename

if not os.path.exists(Directory):
os.makedirs(Directory)

for image in Images:

image = image.cpu().numpy()
image = (image * 255).astype(np.uint8)
img = Image.fromarray(image)
metadata = None
if not args.disable_metadata:
metadata = PngInfo()
if prompt is not None:
metadata.add_text("prompt", json.dumps(prompt))
if extra_pnginfo is not None:
for x in extra_pnginfo:
metadata.add_text(x, json.dumps(extra_pnginfo[x]))

file_path = os.path.join(Directory,Filename)
img.save(file_path, pnginfo=metadata, compress_level=self.compression)

except Exception as e:
print(f"Error saving image: {e}")

return ()


N_CLASS_MAPPINGS = {
"DATASET_SaveImage": DATASET_SaveImage,
}

N_DISPLAY_NAME_MAPPINGS = {
"DATASET_SaveImage": "DATASET_SaveImage",
}

0 comments on commit f918241

Please sign in to comment.