{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "name": "first-order-model-demo", "provenance": [] }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "markdown", "metadata": { "id": "view-in-github", "colab_type": "text" }, "source": [ "\"Open", "\"Kaggle\"" ] }, { "cell_type": "markdown", "metadata": { "id": "cdO_RxQZLahB" }, "source": [ "# Demo for paper \"First Order Motion Model for Image Animation\"\n", "To try the demo, press the 2 play buttons in order and scroll to the bottom. Note that it may take several minutes to load." ] }, { "cell_type": "code", "metadata": { "id": "UCMFMJV7K-ag" }, "source": [ "!pip install ffmpy &> /dev/null\n", "!git init -q .\n", "!git remote add origin https://github.com/AliaksandrSiarohin/first-order-model\n", "!git pull -q origin master\n", "!git clone -q https://github.com/graphemecluster/first-order-model-demo demo" ], "execution_count": null, "outputs": [] }, { "cell_type": "code", "metadata": { "id": "Oxi6-riLOgnm" }, "source": [ "import IPython.display\n", "import PIL.Image\n", "import cv2\n", "import imageio\n", "import io\n", "import ipywidgets\n", "import numpy\n", "import os.path\n", "import requests\n", "import skimage.transform\n", "import warnings\n", "from base64 import b64encode\n", "from demo import load_checkpoints, make_animation\n", "from ffmpy import FFmpeg\n", "from google.colab import files, output\n", "from IPython.display import HTML, Javascript\n", "from skimage import img_as_ubyte\n", "warnings.filterwarnings(\"ignore\")\n", "os.makedirs(\"user\", exist_ok=True)\n", "\n", "display(HTML(\"\"\"\n", "\n", "\"\"\"))\n", "\n", "def thumbnail(file):\n", "\treturn imageio.get_reader(file, mode='I', format='FFMPEG').get_next_data()\n", "\n", "def create_image(i, j):\n", "\timage_widget = ipywidgets.Image(\n", "\t\tvalue=open('demo/images/%d%d.png' % (i, j), 'rb').read(),\n", "\t\tformat='png'\n", "\t)\n", "\timage_widget.add_class('resource')\n", "\timage_widget.add_class('resource-image')\n", "\timage_widget.add_class('resource-image%d%d' % (i, j))\n", "\treturn image_widget\n", "\n", "def create_video(i):\n", "\tvideo_widget = ipywidgets.Image(\n", "\t\tvalue=cv2.imencode('.png', cv2.cvtColor(thumbnail('demo/videos/%d.mp4' % i), cv2.COLOR_RGB2BGR))[1].tostring(),\n", "\t\tformat='png'\n", "\t)\n", "\tvideo_widget.add_class('resource')\n", "\tvideo_widget.add_class('resource-video')\n", "\tvideo_widget.add_class('resource-video%d' % i)\n", "\treturn video_widget\n", "\n", "def create_title(title):\n", "\ttitle_widget = ipywidgets.Label(title)\n", "\ttitle_widget.add_class('title')\n", "\treturn title_widget\n", "\n", "def download_output(button):\n", "\tcomplete.layout.display = 'none'\n", "\tloading.layout.display = ''\n", "\tfiles.download('output.mp4')\n", "\tloading.layout.display = 'none'\n", "\tcomplete.layout.display = ''\n", "\n", "def convert_output(button):\n", "\tcomplete.layout.display = 'none'\n", "\tloading.layout.display = ''\n", "\tFFmpeg(inputs={'output.mp4': None}, outputs={'scaled.mp4': '-vf \"scale=1080x1080:flags=lanczos,pad=1920:1080:420:0\" -y'}).run()\n", "\tfiles.download('scaled.mp4')\n", "\tloading.layout.display = 'none'\n", "\tcomplete.layout.display = ''\n", "\n", "def back_to_main(button):\n", "\tcomplete.layout.display = 'none'\n", "\tmain.layout.display = ''\n", "\n", "label_or = ipywidgets.Label('or')\n", "label_or.add_class('label-or')\n", "\n", "image_titles = ['Peoples', 'Cartoons', 'Dolls', 'Game of Thrones', 'Statues']\n", "image_lengths = [8, 4, 8, 9, 4]\n", "\n", "image_tab = ipywidgets.Tab()\n", "image_tab.children = [ipywidgets.HBox([create_image(i, j) for j in range(length)]) for i, length in enumerate(image_lengths)]\n", "for i, title in enumerate(image_titles):\n", "\timage_tab.set_title(i, title)\n", "\n", "input_image_widget = ipywidgets.Output()\n", "input_image_widget.add_class('input-widget')\n", "upload_input_image_button = ipywidgets.FileUpload(accept='image/*', button_style='primary')\n", "upload_input_image_button.add_class('input-button')\n", "image_part = ipywidgets.HBox([\n", "\tipywidgets.VBox([input_image_widget, upload_input_image_button]),\n", "\tlabel_or,\n", "\timage_tab\n", "])\n", "\n", "video_tab = ipywidgets.Tab()\n", "video_tab.children = [ipywidgets.HBox([create_video(i) for i in range(5)])]\n", "video_tab.set_title(0, 'All Videos')\n", "\n", "input_video_widget = ipywidgets.Output()\n", "input_video_widget.add_class('input-widget')\n", "upload_input_video_button = ipywidgets.FileUpload(accept='video/*', button_style='primary')\n", "upload_input_video_button.add_class('input-button')\n", "video_part = ipywidgets.HBox([\n", "\tipywidgets.VBox([input_video_widget, upload_input_video_button]),\n", "\tlabel_or,\n", "\tvideo_tab\n", "])\n", "\n", "model = ipywidgets.Dropdown(\n", "\tdescription=\"Model:\",\n", "\toptions=[\n", "\t\t'vox',\n", "\t\t'vox-adv',\n", "\t\t'taichi',\n", "\t\t'taichi-adv',\n", "\t\t'nemo',\n", "\t\t'mgif',\n", "\t\t'fashion',\n", "\t\t'bair'\n", "\t]\n", ")\n", "warning = ipywidgets.HTML('Warning: Upload your own images and videos (see README)')\n", "warning.add_class('warning')\n", "model_part = ipywidgets.HBox([model, warning])\n", "\n", "relative = ipywidgets.Checkbox(description=\"Relative keypoint displacement (Inherit object proporions from the video)\", value=True)\n", "adapt_movement_scale = ipywidgets.Checkbox(description=\"Adapt movement scale (Don’t touch unless you know want you are doing)\", value=True)\n", "generate_button = ipywidgets.Button(description=\"Generate\", button_style='primary')\n", "main = ipywidgets.VBox([\n", "\tcreate_title('Choose Image'),\n", "\timage_part,\n", "\tcreate_title('Choose Video'),\n", "\tvideo_part,\n", "\tcreate_title('Settings'),\n", "\tmodel_part,\n", "\trelative,\n", "\tadapt_movement_scale,\n", "\tgenerate_button\n", "])\n", "\n", "loader = ipywidgets.Label()\n", "loader.add_class(\"loader\")\n", "loading_label = ipywidgets.Label(\"This may take several minutes to process…\")\n", "loading_label.add_class(\"loading-label\")\n", "loading = ipywidgets.VBox([loader, loading_label])\n", "loading.add_class('loading')\n", "\n", "output_widget = ipywidgets.Output()\n", "output_widget.add_class('output-widget')\n", "download = ipywidgets.Button(description='Download', button_style='primary')\n", "download.add_class('output-button')\n", "download.on_click(download_output)\n", "convert = ipywidgets.Button(description='Convert to 1920×1080', button_style='primary')\n", "convert.add_class('output-button')\n", "convert.on_click(convert_output)\n", "back = ipywidgets.Button(description='Back', button_style='primary')\n", "back.add_class('output-button')\n", "back.on_click(back_to_main)\n", "\n", "comparison_widget = ipywidgets.Output()\n", "comparison_widget.add_class('comparison-widget')\n", "comparison_label = ipywidgets.Label('Comparison')\n", "comparison_label.add_class('comparison-label')\n", "complete = ipywidgets.HBox([\n", "\tipywidgets.VBox([output_widget, download, convert, back]),\n", "\tipywidgets.VBox([comparison_widget, comparison_label])\n", "])\n", "\n", "display(ipywidgets.VBox([main, loading, complete]))\n", "display(Javascript(\"\"\"\n", "var images, videos;\n", "function deselectImages() {\n", "\timages.forEach(function(item) {\n", "\t\titem.classList.remove(\"selected\");\n", "\t});\n", "}\n", "function deselectVideos() {\n", "\tvideos.forEach(function(item) {\n", "\t\titem.classList.remove(\"selected\");\n", "\t});\n", "}\n", "function invokePython(func) {\n", "\tgoogle.colab.kernel.invokeFunction(\"notebook.\" + func, [].slice.call(arguments, 1), {});\n", "}\n", "setTimeout(function() {\n", "\t(images = [].slice.call(document.getElementsByClassName(\"resource-image\"))).forEach(function(item) {\n", "\t\titem.addEventListener(\"click\", function() {\n", "\t\t\tdeselectImages();\n", "\t\t\titem.classList.add(\"selected\");\n", "\t\t\tinvokePython(\"select_image\", item.className.match(/resource-image(\\d\\d)/)[1]);\n", "\t\t});\n", "\t});\n", "\timages[0].classList.add(\"selected\");\n", "\t(videos = [].slice.call(document.getElementsByClassName(\"resource-video\"))).forEach(function(item) {\n", "\t\titem.addEventListener(\"click\", function() {\n", "\t\t\tdeselectVideos();\n", "\t\t\titem.classList.add(\"selected\");\n", "\t\t\tinvokePython(\"select_video\", item.className.match(/resource-video(\\d)/)[1]);\n", "\t\t});\n", "\t});\n", "\tvideos[0].classList.add(\"selected\");\n", "}, 1000);\n", "\"\"\"))\n", "\n", "selected_image = None\n", "def select_image(filename):\n", "\tglobal selected_image\n", "\tselected_image = resize(PIL.Image.open('demo/images/%s.png' % filename).convert(\"RGB\"))\n", "\tinput_image_widget.clear_output(wait=True)\n", "\twith input_image_widget:\n", "\t\tdisplay(HTML('Image'))\n", "\tinput_image_widget.remove_class('uploaded')\n", "output.register_callback(\"notebook.select_image\", select_image)\n", "\n", "selected_video = None\n", "def select_video(filename):\n", "\tglobal selected_video\n", "\tselected_video = 'demo/videos/%s.mp4' % filename\n", "\tinput_video_widget.clear_output(wait=True)\n", "\twith input_video_widget:\n", "\t\tdisplay(HTML('Video'))\n", "\tinput_video_widget.remove_class('uploaded')\n", "output.register_callback(\"notebook.select_video\", select_video)\n", "\n", "def resize(image, size=(256, 256)):\n", " w, h = image.size\n", " d = min(w, h)\n", " r = ((w - d) // 2, (h - d) // 2, (w + d) // 2, (h + d) // 2)\n", " return image.resize(size, resample=PIL.Image.LANCZOS, box=r)\n", "\n", "def upload_image(change):\n", "\tglobal selected_image\n", "\tfor name, file_info in upload_input_image_button.value.items():\n", "\t\tcontent = file_info['content']\n", "\tif content is not None:\n", "\t\tselected_image = resize(PIL.Image.open(io.BytesIO(content)).convert(\"RGB\"))\n", "\t\tinput_image_widget.clear_output(wait=True)\n", "\t\twith input_image_widget:\n", "\t\t\tdisplay(selected_image)\n", "\t\tinput_image_widget.add_class('uploaded')\n", "\t\tdisplay(Javascript('deselectImages()'))\n", "upload_input_image_button.observe(upload_image, names='value')\n", "\n", "def upload_video(change):\n", "\tglobal selected_video\n", "\tfor name, file_info in upload_input_video_button.value.items():\n", "\t\tcontent = file_info['content']\n", "\tif content is not None:\n", "\t\tselected_video = 'user/' + name\n", "\t\tpreview = resize(PIL.Image.fromarray(thumbnail(content)).convert(\"RGB\"))\n", "\t\tinput_video_widget.clear_output(wait=True)\n", "\t\twith input_video_widget:\n", "\t\t\tdisplay(preview)\n", "\t\tinput_video_widget.add_class('uploaded')\n", "\t\tdisplay(Javascript('deselectVideos()'))\n", "\t\twith open(selected_video, 'wb') as video:\n", "\t\t\tvideo.write(content)\n", "upload_input_video_button.observe(upload_video, names='value')\n", "\n", "def change_model(change):\n", "\tif model.value.startswith('vox'):\n", "\t\twarning.remove_class('warn')\n", "\telse:\n", "\t\twarning.add_class('warn')\n", "model.observe(change_model, names='value')\n", "\n", "def generate(button):\n", "\tmain.layout.display = 'none'\n", "\tloading.layout.display = ''\n", "\tfilename = model.value + ('' if model.value == 'fashion' else '-cpk') + '.pth.tar'\n", "\tif not os.path.isfile(filename):\n", "\t\tdownload = requests.get(requests.get('https://cloud-api.yandex.net/v1/disk/public/resources/download?public_key=https://yadi.sk/d/lEw8uRm140L_eQ&path=/' + filename).json().get('href'))\n", "\t\twith open(filename, 'wb') as checkpoint:\n", "\t\t\tcheckpoint.write(download.content)\n", "\treader = imageio.get_reader(selected_video, mode='I', format='FFMPEG')\n", "\tfps = reader.get_meta_data()['fps']\n", "\tdriving_video = []\n", "\tfor frame in reader:\n", "\t\tdriving_video.append(frame)\n", "\tgenerator, kp_detector = load_checkpoints(config_path='config/%s-256.yaml' % model.value, checkpoint_path=filename)\n", "\tpredictions = make_animation(\n", "\t\tskimage.transform.resize(numpy.asarray(selected_image), (256, 256)),\n", "\t\t[skimage.transform.resize(frame, (256, 256)) for frame in driving_video],\n", "\t\tgenerator,\n", "\t\tkp_detector,\n", "\t\trelative=relative.value,\n", "\t\tadapt_movement_scale=adapt_movement_scale.value\n", "\t)\n", "\tif selected_video.startswith('user/') or selected_video == 'demo/videos/0.mp4':\n", "\t\timageio.mimsave('temp.mp4', [img_as_ubyte(frame) for frame in predictions], fps=fps)\n", "\t\tFFmpeg(inputs={'temp.mp4': None, selected_video: None}, outputs={'output.mp4': '-c copy -y'}).run()\n", "\telse:\n", "\t\timageio.mimsave('output.mp4', [img_as_ubyte(frame) for frame in predictions], fps=fps)\n", "\tloading.layout.display = 'none'\n", "\tcomplete.layout.display = ''\n", "\twith output_widget:\n", "\t\tdisplay(HTML('