require 'pl' local __FILE__ = (function() return string.gsub(debug.getinfo(2, 'S').source, "^@", "") end)() local ROOT = path.dirname(__FILE__) package.path = path.join(ROOT, "lib", "?.lua;") .. package.path _G.TURBO_SSL = true require 'w2nn' local uuid = require 'uuid' local ffi = require 'ffi' local md5 = require 'md5' local iproc = require 'iproc' local reconstruct = require 'reconstruct' local image_loader = require 'image_loader' local alpha_util = require 'alpha_util' local compression = require 'compression' local gm = require 'graphicsmagick' -- Note: turbo and xlua has different implementation of string:split(). -- Therefore, string:split() has conflict issue. -- In this script, use turbo's string:split(). local turbo = require 'turbo' local cmd = torch.CmdLine() cmd:text() cmd:text("waifu2x-api") cmd:text("Options:") cmd:option("-port", 8812, 'listen port') cmd:option("-gpu", 1, 'Device ID') cmd:option("-enable_tta", 0, 'enable TTA query(0|1)') cmd:option("-crop_size", 256, 'patch size per process') cmd:option("-batch_size", 1, 'batch size') cmd:option("-thread", -1, 'number of CPU threads') cmd:option("-force_cudnn", 0, 'use cuDNN backend (0|1)') cmd:option("-max_pixels", 3000 * 3000, 'maximum number of output image pixels (e.g. 3000x3000=9000000)') cmd:option("-curl_request_timeout", 60, "request_timeout for curl") cmd:option("-curl_connect_timeout", 60, "connect_timeout for curl") cmd:option("-curl_max_redirects", 2, "max_redirects for curl") cmd:option("-max_body_size", 5 * 1024 * 1024, "maximum allowed size for uploaded files") cmd:option("-cache_max", 200, "number of cached images on RAM") local opt = cmd:parse(arg) cutorch.setDevice(opt.gpu) torch.setdefaulttensortype('torch.FloatTensor') if opt.thread > 0 then torch.setnumthreads(opt.thread) end if cudnn then cudnn.fastest = true cudnn.benchmark = true end opt.force_cudnn = opt.force_cudnn == 1 opt.enable_tta = opt.enable_tta == 1 --local ART_MODEL_DIR = path.join(ROOT, "models", "upconv_7", "art") local ART_MODEL_DIR = path.join(ROOT, "models", "cunet", "art") local PHOTO_MODEL_DIR = path.join(ROOT, "models", "upconv_7", "photo") local art_model = { scale = w2nn.load_model(path.join(ART_MODEL_DIR, "scale2.0x_model.t7"), opt.force_cudnn), noise0_scale = w2nn.load_model(path.join(ART_MODEL_DIR, "noise0_scale2.0x_model.t7"), opt.force_cudnn), noise1_scale = w2nn.load_model(path.join(ART_MODEL_DIR, "noise1_scale2.0x_model.t7"), opt.force_cudnn), noise2_scale = w2nn.load_model(path.join(ART_MODEL_DIR, "noise2_scale2.0x_model.t7"), opt.force_cudnn), noise3_scale = w2nn.load_model(path.join(ART_MODEL_DIR, "noise3_scale2.0x_model.t7"), opt.force_cudnn), noise0 = w2nn.load_model(path.join(ART_MODEL_DIR, "noise0_model.t7"), opt.force_cudnn), noise1 = w2nn.load_model(path.join(ART_MODEL_DIR, "noise1_model.t7"), opt.force_cudnn), noise2 = w2nn.load_model(path.join(ART_MODEL_DIR, "noise2_model.t7"), opt.force_cudnn), noise3 = w2nn.load_model(path.join(ART_MODEL_DIR, "noise3_model.t7"), opt.force_cudnn) } local photo_model = { scale = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "scale2.0x_model.t7"), opt.force_cudnn), noise0_scale = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "noise0_scale2.0x_model.t7"), opt.force_cudnn), noise1_scale = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "noise1_scale2.0x_model.t7"), opt.force_cudnn), noise2_scale = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "noise2_scale2.0x_model.t7"), opt.force_cudnn), noise3_scale = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "noise3_scale2.0x_model.t7"), opt.force_cudnn), noise0 = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "noise0_model.t7"), opt.force_cudnn), noise1 = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "noise1_model.t7"), opt.force_cudnn), noise2 = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "noise2_model.t7"), opt.force_cudnn), noise3 = w2nn.load_model(path.join(PHOTO_MODEL_DIR, "noise3_model.t7"), opt.force_cudnn) } collectgarbage() local CLEANUP_MODEL = true -- if you are using the low memory GPU, you could use this flag. local CACHE_DIR = path.join(ROOT, "cache") local MAX_NOISE_IMAGE = opt.max_pixels local MAX_SCALE_IMAGE = (math.sqrt(opt.max_pixels) / 2)^2 local PNG_DEPTH = 8 local CURL_OPTIONS = { request_timeout = opt.curl_request_timeout, connect_timeout = opt.curl_connect_timeout, allow_redirects = true, max_redirects = opt.curl_max_redirects } local CURL_MAX_SIZE = opt.max_body_size local function valid_size(x, scale, tta_level) if scale <= 0 then local limit = math.pow(math.floor(math.pow(MAX_NOISE_IMAGE / tta_level, 0.5)), 2) return x:size(2) * x:size(3) <= limit else local limit = math.pow(math.floor(math.pow(MAX_SCALE_IMAGE / tta_level, 0.5)), 2) return x:size(2) * x:size(3) <= limit end end local function auto_tta_level(x, scale) local limit2, limit4, limit8 if scale <= 0 then limit2 = math.pow(math.floor(math.pow(MAX_NOISE_IMAGE / 2, 0.5)), 2) limit4 = math.pow(math.floor(math.pow(MAX_NOISE_IMAGE / 4, 0.5)), 2) limit8 = math.pow(math.floor(math.pow(MAX_NOISE_IMAGE / 8, 0.5)), 2) else limit2 = math.pow(math.floor(math.pow(MAX_SCALE_IMAGE / 2, 0.5)), 2) limit4 = math.pow(math.floor(math.pow(MAX_SCALE_IMAGE / 4, 0.5)), 2) limit8 = math.pow(math.floor(math.pow(MAX_SCALE_IMAGE / 8, 0.5)), 2) end local px = x:size(2) * x:size(3) if px <= limit8 then return 8 elseif px <= limit4 then return 4 elseif px <= limit2 then return 2 else return 1 end end local function cache_url(url) local hash = md5.sumhexa(url) local cache_file = path.join(CACHE_DIR, "url_" .. hash) if path.exists(cache_file) then return image_loader.load_float(cache_file) else local res = coroutine.yield( turbo.async.HTTPClient({verify_ca=false}, nil, CURL_MAX_SIZE):fetch(url, CURL_OPTIONS) ) if res.code == 200 then local content_type = res.headers:get("Content-Type", true) if type(content_type) == "table" then content_type = content_type[1] end if content_type and content_type:find("image") then local fp = io.open(cache_file, "wb") local blob = res.body fp:write(blob) fp:close() return image_loader.decode_float(blob) end end end return nil, nil end local function get_image(req) local file_info = req:get_arguments("file") local url = req:get_argument("url", "") local file = nil local filename = nil if file_info and #file_info == 1 then file = file_info[1][1] local disp = file_info[1]["content-disposition"] if disp and disp["filename"] then filename = path.basename(disp["filename"]) end end if file and file:len() > 0 then local x, meta = image_loader.decode_float(file) return x, meta, filename elseif url and url:len() > 0 then local x, meta = cache_url(url) return x, meta, filename end return nil, nil, nil end local function cleanup_model(model) if CLEANUP_MODEL then model:clearState() -- release GPU memory end end -- cache local g_cache = {} local function cache_count() local count = 0 for _ in pairs(g_cache) do count = count + 1 end return count end local function cache_remove_old() local old_time = nil local old_key = nil for k, v in pairs(g_cache) do if old_time == nil or old_time > v.updated_at then old_key = k old_time = v.updated_at end end if old_key then g_cache[old_key] = nil end end local function cache_compress(raw_image) if raw_image then compressed_image = compression.compress(iproc.float2byte(raw_image)) return compressed_image else return nil end end local function cache_decompress(compressed_image) if compressed_image then local raw_image = compression.decompress(compressed_image) return iproc.byte2float(raw_image) else return nil end end local function cache_get(filename) local cache = g_cache[filename] if cache then return {image = cache_decompress(cache.image), alpha = cache_decompress(cache.alpha)} else return nil end end local function cache_put(filename, image, alpha) g_cache[filename] = {image = cache_compress(image), alpha = cache_compress(alpha), updated_at = os.time()}; local count = cache_count(g_cache) if count > opt.cache_max then cache_remove_old() end end local function convert(x, meta, options) local cache_file = path.join(CACHE_DIR, options.prefix .. ".png") local alpha = meta.alpha local alpha_orig = alpha local cache = cache_get(cache_file) if cache then meta = tablex.copy(meta) meta.alpha = cache.alpha return cache.image, meta else local model = nil if options.style == "art" then model = art_model elseif options.style == "photo" then model = photo_model end if options.border then x = alpha_util.make_border(x, alpha_orig, reconstruct.offset_size(model.scale)) end if (options.method == "scale" or options.method == "noise0_scale" or options.method == "noise1_scale" or options.method == "noise2_scale" or options.method == "noise3_scale") then x = reconstruct.scale_tta(model[options.method], options.tta_level, 2.0, x, opt.crop_size, opt.batch_size) if alpha then if not (alpha:size(2) == x:size(2) and alpha:size(3) == x:size(3)) then alpha = reconstruct.scale(model.scale, 2.0, alpha, opt.crop_size, opt.batch_size) cleanup_model(model.scale) end end cleanup_model(model[options.method]) elseif (options.method == "noise0" or options.method == "noise1" or options.method == "noise2" or options.method == "noise3") then x = reconstruct.image_tta(model[options.method], options.tta_level, x, opt.crop_size, opt.batch_size) cleanup_model(model[options.method]) end cache_put(cache_file, x, alpha) meta = tablex.copy(meta) meta.alpha = alpha return x, meta end end local function client_disconnected(handler) return not(handler.request and handler.request.connection and handler.request.connection.stream and (not handler.request.connection.stream:closed())) end local function make_output_filename(filename, mode) local e = path.extension(filename) local base = filename:sub(0, filename:len() - e:len()) if mode then return base .. "_waifu2x_" .. mode .. ".png" else return base .. ".png" end end local APIHandler = class("APIHandler", turbo.web.RequestHandler) function APIHandler:post() if client_disconnected(self) then self:set_status(400) self:write("client disconnected") return end local x, meta, filename = get_image(self) local scale = tonumber(self:get_argument("scale", "-1")) local noise = tonumber(self:get_argument("noise", "-1")) local tta_level = tonumber(self:get_argument("tta_level", "1")) local style = self:get_argument("style", "art") local download = (self:get_argument("download", "")):len() if client_disconnected(self) then self:set_status(400) self:write("client disconnected") return end if opt.enable_tta then if tta_level == 0 then tta_level = auto_tta_level(x, scale) end if not (tta_level == 0 or tta_level == 1 or tta_level == 2 or tta_level == 4 or tta_level == 8) then tta_level = 1 end else tta_level = 1 end if style ~= "art" then style = "photo" -- style must be art or photo end if x and valid_size(x, scale, tta_level) then local prefix = nil if (noise >= 0 or scale > 0) then local hash = md5.sumhexa(meta.blob) local alpha_prefix = style .. "_" .. hash .. "_alpha" local border = false if scale >= 0 and meta.alpha then border = true end if (scale == 1 or scale == 2) and (noise < 0) then prefix = style .. "_scale_tta_" .. tta_level .. "_" x, meta = convert(x, meta, {method = "scale", style = style, tta_level = tta_level, prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border}) if scale == 1 then x = iproc.scale(x, x:size(3) * (1.6 / 2.0), x:size(2) * (1.6 / 2.0), "Sinc") end elseif (scale == 1 or scale == 2) and (noise == 0 or noise == 1 or noise == 2 or noise == 3) then prefix = style .. string.format("_noise%d_scale_tta_", noise) .. tta_level .. "_" x, meta = convert(x, meta, {method = string.format("noise%d_scale", noise), style = style, tta_level = tta_level, prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border}) if scale == 1 then x = iproc.scale(x, x:size(3) * (1.6 / 2.0), x:size(2) * (1.6 / 2.0), "Sinc") end elseif (noise == 0 or noise == 1 or noise == 2 or noise == 3) then prefix = style .. string.format("_noise%d_tta_", noise) .. tta_level .. "_" x = convert(x, meta, {method = string.format("noise%d", noise), style = style, tta_level = tta_level, prefix = prefix .. hash, alpha_prefix = alpha_prefix, border = border}) border = false end end local name = nil if filename then if prefix then name = make_output_filename(filename, prefix:sub(0, prefix:len()-1)) else name = make_output_filename(filename, nil) end else name = uuid() .. ".png" end local blob = image_loader.encode_png(alpha_util.composite(x, meta.alpha), tablex.update({depth = PNG_DEPTH, inplace = true}, meta)) self:set_header("Content-Length", string.format("%d", #blob)) if download > 0 then self:set_header("Content-Type", "application/octet-stream") self:set_header("Content-Disposition", string.format('attachment; filename="%s"', name)) else self:set_header("Content-Type", "image/png") self:set_header("Content-Disposition", string.format('inline; filename="%s"', name)) end self:write(blob) else if not x then self:set_status(400) self:write("ERROR: An error occurred. (unsupported image format/connection timeout/file is too large)") else self:set_status(400) self:write("ERROR: image size exceeds maximum allowable size.") end end collectgarbage() end local FormHandler = class("FormHandler", turbo.web.RequestHandler) local index_ja = file.read(path.join(ROOT, "assets", "index.ja.html")) local index_ru = file.read(path.join(ROOT, "assets", "index.ru.html")) local index_pt = file.read(path.join(ROOT, "assets", "index.pt.html")) local index_es = file.read(path.join(ROOT, "assets", "index.es.html")) local index_fr = file.read(path.join(ROOT, "assets", "index.fr.html")) local index_de = file.read(path.join(ROOT, "assets", "index.de.html")) local index_tr = file.read(path.join(ROOT, "assets", "index.tr.html")) local index_zh_cn = file.read(path.join(ROOT, "assets", "index.zh-CN.html")) local index_zh_tw = file.read(path.join(ROOT, "assets", "index.zh-TW.html")) local index_ko = file.read(path.join(ROOT, "assets", "index.ko.html")) local index_nl = file.read(path.join(ROOT, "assets", "index.nl.html")) local index_ca = file.read(path.join(ROOT, "assets", "index.ca.html")) local index_ro = file.read(path.join(ROOT, "assets", "index.ro.html")) local index_it = file.read(path.join(ROOT, "assets", "index.it.html")) local index_eo = file.read(path.join(ROOT, "assets", "index.eo.html")) local index_no = file.read(path.join(ROOT, "assets", "index.no.html")) local index_uk = file.read(path.join(ROOT, "assets", "index.uk.html")) local index_pl = file.read(path.join(ROOT, "assets", "index.pl.html")) local index_bg = file.read(path.join(ROOT, "assets", "index.bg.html")) local index_en = file.read(path.join(ROOT, "assets", "index.html")) function FormHandler:get() local lang = self.request.headers:get("Accept-Language") if lang then local langs = utils.split(lang, ",") for i = 1, #langs do langs[i] = utils.split(langs[i], ";")[1] end if langs[1] == "ja" then self:write(index_ja) elseif langs[1] == "ru" then self:write(index_ru) elseif langs[1] == "pt" or langs[1] == "pt-BR" then self:write(index_pt) elseif langs[1] == "es" or langs[1] == "es-ES" then self:write(index_es) elseif langs[1] == "fr" then self:write(index_fr) elseif langs[1] == "de" then self:write(index_de) elseif langs[1] == "tr" then self:write(index_tr) elseif langs[1] == "zh-CN" or langs[1] == "zh" then self:write(index_zh_cn) elseif langs[1] == "zh-TW" then self:write(index_zh_tw) elseif langs[1] == "ko" then self:write(index_ko) elseif langs[1] == "nl" then self:write(index_nl) elseif langs[1] == "ca" or langs[1] == "ca-ES" or langs[1] == "ca-FR" or langs[1] == "ca-IT" or langs[1] == "ca-AD" then self:write(index_ca) elseif langs[1] == "ro" then self:write(index_ro) elseif langs[1] == "it" then self:write(index_it) elseif langs[1] == "eo" then self:write(index_eo) elseif langs[1] == "no" then self:write(index_no) elseif langs[1] == "uk" then self:write(index_uk) elseif langs[1] == "pl" then self:write(index_pl) elseif langs[1] == "bg" then self:write(index_bg) else self:write(index_en) end else self:write(index_en) end end turbo.log.categories = { ["success"] = true, ["notice"] = false, ["warning"] = true, ["error"] = true, ["debug"] = false, ["development"] = false } local app = turbo.web.Application:new( { {"^/$", FormHandler}, {"^/api$", APIHandler}, {"^/([%a%d%.%-_]+)$", turbo.web.StaticFileHandler, path.join(ROOT, "assets/")}, } ) app:listen(opt.port, "0.0.0.0", {max_body_size = CURL_MAX_SIZE}) turbo.ioloop.instance():start()