Skip to content

Commit

Permalink
Merge pull request #6 from chenxwh/replicate
Browse files Browse the repository at this point in the history
Add Replicate demo and API 
  • Loading branch information
mv-lab committed Feb 10, 2023
2 parents 7eeebfb + dc93d73 commit d0afb83
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 0 deletions.
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
![visitors](https://visitor-badge.glitch.me/badge?page_id=mv-lab/swin2sr)
[ <a href="https://colab.research.google.com/drive/1paPrt62ydwLv2U2eZqfcFsePI4X4WRR1?usp=sharing"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="google colab logo"></a>](https://colab.research.google.com/drive/1paPrt62ydwLv2U2eZqfcFsePI4X4WRR1?usp=sharing)
[![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/jjourney1125/swin2sr)
[![Replicate](https://replicate.com/cjwbw/japanese-stable-diffusion/badge)](https://replicate.com/cjwbw/japanese-stable-diffusion)
[ <a href="https://www.kaggle.com/code/jesucristo/super-resolution-demo-swin2sr-official/"><img src="https://upload.wikimedia.org/wikipedia/commons/7/7c/Kaggle_logo.png?20140912155123" alt="kaggle logo" width=50></a>](https://www.kaggle.com/code/jesucristo/super-resolution-demo-swin2sr-official/)


Expand Down
13 changes: 13 additions & 0 deletions cog.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
build:
gpu: true
cuda: "11.6.2"
python_version: "3.10"
system_packages:
- "libgl1-mesa-glx"
- "libglib2.0-0"
python_packages:
- "ipython==8.4.0"
- "torch==1.12.1 --extra-index-url=https://download.pytorch.org/whl/cu116"
- "opencv-python==4.6.0.66"
- "timm==0.6.11"
predict: "predict.py:Predictor"
91 changes: 91 additions & 0 deletions predict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import argparse
import cv2
import numpy as np
import torch
from cog import BasePredictor, Input, Path

from main_test_swin2sr import define_model, test


class Predictor(BasePredictor):
def setup(self):
"""Load the model into memory to make running multiple predictions efficient"""
print("Loading pipeline...")

self.device = "cuda:0"

args = argparse.Namespace()
args.scale = 4
args.large_model = False

tasks = ["classical_sr", "compressed_sr", "real_sr"]
paths = [
"weights/Swin2SR_ClassicalSR_X4_64.pth",
"weights/Swin2SR_CompressedSR_X4_48.pth",
"weights/Swin2SR_RealworldSR_X4_64_BSRGAN_PSNR.pth",
]
sizes = [64, 48, 128]

self.models = {}
for task, path, size in zip(tasks, paths, sizes):
args.training_patch_size = size
args.task, args.model_path = task, path
self.models[task] = define_model(args)
self.models[task].eval()
self.models[task] = self.models[task].to(self.device)

def predict(
self,
image: Path = Input(description="Input image"),
task: str = Input(
description="Choose a task",
choices=["classical_sr", "real_sr", "compressed_sr"],
default="real_sr",
),
) -> Path:
"""Run a single prediction on the model"""

model = self.models[task]

window_size = 8
scale = 4

img_lq = cv2.imread(str(image), cv2.IMREAD_COLOR).astype(np.float32) / 255.0
img_lq = np.transpose(
img_lq if img_lq.shape[2] == 1 else img_lq[:, :, [2, 1, 0]], (2, 0, 1)
) # HCW-BGR to CHW-RGB
img_lq = (
torch.from_numpy(img_lq).float().unsqueeze(0).to(self.device)
) # CHW-RGB to NCHW-RGB

# inference
with torch.no_grad():
# pad input image to be a multiple of window_size
_, _, h_old, w_old = img_lq.size()
h_pad = (h_old // window_size + 1) * window_size - h_old
w_pad = (w_old // window_size + 1) * window_size - w_old
img_lq = torch.cat([img_lq, torch.flip(img_lq, [2])], 2)[
:, :, : h_old + h_pad, :
]
img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[
:, :, :, : w_old + w_pad
]

output = model(img_lq)

if task == "compressed_sr":
output = output[0][..., : h_old * scale, : w_old * scale]
else:
output = output[..., : h_old * scale, : w_old * scale]

# save image
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
if output.ndim == 3:
output = np.transpose(
output[[2, 1, 0], :, :], (1, 2, 0)
) # CHW-RGB to HCW-BGR
output = (output * 255.0).round().astype(np.uint8) # float32 to uint8
output_path = "/tmp/out.png"
cv2.imwrite(output_path, output)

return Path(output_path)

0 comments on commit d0afb83

Please sign in to comment.