{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"name": "first-order-model-demo",
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"",
""
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "cdO_RxQZLahB"
},
"source": [
"# Demo for paper \"First Order Motion Model for Image Animation\"\n",
"To try the demo, press the 2 play buttons in order and scroll to the bottom. Note that it may take several minutes to load."
]
},
{
"cell_type": "code",
"metadata": {
"id": "UCMFMJV7K-ag"
},
"source": [
"!pip install ffmpy &> /dev/null\n",
"!git init -q .\n",
"!git remote add origin https://github.com/AliaksandrSiarohin/first-order-model\n",
"!git pull -q origin master\n",
"!git clone -q https://github.com/graphemecluster/first-order-model-demo demo"
],
"execution_count": null,
"outputs": []
},
{
"cell_type": "code",
"metadata": {
"id": "Oxi6-riLOgnm"
},
"source": [
"import IPython.display\n",
"import PIL.Image\n",
"import cv2\n",
"import imageio\n",
"import io\n",
"import ipywidgets\n",
"import numpy\n",
"import os.path\n",
"import requests\n",
"import skimage.transform\n",
"import warnings\n",
"from base64 import b64encode\n",
"from demo import load_checkpoints, make_animation\n",
"from ffmpy import FFmpeg\n",
"from google.colab import files, output\n",
"from IPython.display import HTML, Javascript\n",
"from skimage import img_as_ubyte\n",
"warnings.filterwarnings(\"ignore\")\n",
"os.makedirs(\"user\", exist_ok=True)\n",
"\n",
"display(HTML(\"\"\"\n",
"\n",
"\"\"\"))\n",
"\n",
"def thumbnail(file):\n",
"\treturn imageio.get_reader(file, mode='I', format='FFMPEG').get_next_data()\n",
"\n",
"def create_image(i, j):\n",
"\timage_widget = ipywidgets.Image(\n",
"\t\tvalue=open('demo/images/%d%d.png' % (i, j), 'rb').read(),\n",
"\t\tformat='png'\n",
"\t)\n",
"\timage_widget.add_class('resource')\n",
"\timage_widget.add_class('resource-image')\n",
"\timage_widget.add_class('resource-image%d%d' % (i, j))\n",
"\treturn image_widget\n",
"\n",
"def create_video(i):\n",
"\tvideo_widget = ipywidgets.Image(\n",
"\t\tvalue=cv2.imencode('.png', cv2.cvtColor(thumbnail('demo/videos/%d.mp4' % i), cv2.COLOR_RGB2BGR))[1].tostring(),\n",
"\t\tformat='png'\n",
"\t)\n",
"\tvideo_widget.add_class('resource')\n",
"\tvideo_widget.add_class('resource-video')\n",
"\tvideo_widget.add_class('resource-video%d' % i)\n",
"\treturn video_widget\n",
"\n",
"def create_title(title):\n",
"\ttitle_widget = ipywidgets.Label(title)\n",
"\ttitle_widget.add_class('title')\n",
"\treturn title_widget\n",
"\n",
"def download_output(button):\n",
"\tcomplete.layout.display = 'none'\n",
"\tloading.layout.display = ''\n",
"\tfiles.download('output.mp4')\n",
"\tloading.layout.display = 'none'\n",
"\tcomplete.layout.display = ''\n",
"\n",
"def convert_output(button):\n",
"\tcomplete.layout.display = 'none'\n",
"\tloading.layout.display = ''\n",
"\tFFmpeg(inputs={'output.mp4': None}, outputs={'scaled.mp4': '-vf \"scale=1080x1080:flags=lanczos,pad=1920:1080:420:0\" -y'}).run()\n",
"\tfiles.download('scaled.mp4')\n",
"\tloading.layout.display = 'none'\n",
"\tcomplete.layout.display = ''\n",
"\n",
"def back_to_main(button):\n",
"\tcomplete.layout.display = 'none'\n",
"\tmain.layout.display = ''\n",
"\n",
"label_or = ipywidgets.Label('or')\n",
"label_or.add_class('label-or')\n",
"\n",
"image_titles = ['Peoples', 'Cartoons', 'Dolls', 'Game of Thrones', 'Statues']\n",
"image_lengths = [8, 4, 8, 9, 4]\n",
"\n",
"image_tab = ipywidgets.Tab()\n",
"image_tab.children = [ipywidgets.HBox([create_image(i, j) for j in range(length)]) for i, length in enumerate(image_lengths)]\n",
"for i, title in enumerate(image_titles):\n",
"\timage_tab.set_title(i, title)\n",
"\n",
"input_image_widget = ipywidgets.Output()\n",
"input_image_widget.add_class('input-widget')\n",
"upload_input_image_button = ipywidgets.FileUpload(accept='image/*', button_style='primary')\n",
"upload_input_image_button.add_class('input-button')\n",
"image_part = ipywidgets.HBox([\n",
"\tipywidgets.VBox([input_image_widget, upload_input_image_button]),\n",
"\tlabel_or,\n",
"\timage_tab\n",
"])\n",
"\n",
"video_tab = ipywidgets.Tab()\n",
"video_tab.children = [ipywidgets.HBox([create_video(i) for i in range(5)])]\n",
"video_tab.set_title(0, 'All Videos')\n",
"\n",
"input_video_widget = ipywidgets.Output()\n",
"input_video_widget.add_class('input-widget')\n",
"upload_input_video_button = ipywidgets.FileUpload(accept='video/*', button_style='primary')\n",
"upload_input_video_button.add_class('input-button')\n",
"video_part = ipywidgets.HBox([\n",
"\tipywidgets.VBox([input_video_widget, upload_input_video_button]),\n",
"\tlabel_or,\n",
"\tvideo_tab\n",
"])\n",
"\n",
"model = ipywidgets.Dropdown(\n",
"\tdescription=\"Model:\",\n",
"\toptions=[\n",
"\t\t'vox',\n",
"\t\t'vox-adv',\n",
"\t\t'taichi',\n",
"\t\t'taichi-adv',\n",
"\t\t'nemo',\n",
"\t\t'mgif',\n",
"\t\t'fashion',\n",
"\t\t'bair'\n",
"\t]\n",
")\n",
"warning = ipywidgets.HTML('Warning: Upload your own images and videos (see README)')\n",
"warning.add_class('warning')\n",
"model_part = ipywidgets.HBox([model, warning])\n",
"\n",
"relative = ipywidgets.Checkbox(description=\"Relative keypoint displacement (Inherit object proporions from the video)\", value=True)\n",
"adapt_movement_scale = ipywidgets.Checkbox(description=\"Adapt movement scale (Don’t touch unless you know want you are doing)\", value=True)\n",
"generate_button = ipywidgets.Button(description=\"Generate\", button_style='primary')\n",
"main = ipywidgets.VBox([\n",
"\tcreate_title('Choose Image'),\n",
"\timage_part,\n",
"\tcreate_title('Choose Video'),\n",
"\tvideo_part,\n",
"\tcreate_title('Settings'),\n",
"\tmodel_part,\n",
"\trelative,\n",
"\tadapt_movement_scale,\n",
"\tgenerate_button\n",
"])\n",
"\n",
"loader = ipywidgets.Label()\n",
"loader.add_class(\"loader\")\n",
"loading_label = ipywidgets.Label(\"This may take several minutes to process…\")\n",
"loading_label.add_class(\"loading-label\")\n",
"loading = ipywidgets.VBox([loader, loading_label])\n",
"loading.add_class('loading')\n",
"\n",
"output_widget = ipywidgets.Output()\n",
"output_widget.add_class('output-widget')\n",
"download = ipywidgets.Button(description='Download', button_style='primary')\n",
"download.add_class('output-button')\n",
"download.on_click(download_output)\n",
"convert = ipywidgets.Button(description='Convert to 1920×1080', button_style='primary')\n",
"convert.add_class('output-button')\n",
"convert.on_click(convert_output)\n",
"back = ipywidgets.Button(description='Back', button_style='primary')\n",
"back.add_class('output-button')\n",
"back.on_click(back_to_main)\n",
"\n",
"comparison_widget = ipywidgets.Output()\n",
"comparison_widget.add_class('comparison-widget')\n",
"comparison_label = ipywidgets.Label('Comparison')\n",
"comparison_label.add_class('comparison-label')\n",
"complete = ipywidgets.HBox([\n",
"\tipywidgets.VBox([output_widget, download, convert, back]),\n",
"\tipywidgets.VBox([comparison_widget, comparison_label])\n",
"])\n",
"\n",
"display(ipywidgets.VBox([main, loading, complete]))\n",
"display(Javascript(\"\"\"\n",
"var images, videos;\n",
"function deselectImages() {\n",
"\timages.forEach(function(item) {\n",
"\t\titem.classList.remove(\"selected\");\n",
"\t});\n",
"}\n",
"function deselectVideos() {\n",
"\tvideos.forEach(function(item) {\n",
"\t\titem.classList.remove(\"selected\");\n",
"\t});\n",
"}\n",
"function invokePython(func) {\n",
"\tgoogle.colab.kernel.invokeFunction(\"notebook.\" + func, [].slice.call(arguments, 1), {});\n",
"}\n",
"setTimeout(function() {\n",
"\t(images = [].slice.call(document.getElementsByClassName(\"resource-image\"))).forEach(function(item) {\n",
"\t\titem.addEventListener(\"click\", function() {\n",
"\t\t\tdeselectImages();\n",
"\t\t\titem.classList.add(\"selected\");\n",
"\t\t\tinvokePython(\"select_image\", item.className.match(/resource-image(\\d\\d)/)[1]);\n",
"\t\t});\n",
"\t});\n",
"\timages[0].classList.add(\"selected\");\n",
"\t(videos = [].slice.call(document.getElementsByClassName(\"resource-video\"))).forEach(function(item) {\n",
"\t\titem.addEventListener(\"click\", function() {\n",
"\t\t\tdeselectVideos();\n",
"\t\t\titem.classList.add(\"selected\");\n",
"\t\t\tinvokePython(\"select_video\", item.className.match(/resource-video(\\d)/)[1]);\n",
"\t\t});\n",
"\t});\n",
"\tvideos[0].classList.add(\"selected\");\n",
"}, 1000);\n",
"\"\"\"))\n",
"\n",
"selected_image = None\n",
"def select_image(filename):\n",
"\tglobal selected_image\n",
"\tselected_image = resize(PIL.Image.open('demo/images/%s.png' % filename).convert(\"RGB\"))\n",
"\tinput_image_widget.clear_output(wait=True)\n",
"\twith input_image_widget:\n",
"\t\tdisplay(HTML('Image'))\n",
"\tinput_image_widget.remove_class('uploaded')\n",
"output.register_callback(\"notebook.select_image\", select_image)\n",
"\n",
"selected_video = None\n",
"def select_video(filename):\n",
"\tglobal selected_video\n",
"\tselected_video = 'demo/videos/%s.mp4' % filename\n",
"\tinput_video_widget.clear_output(wait=True)\n",
"\twith input_video_widget:\n",
"\t\tdisplay(HTML('Video'))\n",
"\tinput_video_widget.remove_class('uploaded')\n",
"output.register_callback(\"notebook.select_video\", select_video)\n",
"\n",
"def resize(image, size=(256, 256)):\n",
" w, h = image.size\n",
" d = min(w, h)\n",
" r = ((w - d) // 2, (h - d) // 2, (w + d) // 2, (h + d) // 2)\n",
" return image.resize(size, resample=PIL.Image.LANCZOS, box=r)\n",
"\n",
"def upload_image(change):\n",
"\tglobal selected_image\n",
"\tfor name, file_info in upload_input_image_button.value.items():\n",
"\t\tcontent = file_info['content']\n",
"\tif content is not None:\n",
"\t\tselected_image = resize(PIL.Image.open(io.BytesIO(content)).convert(\"RGB\"))\n",
"\t\tinput_image_widget.clear_output(wait=True)\n",
"\t\twith input_image_widget:\n",
"\t\t\tdisplay(selected_image)\n",
"\t\tinput_image_widget.add_class('uploaded')\n",
"\t\tdisplay(Javascript('deselectImages()'))\n",
"upload_input_image_button.observe(upload_image, names='value')\n",
"\n",
"def upload_video(change):\n",
"\tglobal selected_video\n",
"\tfor name, file_info in upload_input_video_button.value.items():\n",
"\t\tcontent = file_info['content']\n",
"\tif content is not None:\n",
"\t\tselected_video = 'user/' + name\n",
"\t\tpreview = resize(PIL.Image.fromarray(thumbnail(content)).convert(\"RGB\"))\n",
"\t\tinput_video_widget.clear_output(wait=True)\n",
"\t\twith input_video_widget:\n",
"\t\t\tdisplay(preview)\n",
"\t\tinput_video_widget.add_class('uploaded')\n",
"\t\tdisplay(Javascript('deselectVideos()'))\n",
"\t\twith open(selected_video, 'wb') as video:\n",
"\t\t\tvideo.write(content)\n",
"upload_input_video_button.observe(upload_video, names='value')\n",
"\n",
"def change_model(change):\n",
"\tif model.value.startswith('vox'):\n",
"\t\twarning.remove_class('warn')\n",
"\telse:\n",
"\t\twarning.add_class('warn')\n",
"model.observe(change_model, names='value')\n",
"\n",
"def generate(button):\n",
"\tmain.layout.display = 'none'\n",
"\tloading.layout.display = ''\n",
"\tfilename = model.value + ('' if model.value == 'fashion' else '-cpk') + '.pth.tar'\n",
"\tif not os.path.isfile(filename):\n",
"\t\tdownload = requests.get(requests.get('https://cloud-api.yandex.net/v1/disk/public/resources/download?public_key=https://yadi.sk/d/lEw8uRm140L_eQ&path=/' + filename).json().get('href'))\n",
"\t\twith open(filename, 'wb') as checkpoint:\n",
"\t\t\tcheckpoint.write(download.content)\n",
"\treader = imageio.get_reader(selected_video, mode='I', format='FFMPEG')\n",
"\tfps = reader.get_meta_data()['fps']\n",
"\tdriving_video = []\n",
"\tfor frame in reader:\n",
"\t\tdriving_video.append(frame)\n",
"\tgenerator, kp_detector = load_checkpoints(config_path='config/%s-256.yaml' % model.value, checkpoint_path=filename)\n",
"\tpredictions = make_animation(\n",
"\t\tskimage.transform.resize(numpy.asarray(selected_image), (256, 256)),\n",
"\t\t[skimage.transform.resize(frame, (256, 256)) for frame in driving_video],\n",
"\t\tgenerator,\n",
"\t\tkp_detector,\n",
"\t\trelative=relative.value,\n",
"\t\tadapt_movement_scale=adapt_movement_scale.value\n",
"\t)\n",
"\tif selected_video.startswith('user/') or selected_video == 'demo/videos/0.mp4':\n",
"\t\timageio.mimsave('temp.mp4', [img_as_ubyte(frame) for frame in predictions], fps=fps)\n",
"\t\tFFmpeg(inputs={'temp.mp4': None, selected_video: None}, outputs={'output.mp4': '-c copy -y'}).run()\n",
"\telse:\n",
"\t\timageio.mimsave('output.mp4', [img_as_ubyte(frame) for frame in predictions], fps=fps)\n",
"\tloading.layout.display = 'none'\n",
"\tcomplete.layout.display = ''\n",
"\twith output_widget:\n",
"\t\tdisplay(HTML('' % b64encode(open('output.mp4', 'rb').read()).decode()))\n",
"\twith comparison_widget:\n",
"\t\tdisplay(HTML('' % b64encode(open(selected_video, 'rb').read()).decode()))\n",
"\tdisplay(Javascript(\"\"\"\n",
"\t(function(left, right) {\n",
"\t\tleft.addEventListener(\"play\", function() {\n",
"\t\t\tright.play();\n",
"\t\t});\n",
"\t\tleft.addEventListener(\"pause\", function() {\n",
"\t\t\tright.pause();\n",
"\t\t});\n",
"\t\tleft.addEventListener(\"seeking\", function() {\n",
"\t\t\tright.currentTime = left.currentTime;\n",
"\t\t});\n",
"\t})(document.getElementById(\"left\"), document.getElementById(\"right\"));\n",
"\t\"\"\"))\n",
"\t\n",
"generate_button.on_click(generate)\n",
"\n",
"loading.layout.display = 'none'\n",
"complete.layout.display = 'none'\n",
"select_image('00')\n",
"select_video('0')"
],
"execution_count": null,
"outputs": []
}
]
}