Skip to content

Commit

Permalink
updating
Browse files Browse the repository at this point in the history
  • Loading branch information
Fanghua-Yu committed Mar 5, 2024
1 parent c5b0f25 commit 2d430e1
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 11 deletions.
9 changes: 6 additions & 3 deletions SUPIR/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def load_QF_ckpt(config_path):
return ckpt_Q, ckpt_F


def PIL2Tensor(img, upsacle=1, min_size=1024):
def PIL2Tensor(img, upsacle=1, min_size=1024, fix_resize=None):
'''
PIL.Image -> Tensor[C, H, W], RGB, [-1, 1]
'''
Expand All @@ -67,8 +67,11 @@ def PIL2Tensor(img, upsacle=1, min_size=1024):
_upsacle = min_size / min(w, h)
w *= _upsacle
h *= _upsacle
else:
_upsacle = 1
if fix_resize is not None:
_upsacle = fix_resize / min(w, h)
w *= _upsacle
h *= _upsacle
w0, h0 = round(w), round(h)
w = int(np.round(w / 64.0)) * 64
h = int(np.round(h / 64.0)) * 64
x = img.resize((w, h), Image.BICUBIC)
Expand Down
2 changes: 1 addition & 1 deletion gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
if args.loading_half_params:
model = model.half()
if args.use_tile_vae:
model.init_tile_vae(encoder_tile_size=512, decoder_tile_size=64)
model.init_tile_vae(encoder_tile_size=args.encoder_tile_size, decoder_tile_size=args.decoder_tile_size)
model = model.to(SUPIR_device)
model.first_stage_model.denoise_encoder_s1 = copy.deepcopy(model.first_stage_model.denoise_encoder)
model.current_model = 'v0-Q'
Expand Down
28 changes: 21 additions & 7 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from llava.llava_agent import LLavaAgent
from CKPT_PTH import LLAVA_MODEL_PATH
import os
from torch.nn.functional import interpolate

if torch.cuda.device_count() >= 2:
SUPIR_device = 'cuda:0'
LLaVA_device = 'cuda:1'
Expand Down Expand Up @@ -47,31 +49,43 @@
parser.add_argument("--ae_dtype", type=str, default="bf16", choices=['fp32', 'bf16'])
parser.add_argument("--diff_dtype", type=str, default="fp16", choices=['fp32', 'fp16', 'bf16'])
parser.add_argument("--no_llava", action='store_true', default=False)
parser.add_argument("--loading_half_params", action='store_true', default=False)
parser.add_argument("--use_tile_vae", action='store_true', default=False)
parser.add_argument("--encoder_tile_size", type=int, default=512)
parser.add_argument("--decoder_tile_size", type=int, default=64)
parser.add_argument("--load_8bit_llava", action='store_true', default=False)
args = parser.parse_args()
print(args)
use_llava = not args.no_llava

# load SUPIR
model = create_SUPIR_model('options/SUPIR_v0.yaml', SUPIR_sign=args.SUPIR_sign).to(SUPIR_device)
model = create_SUPIR_model('options/SUPIR_v0.yaml', SUPIR_sign=args.SUPIR_sign)
if args.loading_half_params:
model = model.half()
if args.use_tile_vae:
model.init_tile_vae(encoder_tile_size=args.encoder_tile_size, decoder_tile_size=args.decoder_tile_size)
model.ae_dtype = convert_dtype(args.ae_dtype)
model.model.dtype = convert_dtype(args.diff_dtype)
model = model.to(SUPIR_device)
# load LLaVA
if use_llava:
llava_agent = LLavaAgent(LLAVA_MODEL_PATH, device=LLaVA_device)
llava_agent = LLavaAgent(LLAVA_MODEL_PATH, device=LLaVA_device, load_8bit=args.load_8bit_llava, load_4bit=False)
else:
llava_agent = None

os.makedirs(args.save_dir, exist_ok=True)
for img_pth in os.listdir(args.img_dir):
img_name = os.path.splitext(img_pth)[0]

LQ_img = Image.open(os.path.join(args.img_dir, img_pth))
LQ_img, h0, w0 = PIL2Tensor(LQ_img, upsacle=args.upscale, min_size=args.min_size)
LQ_ips = Image.open(os.path.join(args.img_dir, img_pth))
LQ_img, h0, w0 = PIL2Tensor(LQ_ips, upsacle=args.upscale, min_size=args.min_size)
LQ_img = LQ_img.unsqueeze(0).to(SUPIR_device)[:, :3, :, :]

# step 1: Pre-denoise for LLaVA)
clean_imgs = model.batchify_denoise(LQ_img)
clean_PIL_img = Tensor2PIL(clean_imgs[0], h0, w0)
# step 1: Pre-denoise for LLaVA, resize to 512
LQ_img_512, h1, w1 = PIL2Tensor(LQ_ips, upsacle=args.upscale, min_size=args.min_size, fix_resize=512)
LQ_img_512 = LQ_img_512.unsqueeze(0).to(SUPIR_device)[:, :3, :, :]
clean_imgs = model.batchify_denoise(LQ_img_512)
clean_PIL_img = Tensor2PIL(clean_imgs[0], h1, w1)

# step 2: LLaVA
if use_llava:
Expand Down

0 comments on commit 2d430e1

Please sign in to comment.