Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
柏灌 committed Oct 11, 2023
1 parent 74b78cf commit 371e04b
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 8 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ _<sup>2</sup>[Department of Computing, The Hong Kong Polytechnic University](htt
<img src="samples/000004x2.gif" width="390px"/> <img src="samples/000080x2.gif" width="390px"/>

## News
(2023-10-11) [Colab demo](https://colab.research.google.com/drive/1lZ_-rSGcmreLCiRniVT973x6JLjFiC-b?usp=sharing) is now available. Credits to [Masahide Okada](https://github.com/MasahideOkada).

(2023-10-09) Add training dataset.

(2023-09-28) Add tiled latent to allow upscaling ultra high-resolution images. Please carefully set ```tiled_size``` in ```pipelines/pipeline_pasd.py``` as well as ```--vae_tiled_size``` when upscaling large images.
Expand Down
3 changes: 2 additions & 1 deletion dataloader/localdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ def __init__(self,
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
])

self.img_paths = []
folders = os.listdir(pngtxt_dir)
for folder in folders:
self.img_paths = sorted(glob.glob(f'{pngtxt_dir}/{folder}/*.png'))[:]
self.img_paths.extend(sorted(glob.glob(f'{pngtxt_dir}/{folder}/*.png'))[:])

def tokenize_caption(self, caption):
if random.random() < self.null_text_ratio:
Expand Down
25 changes: 18 additions & 7 deletions myutils/vaehook.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,11 +303,18 @@ def build_sampling(task_queue, net, is_decoder):
module = net.up_blocks
func_name = 'upsamplers'
else:
resolution_iter = range(net.num_resolutions)
block_ids = net.num_res_blocks
condition = net.num_resolutions - 1
module = net.down
func_name = 'downsample'
if sd_flag:
resolution_iter = range(net.num_resolutions)
block_ids = net.num_res_blocks
condition = net.num_resolutions - 1
module = net.down
func_name = 'downsample'
else:
resolution_iter = range(len(net.down_blocks))
block_ids = 2
condition = len(net.down_blocks) - 1
module = net.down_blocks
func_name = 'downsamplers'

for i_level in resolution_iter:
for i_block in range(block_ids):
Expand All @@ -319,7 +326,10 @@ def build_sampling(task_queue, net, is_decoder):
if sd_flag:
task_queue.append((func_name, getattr(module[i_level], func_name)))
else:
task_queue.append((func_name, module[i_level].upsamplers[0]))
if is_decoder:
task_queue.append((func_name, module[i_level].upsamplers[0]))
else:
task_queue.append((func_name, module[i_level].downsamplers[0]))

if not is_decoder:
if sd_flag:
Expand Down Expand Up @@ -688,6 +698,7 @@ def vae_tile_forward(self, z):
@return: image
"""
device = next(self.net.parameters()).device
dtype = z.dtype
net = self.net
tile_size = self.tile_size
is_decoder = self.is_decoder
Expand Down Expand Up @@ -830,4 +841,4 @@ def vae_tile_forward(self, z):

# Done!
pbar.close()
return result if result is not None else result_approx.to(device)
return result.to(dtype) if result is not None else result_approx.to(device)

0 comments on commit 371e04b

Please sign in to comment.