|
@@ -0,0 +1,3269 @@
|
|
|
+{
|
|
|
+ "cells": [
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {
|
|
|
+ "id": "TitleTop"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "# Disco Diffusion v5.1 - Now with Turbo\n",
|
|
|
+ "\n",
|
|
|
+ "In case of confusion, Disco is the name of this notebook edit. The diffusion model in use is Katherine Crowson's fine-tuned 512x512 model\n",
|
|
|
+ "\n",
|
|
|
+ "For issues, join the [Disco Diffusion Discord](https://discord.gg/msEZBy4HxA) or message us on twitter at [@somnai_dreams](https://twitter.com/somnai_dreams) or [@gandamu](https://twitter.com/gandamu_ml)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {
|
|
|
+ "id": "CreditsChTop"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "### Credits & Changelog ⬇️"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {
|
|
|
+ "id": "Credits"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "#### Credits\n",
|
|
|
+ "\n",
|
|
|
+ "Original notebook by Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). It uses either OpenAI's 256x256 unconditional ImageNet or Katherine Crowson's fine-tuned 512x512 diffusion model (https://github.com/openai/guided-diffusion), together with CLIP (https://github.com/openai/CLIP) to connect text prompts with images.\n",
|
|
|
+ "\n",
|
|
|
+ "Modified by Daniel Russell (https://github.com/russelldc, https://twitter.com/danielrussruss) to include (hopefully) optimal params for quick generations in 15-100 timesteps rather than 1000, as well as more robust augmentations.\n",
|
|
|
+ "\n",
|
|
|
+ "Further improvements from Dango233 and nsheppard helped improve the quality of diffusion in general, and especially so for shorter runs like this notebook aims to achieve.\n",
|
|
|
+ "\n",
|
|
|
+ "Vark added code to load in multiple Clip models at once, which all prompts are evaluated against, which may greatly improve accuracy.\n",
|
|
|
+ "\n",
|
|
|
+ "The latest zoom, pan, rotation, and keyframes features were taken from Chigozie Nri's VQGAN Zoom Notebook (https://github.com/chigozienri, https://twitter.com/chigozienri)\n",
|
|
|
+ "\n",
|
|
|
+ "Advanced DangoCutn Cutout method is also from Dango223.\n",
|
|
|
+ "\n",
|
|
|
+ "--\n",
|
|
|
+ "\n",
|
|
|
+ "Disco:\n",
|
|
|
+ "\n",
|
|
|
+ "Somnai (https://twitter.com/Somnai_dreams) added Diffusion Animation techniques, QoL improvements and various implementations of tech and techniques, mostly listed in the changelog below.\n",
|
|
|
+ "\n",
|
|
|
+ "3D animation implementation added by Adam Letts (https://twitter.com/gandamu_ml) in collaboration with Somnai.\n",
|
|
|
+ "\n",
|
|
|
+ "Turbo feature by Chris Allen (https://twitter.com/zippy731)"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {
|
|
|
+ "id": "LicenseTop"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "#### License"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {
|
|
|
+ "id": "License"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "Licensed under the MIT License\n",
|
|
|
+ "\n",
|
|
|
+ "Copyright (c) 2021 Katherine Crowson \n",
|
|
|
+ "\n",
|
|
|
+ "Permission is hereby granted, free of charge, to any person obtaining a copy\n",
|
|
|
+ "of this software and associated documentation files (the \"Software\"), to deal\n",
|
|
|
+ "in the Software without restriction, including without limitation the rights\n",
|
|
|
+ "to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n",
|
|
|
+ "copies of the Software, and to permit persons to whom the Software is\n",
|
|
|
+ "furnished to do so, subject to the following conditions:\n",
|
|
|
+ "\n",
|
|
|
+ "The above copyright notice and this permission notice shall be included in\n",
|
|
|
+ "all copies or substantial portions of the Software.\n",
|
|
|
+ "\n",
|
|
|
+ "THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n",
|
|
|
+ "IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n",
|
|
|
+ "FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n",
|
|
|
+ "AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n",
|
|
|
+ "LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n",
|
|
|
+ "OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n",
|
|
|
+ "THE SOFTWARE.\n",
|
|
|
+ "\n",
|
|
|
+ "--\n",
|
|
|
+ "\n",
|
|
|
+ "MIT License\n",
|
|
|
+ "\n",
|
|
|
+ "Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab)\n",
|
|
|
+ "\n",
|
|
|
+ "Permission is hereby granted, free of charge, to any person obtaining a copy\n",
|
|
|
+ "of this software and associated documentation files (the \"Software\"), to deal\n",
|
|
|
+ "in the Software without restriction, including without limitation the rights\n",
|
|
|
+ "to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n",
|
|
|
+ "copies of the Software, and to permit persons to whom the Software is\n",
|
|
|
+ "furnished to do so, subject to the following conditions:\n",
|
|
|
+ "\n",
|
|
|
+ "The above copyright notice and this permission notice shall be included in all\n",
|
|
|
+ "copies or substantial portions of the Software.\n",
|
|
|
+ "\n",
|
|
|
+ "THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n",
|
|
|
+ "IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n",
|
|
|
+ "FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n",
|
|
|
+ "AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n",
|
|
|
+ "LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n",
|
|
|
+ "OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n",
|
|
|
+ "SOFTWARE.\n",
|
|
|
+ "\n",
|
|
|
+ "--\n",
|
|
|
+ "\n",
|
|
|
+ "Licensed under the MIT License\n",
|
|
|
+ "\n",
|
|
|
+ "Copyright (c) 2021 Maxwell Ingham\n",
|
|
|
+ "\n",
|
|
|
+ "Copyright (c) 2022 Adam Letts \n",
|
|
|
+ "\n",
|
|
|
+ "Permission is hereby granted, free of charge, to any person obtaining a copy\n",
|
|
|
+ "of this software and associated documentation files (the \"Software\"), to deal\n",
|
|
|
+ "in the Software without restriction, including without limitation the rights\n",
|
|
|
+ "to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n",
|
|
|
+ "copies of the Software, and to permit persons to whom the Software is\n",
|
|
|
+ "furnished to do so, subject to the following conditions:\n",
|
|
|
+ "\n",
|
|
|
+ "The above copyright notice and this permission notice shall be included in\n",
|
|
|
+ "all copies or substantial portions of the Software.\n",
|
|
|
+ "\n",
|
|
|
+ "THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n",
|
|
|
+ "IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n",
|
|
|
+ "FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n",
|
|
|
+ "AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n",
|
|
|
+ "LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n",
|
|
|
+ "OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n",
|
|
|
+ "THE SOFTWARE."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {
|
|
|
+ "id": "ChangelogTop"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "#### Changelog"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "metadata": {
|
|
|
+ "cellView": "form",
|
|
|
+ "id": "Changelog"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "#@title <- View Changelog\n",
|
|
|
+ "skip_for_run_all = True #@param {type: 'boolean'}\n",
|
|
|
+ "\n",
|
|
|
+ "if skip_for_run_all == False:\n",
|
|
|
+ " print(\n",
|
|
|
+ " '''\n",
|
|
|
+ " v1 Update: Oct 29th 2021 - Somnai\n",
|
|
|
+ "\n",
|
|
|
+ " QoL improvements added by Somnai (@somnai_dreams), including user friendly UI, settings+prompt saving and improved google drive folder organization.\n",
|
|
|
+ "\n",
|
|
|
+ " v1.1 Update: Nov 13th 2021 - Somnai\n",
|
|
|
+ "\n",
|
|
|
+ " Now includes sizing options, intermediate saves and fixed image prompts and perlin inits. unexposed batch option since it doesn't work\n",
|
|
|
+ "\n",
|
|
|
+ " v2 Update: Nov 22nd 2021 - Somnai\n",
|
|
|
+ "\n",
|
|
|
+ " Initial addition of Katherine Crowson's Secondary Model Method (https://colab.research.google.com/drive/1mpkrhOjoyzPeSWy2r7T8EYRaU7amYOOi#scrollTo=X5gODNAMEUCR)\n",
|
|
|
+ "\n",
|
|
|
+ " Noticed settings were saving with the wrong name so corrected it. Let me know if you preferred the old scheme.\n",
|
|
|
+ "\n",
|
|
|
+ " v3 Update: Dec 24th 2021 - Somnai\n",
|
|
|
+ "\n",
|
|
|
+ " Implemented Dango's advanced cutout method\n",
|
|
|
+ "\n",
|
|
|
+ " Added SLIP models, thanks to NeuralDivergent\n",
|
|
|
+ "\n",
|
|
|
+ " Fixed issue with NaNs resulting in black images, with massive help and testing from @Softology\n",
|
|
|
+ "\n",
|
|
|
+ " Perlin now changes properly within batches (not sure where this perlin_regen code came from originally, but thank you)\n",
|
|
|
+ "\n",
|
|
|
+ " v4 Update: Jan 2021 - Somnai\n",
|
|
|
+ "\n",
|
|
|
+ " Implemented Diffusion Zooming\n",
|
|
|
+ "\n",
|
|
|
+ " Added Chigozie keyframing\n",
|
|
|
+ "\n",
|
|
|
+ " Made a bunch of edits to processes\n",
|
|
|
+ " \n",
|
|
|
+ " v4.1 Update: Jan 14th 2021 - Somnai\n",
|
|
|
+ "\n",
|
|
|
+ " Added video input mode\n",
|
|
|
+ "\n",
|
|
|
+ " Added license that somehow went missing\n",
|
|
|
+ "\n",
|
|
|
+ " Added improved prompt keyframing, fixed image_prompts and multiple prompts\n",
|
|
|
+ "\n",
|
|
|
+ " Improved UI\n",
|
|
|
+ "\n",
|
|
|
+ " Significant under the hood cleanup and improvement\n",
|
|
|
+ "\n",
|
|
|
+ " Refined defaults for each mode\n",
|
|
|
+ "\n",
|
|
|
+ " Added latent-diffusion SuperRes for sharpening\n",
|
|
|
+ "\n",
|
|
|
+ " Added resume run mode\n",
|
|
|
+ "\n",
|
|
|
+ " v4.9 Update: Feb 5th 2022 - gandamu / Adam Letts\n",
|
|
|
+ "\n",
|
|
|
+ " Added 3D\n",
|
|
|
+ "\n",
|
|
|
+ " Added brightness corrections to prevent animation from steadily going dark over time\n",
|
|
|
+ "\n",
|
|
|
+ " v4.91 Update: Feb 19th 2022 - gandamu / Adam Letts\n",
|
|
|
+ "\n",
|
|
|
+ " Cleaned up 3D implementation and made associated args accessible via Colab UI elements\n",
|
|
|
+ "\n",
|
|
|
+ " v4.92 Update: Feb 20th 2022 - gandamu / Adam Letts\n",
|
|
|
+ "\n",
|
|
|
+ " Separated transform code\n",
|
|
|
+ "\n",
|
|
|
+ " v5.01 Update: Mar 10th 2022 - gandamu / Adam Letts\n",
|
|
|
+ "\n",
|
|
|
+ " IPython magic commands replaced by Python code\n",
|
|
|
+ "\n",
|
|
|
+ " v5.1 Update: Mar 30th 2022 - zippy / Chris Allen and gandamu / Adam Letts\n",
|
|
|
+ "\n",
|
|
|
+ " Integrated Turbo+Smooth features from Disco Diffusion Turbo -- just the implementation, without its defaults.\n",
|
|
|
+ "\n",
|
|
|
+ " Implemented resume of turbo animations in such a way that it's now possible to resume from different batch folders and batch numbers.\n",
|
|
|
+ "\n",
|
|
|
+ " 3D rotation parameter units are now degrees (rather than radians)\n",
|
|
|
+ "\n",
|
|
|
+ " Corrected name collision in sampling_mode (now diffusion_sampling_mode for plms/ddim, and sampling_mode for 3D transform sampling)\n",
|
|
|
+ "\n",
|
|
|
+ " Added video_init_seed_continuity option to make init video animations more continuous\n",
|
|
|
+ "\n",
|
|
|
+ " '''\n",
|
|
|
+ " )\n"
|
|
|
+ ],
|
|
|
+ "outputs": [],
|
|
|
+ "execution_count": null
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {
|
|
|
+ "id": "TutorialTop"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "# Tutorial"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {
|
|
|
+ "id": "DiffusionSet"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "**Diffusion settings (Defaults are heavily outdated)**\n",
|
|
|
+ "---\n",
|
|
|
+ "\n",
|
|
|
+ "This section is outdated as of v2\n",
|
|
|
+ "\n",
|
|
|
+ "Setting | Description | Default\n",
|
|
|
+ "--- | --- | ---\n",
|
|
|
+ "**Your vision:**\n",
|
|
|
+ "`text_prompts` | A description of what you'd like the machine to generate. Think of it like writing the caption below your image on a website. | N/A\n",
|
|
|
+ "`image_prompts` | Think of these images more as a description of their contents. | N/A\n",
|
|
|
+ "**Image quality:**\n",
|
|
|
+ "`clip_guidance_scale` | Controls how much the image should look like the prompt. | 1000\n",
|
|
|
+ "`tv_scale` | Controls the smoothness of the final output. | 150\n",
|
|
|
+ "`range_scale` | Controls how far out of range RGB values are allowed to be. | 150\n",
|
|
|
+ "`sat_scale` | Controls how much saturation is allowed. From nshepperd's JAX notebook. | 0\n",
|
|
|
+ "`cutn` | Controls how many crops to take from the image. | 16\n",
|
|
|
+ "`cutn_batches` | Accumulate CLIP gradient from multiple batches of cuts | 2\n",
|
|
|
+ "**Init settings:**\n",
|
|
|
+ "`init_image` | URL or local path | None\n",
|
|
|
+ "`init_scale` | This enhances the effect of the init image, a good value is 1000 | 0\n",
|
|
|
+ "`skip_steps Controls the starting point along the diffusion timesteps | 0\n",
|
|
|
+ "`perlin_init` | Option to start with random perlin noise | False\n",
|
|
|
+ "`perlin_mode` | ('gray', 'color') | 'mixed'\n",
|
|
|
+ "**Advanced:**\n",
|
|
|
+ "`skip_augs` |Controls whether to skip torchvision augmentations | False\n",
|
|
|
+ "`randomize_class` |Controls whether the imagenet class is randomly changed each iteration | True\n",
|
|
|
+ "`clip_denoised` |Determines whether CLIP discriminates a noisy or denoised image | False\n",
|
|
|
+ "`clamp_grad` |Experimental: Using adaptive clip grad in the cond_fn | True\n",
|
|
|
+ "`seed` | Choose a random seed and print it at end of run for reproduction | random_seed\n",
|
|
|
+ "`fuzzy_prompt` | Controls whether to add multiple noisy prompts to the prompt losses | False\n",
|
|
|
+ "`rand_mag` |Controls the magnitude of the random noise | 0.1\n",
|
|
|
+ "`eta` | DDIM hyperparameter | 0.5\n",
|
|
|
+ "\n",
|
|
|
+ "..\n",
|
|
|
+ "\n",
|
|
|
+ "**Model settings**\n",
|
|
|
+ "---\n",
|
|
|
+ "\n",
|
|
|
+ "Setting | Description | Default\n",
|
|
|
+ "--- | --- | ---\n",
|
|
|
+ "**Diffusion:**\n",
|
|
|
+ "`timestep_respacing` | Modify this value to decrease the number of timesteps. | ddim100\n",
|
|
|
+ "`diffusion_steps` || 1000\n",
|
|
|
+ "**Diffusion:**\n",
|
|
|
+ "`clip_models` | Models of CLIP to load. Typically the more, the better but they all come at a hefty VRAM cost. | ViT-B/32, ViT-B/16, RN50x4"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {
|
|
|
+ "id": "SetupTop"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "# 1. Set Up"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "metadata": {
|
|
|
+ "cellView": "form",
|
|
|
+ "id": "CheckGPU"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "#@title 1.1 Check GPU Status\n",
|
|
|
+ "import subprocess\n",
|
|
|
+ "simple_nvidia_smi_display = False#@param {type:\"boolean\"}\n",
|
|
|
+ "if simple_nvidia_smi_display:\n",
|
|
|
+ " #!nvidia-smi\n",
|
|
|
+ " nvidiasmi_output = subprocess.run(['nvidia-smi', '-L'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
|
|
+ " print(nvidiasmi_output)\n",
|
|
|
+ "else:\n",
|
|
|
+ " #!nvidia-smi -i 0 -e 0\n",
|
|
|
+ " nvidiasmi_output = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
|
|
+ " print(nvidiasmi_output)\n",
|
|
|
+ " nvidiasmi_ecc_note = subprocess.run(['nvidia-smi', '-i', '0', '-e', '0'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
|
|
+ " print(nvidiasmi_ecc_note)"
|
|
|
+ ],
|
|
|
+ "outputs": [],
|
|
|
+ "execution_count": null
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "metadata": {
|
|
|
+ "cellView": "form",
|
|
|
+ "id": "PrepFolders"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "#@title 1.2 Prepare Folders\n",
|
|
|
+ "import subprocess\n",
|
|
|
+ "import sys\n",
|
|
|
+ "import ipykernel\n",
|
|
|
+ "\n",
|
|
|
+ "def gitclone(url):\n",
|
|
|
+ " res = subprocess.run(['git', 'clone', url], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
|
|
+ " print(res)\n",
|
|
|
+ "\n",
|
|
|
+ "def pipi(modulestr):\n",
|
|
|
+ " res = subprocess.run(['pip', 'install', modulestr], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
|
|
+ " print(res)\n",
|
|
|
+ "\n",
|
|
|
+ "def pipie(modulestr):\n",
|
|
|
+ " res = subprocess.run(['git', 'install', '-e', modulestr], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
|
|
+ " print(res)\n",
|
|
|
+ "\n",
|
|
|
+ "def wget(url, outputdir):\n",
|
|
|
+ " res = subprocess.run(['wget', url, '-P', f'{outputdir}'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
|
|
+ " print(res)\n",
|
|
|
+ "\n",
|
|
|
+ "try:\n",
|
|
|
+ " from google.colab import drive\n",
|
|
|
+ " print(\"Google Colab detected. Using Google Drive.\")\n",
|
|
|
+ " is_colab = True\n",
|
|
|
+ " #@markdown If you connect your Google Drive, you can save the final image of each run on your drive.\n",
|
|
|
+ " google_drive = True #@param {type:\"boolean\"}\n",
|
|
|
+ " #@markdown Click here if you'd like to save the diffusion model checkpoint file to (and/or load from) your Google Drive:\n",
|
|
|
+ " save_models_to_google_drive = True #@param {type:\"boolean\"}\n",
|
|
|
+ "except:\n",
|
|
|
+ " is_colab = False\n",
|
|
|
+ " google_drive = False\n",
|
|
|
+ " save_models_to_google_drive = False\n",
|
|
|
+ " print(\"Google Colab not detected.\")\n",
|
|
|
+ "\n",
|
|
|
+ "if is_colab:\n",
|
|
|
+ " if google_drive is True:\n",
|
|
|
+ " drive.mount('/content/drive')\n",
|
|
|
+ " root_path = '/content/drive/MyDrive/AI/Disco_Diffusion'\n",
|
|
|
+ " else:\n",
|
|
|
+ " root_path = '/content'\n",
|
|
|
+ "else:\n",
|
|
|
+ " root_path = '.'\n",
|
|
|
+ "\n",
|
|
|
+ "import os\n",
|
|
|
+ "def createPath(filepath):\n",
|
|
|
+ " os.makedirs(filepath, exist_ok=True)\n",
|
|
|
+ "\n",
|
|
|
+ "initDirPath = f'{root_path}/init_images'\n",
|
|
|
+ "createPath(initDirPath)\n",
|
|
|
+ "outDirPath = f'{root_path}/images_out'\n",
|
|
|
+ "createPath(outDirPath)\n",
|
|
|
+ "\n",
|
|
|
+ "if is_colab:\n",
|
|
|
+ " if google_drive and not save_models_to_google_drive or not google_drive:\n",
|
|
|
+ " model_path = '/content/model'\n",
|
|
|
+ " createPath(model_path)\n",
|
|
|
+ " if google_drive and save_models_to_google_drive:\n",
|
|
|
+ " model_path = f'{root_path}/model'\n",
|
|
|
+ " createPath(model_path)\n",
|
|
|
+ "else:\n",
|
|
|
+ " model_path = f'{root_path}/model'\n",
|
|
|
+ " createPath(model_path)\n",
|
|
|
+ "\n",
|
|
|
+ "# libraries = f'{root_path}/libraries'\n",
|
|
|
+ "# createPath(libraries)"
|
|
|
+ ],
|
|
|
+ "outputs": [],
|
|
|
+ "execution_count": null
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "metadata": {
|
|
|
+ "cellView": "form",
|
|
|
+ "id": "InstallDeps"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "#@title ### 1.3 Install and import dependencies\n",
|
|
|
+ "\n",
|
|
|
+ "import pathlib, shutil\n",
|
|
|
+ "\n",
|
|
|
+ "if not is_colab:\n",
|
|
|
+ " # If running locally, there's a good chance your env will need this in order to not crash upon np.matmul() or similar operations.\n",
|
|
|
+ " os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'\n",
|
|
|
+ "\n",
|
|
|
+ "PROJECT_DIR = os.path.abspath(os.getcwd())\n",
|
|
|
+ "USE_ADABINS = True\n",
|
|
|
+ "\n",
|
|
|
+ "if is_colab:\n",
|
|
|
+ " if google_drive is not True:\n",
|
|
|
+ " root_path = f'/content'\n",
|
|
|
+ " model_path = '/content/models' \n",
|
|
|
+ "else:\n",
|
|
|
+ " root_path = f'.'\n",
|
|
|
+ " model_path = f'{root_path}/model'\n",
|
|
|
+ "\n",
|
|
|
+ "model_256_downloaded = False\n",
|
|
|
+ "model_512_downloaded = False\n",
|
|
|
+ "model_secondary_downloaded = False\n",
|
|
|
+ "\n",
|
|
|
+ "if is_colab:\n",
|
|
|
+ " gitclone(\"https://github.com/openai/CLIP\")\n",
|
|
|
+ " #gitclone(\"https://github.com/facebookresearch/SLIP.git\")\n",
|
|
|
+ " gitclone(\"https://github.com/crowsonkb/guided-diffusion\")\n",
|
|
|
+ " gitclone(\"https://github.com/assafshocher/ResizeRight.git\")\n",
|
|
|
+ " gitclone(\"https://github.com/MSFTserver/pytorch3d-lite.git\")\n",
|
|
|
+ " pipie(\"./CLIP\")\n",
|
|
|
+ " pipie(\"./guided-diffusion\")\n",
|
|
|
+ " multipip_res = subprocess.run(['pip', 'install', 'lpips', 'datetime', 'timm', 'ftfy'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
|
|
+ " print(multipip_res)\n",
|
|
|
+ " subprocess.run(['apt', 'install', 'imagemagick'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
|
|
+ " gitclone(\"https://github.com/isl-org/MiDaS.git\")\n",
|
|
|
+ " gitclone(\"https://github.com/alembics/disco-diffusion.git\")\n",
|
|
|
+ " pipi(\"pytorch-lightning\")\n",
|
|
|
+ " pipi(\"omegaconf\")\n",
|
|
|
+ " pipi(\"einops\")\n",
|
|
|
+ " # Rename a file to avoid a name conflict..\n",
|
|
|
+ " try:\n",
|
|
|
+ " os.rename(\"MiDaS/utils.py\", \"MiDaS/midas_utils.py\")\n",
|
|
|
+ " shutil.copyfile(\"disco-diffusion/disco_xform_utils.py\", \"disco_xform_utils.py\")\n",
|
|
|
+ " except:\n",
|
|
|
+ " pass\n",
|
|
|
+ "\n",
|
|
|
+ "if not os.path.exists(f'{model_path}'):\n",
|
|
|
+ " pathlib.Path(model_path).mkdir(parents=True, exist_ok=True)\n",
|
|
|
+ "if not os.path.exists(f'{model_path}/dpt_large-midas-2f21e586.pt'):\n",
|
|
|
+ " wget(\"https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt\", model_path)\n",
|
|
|
+ "\n",
|
|
|
+ "import sys\n",
|
|
|
+ "import torch\n",
|
|
|
+ "\n",
|
|
|
+ "# sys.path.append('./SLIP')\n",
|
|
|
+ "sys.path.append('./pytorch3d-lite')\n",
|
|
|
+ "sys.path.append('./ResizeRight')\n",
|
|
|
+ "sys.path.append('./MiDaS')\n",
|
|
|
+ "from dataclasses import dataclass\n",
|
|
|
+ "from functools import partial\n",
|
|
|
+ "import cv2\n",
|
|
|
+ "import pandas as pd\n",
|
|
|
+ "import gc\n",
|
|
|
+ "import io\n",
|
|
|
+ "import math\n",
|
|
|
+ "import timm\n",
|
|
|
+ "from IPython import display\n",
|
|
|
+ "import lpips\n",
|
|
|
+ "from PIL import Image, ImageOps\n",
|
|
|
+ "import requests\n",
|
|
|
+ "from glob import glob\n",
|
|
|
+ "import json\n",
|
|
|
+ "from types import SimpleNamespace\n",
|
|
|
+ "from torch import nn\n",
|
|
|
+ "from torch.nn import functional as F\n",
|
|
|
+ "import torchvision.transforms as T\n",
|
|
|
+ "import torchvision.transforms.functional as TF\n",
|
|
|
+ "from tqdm.notebook import tqdm\n",
|
|
|
+ "sys.path.append('./CLIP')\n",
|
|
|
+ "sys.path.append('./guided-diffusion')\n",
|
|
|
+ "import clip\n",
|
|
|
+ "from resize_right import resize\n",
|
|
|
+ "# from models import SLIP_VITB16, SLIP, SLIP_VITL16\n",
|
|
|
+ "from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults\n",
|
|
|
+ "from datetime import datetime\n",
|
|
|
+ "import numpy as np\n",
|
|
|
+ "import matplotlib.pyplot as plt\n",
|
|
|
+ "import random\n",
|
|
|
+ "from ipywidgets import Output\n",
|
|
|
+ "import hashlib\n",
|
|
|
+ "\n",
|
|
|
+ "#SuperRes\n",
|
|
|
+ "if is_colab:\n",
|
|
|
+ " gitclone(\"https://github.com/CompVis/latent-diffusion.git\")\n",
|
|
|
+ " gitclone(\"https://github.com/CompVis/taming-transformers\")\n",
|
|
|
+ " pipie(\"./taming-transformers\")\n",
|
|
|
+ " pipi(\"ipywidgets omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops wandb\")\n",
|
|
|
+ "\n",
|
|
|
+ "#SuperRes\n",
|
|
|
+ "import ipywidgets as widgets\n",
|
|
|
+ "import os\n",
|
|
|
+ "sys.path.append(\".\")\n",
|
|
|
+ "sys.path.append('./taming-transformers')\n",
|
|
|
+ "from taming.models import vqgan # checking correct import from taming\n",
|
|
|
+ "from torchvision.datasets.utils import download_url\n",
|
|
|
+ "\n",
|
|
|
+ "if is_colab:\n",
|
|
|
+ " os.chdir('/content/latent-diffusion')\n",
|
|
|
+ "else:\n",
|
|
|
+ " #os.chdir('latent-diffusion')\n",
|
|
|
+ " sys.path.append('latent-diffusion')\n",
|
|
|
+ "from functools import partial\n",
|
|
|
+ "from ldm.util import instantiate_from_config\n",
|
|
|
+ "from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like\n",
|
|
|
+ "# from ldm.models.diffusion.ddim import DDIMSampler\n",
|
|
|
+ "from ldm.util import ismap\n",
|
|
|
+ "if is_colab:\n",
|
|
|
+ " os.chdir('/content')\n",
|
|
|
+ " from google.colab import files\n",
|
|
|
+ "else:\n",
|
|
|
+ " os.chdir(f'{PROJECT_DIR}')\n",
|
|
|
+ "from IPython.display import Image as ipyimg\n",
|
|
|
+ "from numpy import asarray\n",
|
|
|
+ "from einops import rearrange, repeat\n",
|
|
|
+ "import torch, torchvision\n",
|
|
|
+ "import time\n",
|
|
|
+ "from omegaconf import OmegaConf\n",
|
|
|
+ "import warnings\n",
|
|
|
+ "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
|
|
|
+ "\n",
|
|
|
+ "# AdaBins stuff\n",
|
|
|
+ "if USE_ADABINS:\n",
|
|
|
+ " if is_colab:\n",
|
|
|
+ " gitclone(\"https://github.com/shariqfarooq123/AdaBins.git\")\n",
|
|
|
+ " if not os.path.exists(f'{model_path}/AdaBins_nyu.pt'):\n",
|
|
|
+ " wget(\"https://cloudflare-ipfs.com/ipfs/Qmd2mMnDLWePKmgfS8m6ntAg4nhV5VkUyAydYBp8cWWeB7/AdaBins_nyu.pt\", model_path)\n",
|
|
|
+ " pathlib.Path(\"pretrained\").mkdir(parents=True, exist_ok=True)\n",
|
|
|
+ " shutil.copyfile(f\"{model_path}/AdaBins_nyu.pt\", \"pretrained/AdaBins_nyu.pt\")\n",
|
|
|
+ " sys.path.append('./AdaBins')\n",
|
|
|
+ " from infer import InferenceHelper\n",
|
|
|
+ " MAX_ADABINS_AREA = 500000\n",
|
|
|
+ "\n",
|
|
|
+ "import torch\n",
|
|
|
+ "DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
|
|
|
+ "print('Using device:', DEVICE)\n",
|
|
|
+ "device = DEVICE # At least one of the modules expects this name..\n",
|
|
|
+ "\n",
|
|
|
+ "if torch.cuda.get_device_capability(DEVICE) == (8,0): ## A100 fix thanks to Emad\n",
|
|
|
+ " print('Disabling CUDNN for A100 gpu', file=sys.stderr)\n",
|
|
|
+ " torch.backends.cudnn.enabled = False"
|
|
|
+ ],
|
|
|
+ "outputs": [],
|
|
|
+ "execution_count": null
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "metadata": {
|
|
|
+ "cellView": "form",
|
|
|
+ "id": "DefMidasFns"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "#@title ### 1.4 Define Midas functions\n",
|
|
|
+ "\n",
|
|
|
+ "from midas.dpt_depth import DPTDepthModel\n",
|
|
|
+ "from midas.midas_net import MidasNet\n",
|
|
|
+ "from midas.midas_net_custom import MidasNet_small\n",
|
|
|
+ "from midas.transforms import Resize, NormalizeImage, PrepareForNet\n",
|
|
|
+ "\n",
|
|
|
+ "# Initialize MiDaS depth model.\n",
|
|
|
+ "# It remains resident in VRAM and likely takes around 2GB VRAM.\n",
|
|
|
+ "# You could instead initialize it for each frame (and free it after each frame) to save VRAM.. but initializing it is slow.\n",
|
|
|
+ "default_models = {\n",
|
|
|
+ " \"midas_v21_small\": f\"{model_path}/midas_v21_small-70d6b9c8.pt\",\n",
|
|
|
+ " \"midas_v21\": f\"{model_path}/midas_v21-f6b98070.pt\",\n",
|
|
|
+ " \"dpt_large\": f\"{model_path}/dpt_large-midas-2f21e586.pt\",\n",
|
|
|
+ " \"dpt_hybrid\": f\"{model_path}/dpt_hybrid-midas-501f0c75.pt\",\n",
|
|
|
+ " \"dpt_hybrid_nyu\": f\"{model_path}/dpt_hybrid_nyu-2ce69ec7.pt\",}\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "def init_midas_depth_model(midas_model_type=\"dpt_large\", optimize=True):\n",
|
|
|
+ " midas_model = None\n",
|
|
|
+ " net_w = None\n",
|
|
|
+ " net_h = None\n",
|
|
|
+ " resize_mode = None\n",
|
|
|
+ " normalization = None\n",
|
|
|
+ "\n",
|
|
|
+ " print(f\"Initializing MiDaS '{midas_model_type}' depth model...\")\n",
|
|
|
+ " # load network\n",
|
|
|
+ " midas_model_path = default_models[midas_model_type]\n",
|
|
|
+ "\n",
|
|
|
+ " if midas_model_type == \"dpt_large\": # DPT-Large\n",
|
|
|
+ " midas_model = DPTDepthModel(\n",
|
|
|
+ " path=midas_model_path,\n",
|
|
|
+ " backbone=\"vitl16_384\",\n",
|
|
|
+ " non_negative=True,\n",
|
|
|
+ " )\n",
|
|
|
+ " net_w, net_h = 384, 384\n",
|
|
|
+ " resize_mode = \"minimal\"\n",
|
|
|
+ " normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n",
|
|
|
+ " elif midas_model_type == \"dpt_hybrid\": #DPT-Hybrid\n",
|
|
|
+ " midas_model = DPTDepthModel(\n",
|
|
|
+ " path=midas_model_path,\n",
|
|
|
+ " backbone=\"vitb_rn50_384\",\n",
|
|
|
+ " non_negative=True,\n",
|
|
|
+ " )\n",
|
|
|
+ " net_w, net_h = 384, 384\n",
|
|
|
+ " resize_mode=\"minimal\"\n",
|
|
|
+ " normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n",
|
|
|
+ " elif midas_model_type == \"dpt_hybrid_nyu\": #DPT-Hybrid-NYU\n",
|
|
|
+ " midas_model = DPTDepthModel(\n",
|
|
|
+ " path=midas_model_path,\n",
|
|
|
+ " backbone=\"vitb_rn50_384\",\n",
|
|
|
+ " non_negative=True,\n",
|
|
|
+ " )\n",
|
|
|
+ " net_w, net_h = 384, 384\n",
|
|
|
+ " resize_mode=\"minimal\"\n",
|
|
|
+ " normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n",
|
|
|
+ " elif midas_model_type == \"midas_v21\":\n",
|
|
|
+ " midas_model = MidasNet(midas_model_path, non_negative=True)\n",
|
|
|
+ " net_w, net_h = 384, 384\n",
|
|
|
+ " resize_mode=\"upper_bound\"\n",
|
|
|
+ " normalization = NormalizeImage(\n",
|
|
|
+ " mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]\n",
|
|
|
+ " )\n",
|
|
|
+ " elif midas_model_type == \"midas_v21_small\":\n",
|
|
|
+ " midas_model = MidasNet_small(midas_model_path, features=64, backbone=\"efficientnet_lite3\", exportable=True, non_negative=True, blocks={'expand': True})\n",
|
|
|
+ " net_w, net_h = 256, 256\n",
|
|
|
+ " resize_mode=\"upper_bound\"\n",
|
|
|
+ " normalization = NormalizeImage(\n",
|
|
|
+ " mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]\n",
|
|
|
+ " )\n",
|
|
|
+ " else:\n",
|
|
|
+ " print(f\"midas_model_type '{midas_model_type}' not implemented\")\n",
|
|
|
+ " assert False\n",
|
|
|
+ "\n",
|
|
|
+ " midas_transform = T.Compose(\n",
|
|
|
+ " [\n",
|
|
|
+ " Resize(\n",
|
|
|
+ " net_w,\n",
|
|
|
+ " net_h,\n",
|
|
|
+ " resize_target=None,\n",
|
|
|
+ " keep_aspect_ratio=True,\n",
|
|
|
+ " ensure_multiple_of=32,\n",
|
|
|
+ " resize_method=resize_mode,\n",
|
|
|
+ " image_interpolation_method=cv2.INTER_CUBIC,\n",
|
|
|
+ " ),\n",
|
|
|
+ " normalization,\n",
|
|
|
+ " PrepareForNet(),\n",
|
|
|
+ " ]\n",
|
|
|
+ " )\n",
|
|
|
+ "\n",
|
|
|
+ " midas_model.eval()\n",
|
|
|
+ " \n",
|
|
|
+ " if optimize==True:\n",
|
|
|
+ " if DEVICE == torch.device(\"cuda\"):\n",
|
|
|
+ " midas_model = midas_model.to(memory_format=torch.channels_last) \n",
|
|
|
+ " midas_model = midas_model.half()\n",
|
|
|
+ "\n",
|
|
|
+ " midas_model.to(DEVICE)\n",
|
|
|
+ "\n",
|
|
|
+ " print(f\"MiDaS '{midas_model_type}' depth model initialized.\")\n",
|
|
|
+ " return midas_model, midas_transform, net_w, net_h, resize_mode, normalization"
|
|
|
+ ],
|
|
|
+ "outputs": [],
|
|
|
+ "execution_count": null
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "metadata": {
|
|
|
+ "cellView": "form",
|
|
|
+ "id": "DefFns"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "#@title 1.5 Define necessary functions\n",
|
|
|
+ "\n",
|
|
|
+ "# https://gist.github.com/adefossez/0646dbe9ed4005480a2407c62aac8869\n",
|
|
|
+ "\n",
|
|
|
+ "import py3d_tools as p3dT\n",
|
|
|
+ "import disco_xform_utils as dxf\n",
|
|
|
+ "\n",
|
|
|
+ "def interp(t):\n",
|
|
|
+ " return 3 * t**2 - 2 * t ** 3\n",
|
|
|
+ "\n",
|
|
|
+ "def perlin(width, height, scale=10, device=None):\n",
|
|
|
+ " gx, gy = torch.randn(2, width + 1, height + 1, 1, 1, device=device)\n",
|
|
|
+ " xs = torch.linspace(0, 1, scale + 1)[:-1, None].to(device)\n",
|
|
|
+ " ys = torch.linspace(0, 1, scale + 1)[None, :-1].to(device)\n",
|
|
|
+ " wx = 1 - interp(xs)\n",
|
|
|
+ " wy = 1 - interp(ys)\n",
|
|
|
+ " dots = 0\n",
|
|
|
+ " dots += wx * wy * (gx[:-1, :-1] * xs + gy[:-1, :-1] * ys)\n",
|
|
|
+ " dots += (1 - wx) * wy * (-gx[1:, :-1] * (1 - xs) + gy[1:, :-1] * ys)\n",
|
|
|
+ " dots += wx * (1 - wy) * (gx[:-1, 1:] * xs - gy[:-1, 1:] * (1 - ys))\n",
|
|
|
+ " dots += (1 - wx) * (1 - wy) * (-gx[1:, 1:] * (1 - xs) - gy[1:, 1:] * (1 - ys))\n",
|
|
|
+ " return dots.permute(0, 2, 1, 3).contiguous().view(width * scale, height * scale)\n",
|
|
|
+ "\n",
|
|
|
+ "def perlin_ms(octaves, width, height, grayscale, device=device):\n",
|
|
|
+ " out_array = [0.5] if grayscale else [0.5, 0.5, 0.5]\n",
|
|
|
+ " # out_array = [0.0] if grayscale else [0.0, 0.0, 0.0]\n",
|
|
|
+ " for i in range(1 if grayscale else 3):\n",
|
|
|
+ " scale = 2 ** len(octaves)\n",
|
|
|
+ " oct_width = width\n",
|
|
|
+ " oct_height = height\n",
|
|
|
+ " for oct in octaves:\n",
|
|
|
+ " p = perlin(oct_width, oct_height, scale, device)\n",
|
|
|
+ " out_array[i] += p * oct\n",
|
|
|
+ " scale //= 2\n",
|
|
|
+ " oct_width *= 2\n",
|
|
|
+ " oct_height *= 2\n",
|
|
|
+ " return torch.cat(out_array)\n",
|
|
|
+ "\n",
|
|
|
+ "def create_perlin_noise(octaves=[1, 1, 1, 1], width=2, height=2, grayscale=True):\n",
|
|
|
+ " out = perlin_ms(octaves, width, height, grayscale)\n",
|
|
|
+ " if grayscale:\n",
|
|
|
+ " out = TF.resize(size=(side_y, side_x), img=out.unsqueeze(0))\n",
|
|
|
+ " out = TF.to_pil_image(out.clamp(0, 1)).convert('RGB')\n",
|
|
|
+ " else:\n",
|
|
|
+ " out = out.reshape(-1, 3, out.shape[0]//3, out.shape[1])\n",
|
|
|
+ " out = TF.resize(size=(side_y, side_x), img=out)\n",
|
|
|
+ " out = TF.to_pil_image(out.clamp(0, 1).squeeze())\n",
|
|
|
+ "\n",
|
|
|
+ " out = ImageOps.autocontrast(out)\n",
|
|
|
+ " return out\n",
|
|
|
+ "\n",
|
|
|
+ "def regen_perlin():\n",
|
|
|
+ " if perlin_mode == 'color':\n",
|
|
|
+ " init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)\n",
|
|
|
+ " init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, False)\n",
|
|
|
+ " elif perlin_mode == 'gray':\n",
|
|
|
+ " init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, True)\n",
|
|
|
+ " init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)\n",
|
|
|
+ " else:\n",
|
|
|
+ " init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)\n",
|
|
|
+ " init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)\n",
|
|
|
+ "\n",
|
|
|
+ " init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device).unsqueeze(0).mul(2).sub(1)\n",
|
|
|
+ " del init2\n",
|
|
|
+ " return init.expand(batch_size, -1, -1, -1)\n",
|
|
|
+ "\n",
|
|
|
+ "def fetch(url_or_path):\n",
|
|
|
+ " if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):\n",
|
|
|
+ " r = requests.get(url_or_path)\n",
|
|
|
+ " r.raise_for_status()\n",
|
|
|
+ " fd = io.BytesIO()\n",
|
|
|
+ " fd.write(r.content)\n",
|
|
|
+ " fd.seek(0)\n",
|
|
|
+ " return fd\n",
|
|
|
+ " return open(url_or_path, 'rb')\n",
|
|
|
+ "\n",
|
|
|
+ "def read_image_workaround(path):\n",
|
|
|
+ " \"\"\"OpenCV reads images as BGR, Pillow saves them as RGB. Work around\n",
|
|
|
+ " this incompatibility to avoid colour inversions.\"\"\"\n",
|
|
|
+ " im_tmp = cv2.imread(path)\n",
|
|
|
+ " return cv2.cvtColor(im_tmp, cv2.COLOR_BGR2RGB)\n",
|
|
|
+ "\n",
|
|
|
+ "def parse_prompt(prompt):\n",
|
|
|
+ " if prompt.startswith('http://') or prompt.startswith('https://'):\n",
|
|
|
+ " vals = prompt.rsplit(':', 2)\n",
|
|
|
+ " vals = [vals[0] + ':' + vals[1], *vals[2:]]\n",
|
|
|
+ " else:\n",
|
|
|
+ " vals = prompt.rsplit(':', 1)\n",
|
|
|
+ " vals = vals + ['', '1'][len(vals):]\n",
|
|
|
+ " return vals[0], float(vals[1])\n",
|
|
|
+ "\n",
|
|
|
+ "def sinc(x):\n",
|
|
|
+ " return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))\n",
|
|
|
+ "\n",
|
|
|
+ "def lanczos(x, a):\n",
|
|
|
+ " cond = torch.logical_and(-a < x, x < a)\n",
|
|
|
+ " out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))\n",
|
|
|
+ " return out / out.sum()\n",
|
|
|
+ "\n",
|
|
|
+ "def ramp(ratio, width):\n",
|
|
|
+ " n = math.ceil(width / ratio + 1)\n",
|
|
|
+ " out = torch.empty([n])\n",
|
|
|
+ " cur = 0\n",
|
|
|
+ " for i in range(out.shape[0]):\n",
|
|
|
+ " out[i] = cur\n",
|
|
|
+ " cur += ratio\n",
|
|
|
+ " return torch.cat([-out[1:].flip([0]), out])[1:-1]\n",
|
|
|
+ "\n",
|
|
|
+ "def resample(input, size, align_corners=True):\n",
|
|
|
+ " n, c, h, w = input.shape\n",
|
|
|
+ " dh, dw = size\n",
|
|
|
+ "\n",
|
|
|
+ " input = input.reshape([n * c, 1, h, w])\n",
|
|
|
+ "\n",
|
|
|
+ " if dh < h:\n",
|
|
|
+ " kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)\n",
|
|
|
+ " pad_h = (kernel_h.shape[0] - 1) // 2\n",
|
|
|
+ " input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')\n",
|
|
|
+ " input = F.conv2d(input, kernel_h[None, None, :, None])\n",
|
|
|
+ "\n",
|
|
|
+ " if dw < w:\n",
|
|
|
+ " kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)\n",
|
|
|
+ " pad_w = (kernel_w.shape[0] - 1) // 2\n",
|
|
|
+ " input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')\n",
|
|
|
+ " input = F.conv2d(input, kernel_w[None, None, None, :])\n",
|
|
|
+ "\n",
|
|
|
+ " input = input.reshape([n, c, h, w])\n",
|
|
|
+ " return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)\n",
|
|
|
+ "\n",
|
|
|
+ "class MakeCutouts(nn.Module):\n",
|
|
|
+ " def __init__(self, cut_size, cutn, skip_augs=False):\n",
|
|
|
+ " super().__init__()\n",
|
|
|
+ " self.cut_size = cut_size\n",
|
|
|
+ " self.cutn = cutn\n",
|
|
|
+ " self.skip_augs = skip_augs\n",
|
|
|
+ " self.augs = T.Compose([\n",
|
|
|
+ " T.RandomHorizontalFlip(p=0.5),\n",
|
|
|
+ " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
+ " T.RandomAffine(degrees=15, translate=(0.1, 0.1)),\n",
|
|
|
+ " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
+ " T.RandomPerspective(distortion_scale=0.4, p=0.7),\n",
|
|
|
+ " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
+ " T.RandomGrayscale(p=0.15),\n",
|
|
|
+ " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
+ " # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),\n",
|
|
|
+ " ])\n",
|
|
|
+ "\n",
|
|
|
+ " def forward(self, input):\n",
|
|
|
+ " input = T.Pad(input.shape[2]//4, fill=0)(input)\n",
|
|
|
+ " sideY, sideX = input.shape[2:4]\n",
|
|
|
+ " max_size = min(sideX, sideY)\n",
|
|
|
+ "\n",
|
|
|
+ " cutouts = []\n",
|
|
|
+ " for ch in range(self.cutn):\n",
|
|
|
+ " if ch > self.cutn - self.cutn//4:\n",
|
|
|
+ " cutout = input.clone()\n",
|
|
|
+ " else:\n",
|
|
|
+ " size = int(max_size * torch.zeros(1,).normal_(mean=.8, std=.3).clip(float(self.cut_size/max_size), 1.))\n",
|
|
|
+ " offsetx = torch.randint(0, abs(sideX - size + 1), ())\n",
|
|
|
+ " offsety = torch.randint(0, abs(sideY - size + 1), ())\n",
|
|
|
+ " cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]\n",
|
|
|
+ "\n",
|
|
|
+ " if not self.skip_augs:\n",
|
|
|
+ " cutout = self.augs(cutout)\n",
|
|
|
+ " cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))\n",
|
|
|
+ " del cutout\n",
|
|
|
+ "\n",
|
|
|
+ " cutouts = torch.cat(cutouts, dim=0)\n",
|
|
|
+ " return cutouts\n",
|
|
|
+ "\n",
|
|
|
+ "cutout_debug = False\n",
|
|
|
+ "padargs = {}\n",
|
|
|
+ "\n",
|
|
|
+ "class MakeCutoutsDango(nn.Module):\n",
|
|
|
+ " def __init__(self, cut_size,\n",
|
|
|
+ " Overview=4, \n",
|
|
|
+ " InnerCrop = 0, IC_Size_Pow=0.5, IC_Grey_P = 0.2\n",
|
|
|
+ " ):\n",
|
|
|
+ " super().__init__()\n",
|
|
|
+ " self.cut_size = cut_size\n",
|
|
|
+ " self.Overview = Overview\n",
|
|
|
+ " self.InnerCrop = InnerCrop\n",
|
|
|
+ " self.IC_Size_Pow = IC_Size_Pow\n",
|
|
|
+ " self.IC_Grey_P = IC_Grey_P\n",
|
|
|
+ " if args.animation_mode == 'None':\n",
|
|
|
+ " self.augs = T.Compose([\n",
|
|
|
+ " T.RandomHorizontalFlip(p=0.5),\n",
|
|
|
+ " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
+ " T.RandomAffine(degrees=10, translate=(0.05, 0.05), interpolation = T.InterpolationMode.BILINEAR),\n",
|
|
|
+ " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
+ " T.RandomGrayscale(p=0.1),\n",
|
|
|
+ " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
+ " T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),\n",
|
|
|
+ " ])\n",
|
|
|
+ " elif args.animation_mode == 'Video Input':\n",
|
|
|
+ " self.augs = T.Compose([\n",
|
|
|
+ " T.RandomHorizontalFlip(p=0.5),\n",
|
|
|
+ " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
+ " T.RandomAffine(degrees=15, translate=(0.1, 0.1)),\n",
|
|
|
+ " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
+ " T.RandomPerspective(distortion_scale=0.4, p=0.7),\n",
|
|
|
+ " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
+ " T.RandomGrayscale(p=0.15),\n",
|
|
|
+ " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
+ " # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),\n",
|
|
|
+ " ])\n",
|
|
|
+ " elif args.animation_mode == '2D' or args.animation_mode == '3D':\n",
|
|
|
+ " self.augs = T.Compose([\n",
|
|
|
+ " T.RandomHorizontalFlip(p=0.4),\n",
|
|
|
+ " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
+ " T.RandomAffine(degrees=10, translate=(0.05, 0.05), interpolation = T.InterpolationMode.BILINEAR),\n",
|
|
|
+ " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
+ " T.RandomGrayscale(p=0.1),\n",
|
|
|
+ " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
+ " T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.3),\n",
|
|
|
+ " ])\n",
|
|
|
+ " \n",
|
|
|
+ "\n",
|
|
|
+ " def forward(self, input):\n",
|
|
|
+ " cutouts = []\n",
|
|
|
+ " gray = T.Grayscale(3)\n",
|
|
|
+ " sideY, sideX = input.shape[2:4]\n",
|
|
|
+ " max_size = min(sideX, sideY)\n",
|
|
|
+ " min_size = min(sideX, sideY, self.cut_size)\n",
|
|
|
+ " l_size = max(sideX, sideY)\n",
|
|
|
+ " output_shape = [1,3,self.cut_size,self.cut_size] \n",
|
|
|
+ " output_shape_2 = [1,3,self.cut_size+2,self.cut_size+2]\n",
|
|
|
+ " pad_input = F.pad(input,((sideY-max_size)//2,(sideY-max_size)//2,(sideX-max_size)//2,(sideX-max_size)//2), **padargs)\n",
|
|
|
+ " cutout = resize(pad_input, out_shape=output_shape)\n",
|
|
|
+ "\n",
|
|
|
+ " if self.Overview>0:\n",
|
|
|
+ " if self.Overview<=4:\n",
|
|
|
+ " if self.Overview>=1:\n",
|
|
|
+ " cutouts.append(cutout)\n",
|
|
|
+ " if self.Overview>=2:\n",
|
|
|
+ " cutouts.append(gray(cutout))\n",
|
|
|
+ " if self.Overview>=3:\n",
|
|
|
+ " cutouts.append(TF.hflip(cutout))\n",
|
|
|
+ " if self.Overview==4:\n",
|
|
|
+ " cutouts.append(gray(TF.hflip(cutout)))\n",
|
|
|
+ " else:\n",
|
|
|
+ " cutout = resize(pad_input, out_shape=output_shape)\n",
|
|
|
+ " for _ in range(self.Overview):\n",
|
|
|
+ " cutouts.append(cutout)\n",
|
|
|
+ "\n",
|
|
|
+ " if cutout_debug:\n",
|
|
|
+ " if is_colab:\n",
|
|
|
+ " TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save(\"/content/cutout_overview0.jpg\",quality=99)\n",
|
|
|
+ " else:\n",
|
|
|
+ " TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save(\"cutout_overview0.jpg\",quality=99)\n",
|
|
|
+ "\n",
|
|
|
+ " \n",
|
|
|
+ " if self.InnerCrop >0:\n",
|
|
|
+ " for i in range(self.InnerCrop):\n",
|
|
|
+ " size = int(torch.rand([])**self.IC_Size_Pow * (max_size - min_size) + min_size)\n",
|
|
|
+ " offsetx = torch.randint(0, sideX - size + 1, ())\n",
|
|
|
+ " offsety = torch.randint(0, sideY - size + 1, ())\n",
|
|
|
+ " cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]\n",
|
|
|
+ " if i <= int(self.IC_Grey_P * self.InnerCrop):\n",
|
|
|
+ " cutout = gray(cutout)\n",
|
|
|
+ " cutout = resize(cutout, out_shape=output_shape)\n",
|
|
|
+ " cutouts.append(cutout)\n",
|
|
|
+ " if cutout_debug:\n",
|
|
|
+ " if is_colab:\n",
|
|
|
+ " TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save(\"/content/cutout_InnerCrop.jpg\",quality=99)\n",
|
|
|
+ " else:\n",
|
|
|
+ " TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save(\"cutout_InnerCrop.jpg\",quality=99)\n",
|
|
|
+ " cutouts = torch.cat(cutouts)\n",
|
|
|
+ " if skip_augs is not True: cutouts=self.augs(cutouts)\n",
|
|
|
+ " return cutouts\n",
|
|
|
+ "\n",
|
|
|
+ "def spherical_dist_loss(x, y):\n",
|
|
|
+ " x = F.normalize(x, dim=-1)\n",
|
|
|
+ " y = F.normalize(y, dim=-1)\n",
|
|
|
+ " return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) \n",
|
|
|
+ "\n",
|
|
|
+ "def tv_loss(input):\n",
|
|
|
+ " \"\"\"L2 total variation loss, as in Mahendran et al.\"\"\"\n",
|
|
|
+ " input = F.pad(input, (0, 1, 0, 1), 'replicate')\n",
|
|
|
+ " x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]\n",
|
|
|
+ " y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]\n",
|
|
|
+ " return (x_diff**2 + y_diff**2).mean([1, 2, 3])\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "def range_loss(input):\n",
|
|
|
+ " return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])\n",
|
|
|
+ "\n",
|
|
|
+ "stop_on_next_loop = False # Make sure GPU memory doesn't get corrupted from cancelling the run mid-way through, allow a full frame to complete\n",
|
|
|
+ "\n",
|
|
|
+ "def do_3d_step(img_filepath, frame_num, midas_model, midas_transform):\n",
|
|
|
+ " if args.key_frames:\n",
|
|
|
+ " translation_x = args.translation_x_series[frame_num]\n",
|
|
|
+ " translation_y = args.translation_y_series[frame_num]\n",
|
|
|
+ " translation_z = args.translation_z_series[frame_num]\n",
|
|
|
+ " rotation_3d_x = args.rotation_3d_x_series[frame_num]\n",
|
|
|
+ " rotation_3d_y = args.rotation_3d_y_series[frame_num]\n",
|
|
|
+ " rotation_3d_z = args.rotation_3d_z_series[frame_num]\n",
|
|
|
+ " print(\n",
|
|
|
+ " f'translation_x: {translation_x}',\n",
|
|
|
+ " f'translation_y: {translation_y}',\n",
|
|
|
+ " f'translation_z: {translation_z}',\n",
|
|
|
+ " f'rotation_3d_x: {rotation_3d_x}',\n",
|
|
|
+ " f'rotation_3d_y: {rotation_3d_y}',\n",
|
|
|
+ " f'rotation_3d_z: {rotation_3d_z}',\n",
|
|
|
+ " )\n",
|
|
|
+ "\n",
|
|
|
+ " trans_scale = 1.0/200.0\n",
|
|
|
+ " translate_xyz = [-translation_x*trans_scale, translation_y*trans_scale, -translation_z*trans_scale]\n",
|
|
|
+ " rotate_xyz_degrees = [rotation_3d_x, rotation_3d_y, rotation_3d_z]\n",
|
|
|
+ " print('translation:',translate_xyz)\n",
|
|
|
+ " print('rotation:',rotate_xyz_degrees)\n",
|
|
|
+ " rotate_xyz = [math.radians(rotate_xyz_degrees[0]), math.radians(rotate_xyz_degrees[1]), math.radians(rotate_xyz_degrees[2])]\n",
|
|
|
+ " rot_mat = p3dT.euler_angles_to_matrix(torch.tensor(rotate_xyz, device=device), \"XYZ\").unsqueeze(0)\n",
|
|
|
+ " print(\"rot_mat: \" + str(rot_mat))\n",
|
|
|
+ " next_step_pil = dxf.transform_image_3d(img_filepath, midas_model, midas_transform, DEVICE,\n",
|
|
|
+ " rot_mat, translate_xyz, args.near_plane, args.far_plane,\n",
|
|
|
+ " args.fov, padding_mode=args.padding_mode,\n",
|
|
|
+ " sampling_mode=args.sampling_mode, midas_weight=args.midas_weight)\n",
|
|
|
+ " return next_step_pil\n",
|
|
|
+ "\n",
|
|
|
+ "def do_run():\n",
|
|
|
+ " seed = args.seed\n",
|
|
|
+ " print(range(args.start_frame, args.max_frames))\n",
|
|
|
+ "\n",
|
|
|
+ " if (args.animation_mode == \"3D\") and (args.midas_weight > 0.0):\n",
|
|
|
+ " midas_model, midas_transform, midas_net_w, midas_net_h, midas_resize_mode, midas_normalization = init_midas_depth_model(args.midas_depth_model)\n",
|
|
|
+ " for frame_num in range(args.start_frame, args.max_frames):\n",
|
|
|
+ " if stop_on_next_loop:\n",
|
|
|
+ " break\n",
|
|
|
+ " \n",
|
|
|
+ " display.clear_output(wait=True)\n",
|
|
|
+ "\n",
|
|
|
+ " # Print Frame progress if animation mode is on\n",
|
|
|
+ " if args.animation_mode != \"None\":\n",
|
|
|
+ " batchBar = tqdm(range(args.max_frames), desc =\"Frames\")\n",
|
|
|
+ " batchBar.n = frame_num\n",
|
|
|
+ " batchBar.refresh()\n",
|
|
|
+ "\n",
|
|
|
+ " \n",
|
|
|
+ " # Inits if not video frames\n",
|
|
|
+ " if args.animation_mode != \"Video Input\":\n",
|
|
|
+ " if args.init_image == '':\n",
|
|
|
+ " init_image = None\n",
|
|
|
+ " else:\n",
|
|
|
+ " init_image = args.init_image\n",
|
|
|
+ " init_scale = args.init_scale\n",
|
|
|
+ " skip_steps = args.skip_steps\n",
|
|
|
+ "\n",
|
|
|
+ " if args.animation_mode == \"2D\":\n",
|
|
|
+ " if args.key_frames:\n",
|
|
|
+ " angle = args.angle_series[frame_num]\n",
|
|
|
+ " zoom = args.zoom_series[frame_num]\n",
|
|
|
+ " translation_x = args.translation_x_series[frame_num]\n",
|
|
|
+ " translation_y = args.translation_y_series[frame_num]\n",
|
|
|
+ " print(\n",
|
|
|
+ " f'angle: {angle}',\n",
|
|
|
+ " f'zoom: {zoom}',\n",
|
|
|
+ " f'translation_x: {translation_x}',\n",
|
|
|
+ " f'translation_y: {translation_y}',\n",
|
|
|
+ " )\n",
|
|
|
+ " \n",
|
|
|
+ " if frame_num > 0:\n",
|
|
|
+ " seed += 1\n",
|
|
|
+ " if resume_run and frame_num == start_frame:\n",
|
|
|
+ " img_0 = cv2.imread(batchFolder+f\"/{batch_name}({batchNum})_{start_frame-1:04}.png\")\n",
|
|
|
+ " else:\n",
|
|
|
+ " img_0 = cv2.imread('prevFrame.png')\n",
|
|
|
+ " center = (1*img_0.shape[1]//2, 1*img_0.shape[0]//2)\n",
|
|
|
+ " trans_mat = np.float32(\n",
|
|
|
+ " [[1, 0, translation_x],\n",
|
|
|
+ " [0, 1, translation_y]]\n",
|
|
|
+ " )\n",
|
|
|
+ " rot_mat = cv2.getRotationMatrix2D( center, angle, zoom )\n",
|
|
|
+ " trans_mat = np.vstack([trans_mat, [0,0,1]])\n",
|
|
|
+ " rot_mat = np.vstack([rot_mat, [0,0,1]])\n",
|
|
|
+ " transformation_matrix = np.matmul(rot_mat, trans_mat)\n",
|
|
|
+ " img_0 = cv2.warpPerspective(\n",
|
|
|
+ " img_0,\n",
|
|
|
+ " transformation_matrix,\n",
|
|
|
+ " (img_0.shape[1], img_0.shape[0]),\n",
|
|
|
+ " borderMode=cv2.BORDER_WRAP\n",
|
|
|
+ " )\n",
|
|
|
+ "\n",
|
|
|
+ " cv2.imwrite('prevFrameScaled.png', img_0)\n",
|
|
|
+ " init_image = 'prevFrameScaled.png'\n",
|
|
|
+ " init_scale = args.frames_scale\n",
|
|
|
+ " skip_steps = args.calc_frames_skip_steps\n",
|
|
|
+ "\n",
|
|
|
+ " if args.animation_mode == \"3D\":\n",
|
|
|
+ " if frame_num == 0:\n",
|
|
|
+ " pass\n",
|
|
|
+ " else:\n",
|
|
|
+ " seed += 1 \n",
|
|
|
+ " if resume_run and frame_num == start_frame:\n",
|
|
|
+ " img_filepath = batchFolder+f\"/{batch_name}({batchNum})_{start_frame-1:04}.png\"\n",
|
|
|
+ " if turbo_mode and frame_num > turbo_preroll:\n",
|
|
|
+ " shutil.copyfile(img_filepath, 'oldFrameScaled.png')\n",
|
|
|
+ " else:\n",
|
|
|
+ " img_filepath = '/content/prevFrame.png' if is_colab else 'prevFrame.png'\n",
|
|
|
+ "\n",
|
|
|
+ " next_step_pil = do_3d_step(img_filepath, frame_num, midas_model, midas_transform)\n",
|
|
|
+ " next_step_pil.save('prevFrameScaled.png')\n",
|
|
|
+ "\n",
|
|
|
+ " ### Turbo mode - skip some diffusions, use 3d morph for clarity and to save time\n",
|
|
|
+ " if turbo_mode:\n",
|
|
|
+ " if frame_num == turbo_preroll: #start tracking oldframe\n",
|
|
|
+ " next_step_pil.save('oldFrameScaled.png')#stash for later blending \n",
|
|
|
+ " elif frame_num > turbo_preroll:\n",
|
|
|
+ " #set up 2 warped image sequences, old & new, to blend toward new diff image\n",
|
|
|
+ " old_frame = do_3d_step('oldFrameScaled.png', frame_num, midas_model, midas_transform)\n",
|
|
|
+ " old_frame.save('oldFrameScaled.png')\n",
|
|
|
+ " if frame_num % int(turbo_steps) != 0: \n",
|
|
|
+ " print('turbo skip this frame: skipping clip diffusion steps')\n",
|
|
|
+ " filename = f'{args.batch_name}({args.batchNum})_{frame_num:04}.png'\n",
|
|
|
+ " blend_factor = ((frame_num % int(turbo_steps))+1)/int(turbo_steps)\n",
|
|
|
+ " print('turbo skip this frame: skipping clip diffusion steps and saving blended frame')\n",
|
|
|
+ " newWarpedImg = cv2.imread('prevFrameScaled.png')#this is already updated..\n",
|
|
|
+ " oldWarpedImg = cv2.imread('oldFrameScaled.png')\n",
|
|
|
+ " blendedImage = cv2.addWeighted(newWarpedImg, blend_factor, oldWarpedImg,1-blend_factor, 0.0)\n",
|
|
|
+ " cv2.imwrite(f'{batchFolder}/{filename}',blendedImage)\n",
|
|
|
+ " next_step_pil.save(f'{img_filepath}') # save it also as prev_frame to feed next iteration\n",
|
|
|
+ " continue\n",
|
|
|
+ " else:\n",
|
|
|
+ " #if not a skip frame, will run diffusion and need to blend.\n",
|
|
|
+ " oldWarpedImg = cv2.imread('prevFrameScaled.png')\n",
|
|
|
+ " cv2.imwrite(f'oldFrameScaled.png',oldWarpedImg)#swap in for blending later \n",
|
|
|
+ " print('clip/diff this frame - generate clip diff image')\n",
|
|
|
+ "\n",
|
|
|
+ " init_image = 'prevFrameScaled.png'\n",
|
|
|
+ " init_scale = args.frames_scale\n",
|
|
|
+ " skip_steps = args.calc_frames_skip_steps\n",
|
|
|
+ "\n",
|
|
|
+ " if args.animation_mode == \"Video Input\":\n",
|
|
|
+ " if not video_init_seed_continuity:\n",
|
|
|
+ " seed += 1\n",
|
|
|
+ " init_image = f'{videoFramesFolder}/{frame_num+1:04}.jpg'\n",
|
|
|
+ " init_scale = args.frames_scale\n",
|
|
|
+ " skip_steps = args.calc_frames_skip_steps\n",
|
|
|
+ "\n",
|
|
|
+ " loss_values = []\n",
|
|
|
+ " \n",
|
|
|
+ " if seed is not None:\n",
|
|
|
+ " np.random.seed(seed)\n",
|
|
|
+ " random.seed(seed)\n",
|
|
|
+ " torch.manual_seed(seed)\n",
|
|
|
+ " torch.cuda.manual_seed_all(seed)\n",
|
|
|
+ " torch.backends.cudnn.deterministic = True\n",
|
|
|
+ " \n",
|
|
|
+ " target_embeds, weights = [], []\n",
|
|
|
+ " \n",
|
|
|
+ " if args.prompts_series is not None and frame_num >= len(args.prompts_series):\n",
|
|
|
+ " frame_prompt = args.prompts_series[-1]\n",
|
|
|
+ " elif args.prompts_series is not None:\n",
|
|
|
+ " frame_prompt = args.prompts_series[frame_num]\n",
|
|
|
+ " else:\n",
|
|
|
+ " frame_prompt = []\n",
|
|
|
+ " \n",
|
|
|
+ " print(args.image_prompts_series)\n",
|
|
|
+ " if args.image_prompts_series is not None and frame_num >= len(args.image_prompts_series):\n",
|
|
|
+ " image_prompt = args.image_prompts_series[-1]\n",
|
|
|
+ " elif args.image_prompts_series is not None:\n",
|
|
|
+ " image_prompt = args.image_prompts_series[frame_num]\n",
|
|
|
+ " else:\n",
|
|
|
+ " image_prompt = []\n",
|
|
|
+ "\n",
|
|
|
+ " print(f'Frame {frame_num} Prompt: {frame_prompt}')\n",
|
|
|
+ "\n",
|
|
|
+ " model_stats = []\n",
|
|
|
+ " for clip_model in clip_models:\n",
|
|
|
+ " cutn = 16\n",
|
|
|
+ " model_stat = {\"clip_model\":None,\"target_embeds\":[],\"make_cutouts\":None,\"weights\":[]}\n",
|
|
|
+ " model_stat[\"clip_model\"] = clip_model\n",
|
|
|
+ " \n",
|
|
|
+ " \n",
|
|
|
+ " for prompt in frame_prompt:\n",
|
|
|
+ " txt, weight = parse_prompt(prompt)\n",
|
|
|
+ " txt = clip_model.encode_text(clip.tokenize(prompt).to(device)).float()\n",
|
|
|
+ " \n",
|
|
|
+ " if args.fuzzy_prompt:\n",
|
|
|
+ " for i in range(25):\n",
|
|
|
+ " model_stat[\"target_embeds\"].append((txt + torch.randn(txt.shape).cuda() * args.rand_mag).clamp(0,1))\n",
|
|
|
+ " model_stat[\"weights\"].append(weight)\n",
|
|
|
+ " else:\n",
|
|
|
+ " model_stat[\"target_embeds\"].append(txt)\n",
|
|
|
+ " model_stat[\"weights\"].append(weight)\n",
|
|
|
+ " \n",
|
|
|
+ " if image_prompt:\n",
|
|
|
+ " model_stat[\"make_cutouts\"] = MakeCutouts(clip_model.visual.input_resolution, cutn, skip_augs=skip_augs) \n",
|
|
|
+ " for prompt in image_prompt:\n",
|
|
|
+ " path, weight = parse_prompt(prompt)\n",
|
|
|
+ " img = Image.open(fetch(path)).convert('RGB')\n",
|
|
|
+ " img = TF.resize(img, min(side_x, side_y, *img.size), T.InterpolationMode.LANCZOS)\n",
|
|
|
+ " batch = model_stat[\"make_cutouts\"](TF.to_tensor(img).to(device).unsqueeze(0).mul(2).sub(1))\n",
|
|
|
+ " embed = clip_model.encode_image(normalize(batch)).float()\n",
|
|
|
+ " if fuzzy_prompt:\n",
|
|
|
+ " for i in range(25):\n",
|
|
|
+ " model_stat[\"target_embeds\"].append((embed + torch.randn(embed.shape).cuda() * rand_mag).clamp(0,1))\n",
|
|
|
+ " weights.extend([weight / cutn] * cutn)\n",
|
|
|
+ " else:\n",
|
|
|
+ " model_stat[\"target_embeds\"].append(embed)\n",
|
|
|
+ " model_stat[\"weights\"].extend([weight / cutn] * cutn)\n",
|
|
|
+ " \n",
|
|
|
+ " model_stat[\"target_embeds\"] = torch.cat(model_stat[\"target_embeds\"])\n",
|
|
|
+ " model_stat[\"weights\"] = torch.tensor(model_stat[\"weights\"], device=device)\n",
|
|
|
+ " if model_stat[\"weights\"].sum().abs() < 1e-3:\n",
|
|
|
+ " raise RuntimeError('The weights must not sum to 0.')\n",
|
|
|
+ " model_stat[\"weights\"] /= model_stat[\"weights\"].sum().abs()\n",
|
|
|
+ " model_stats.append(model_stat)\n",
|
|
|
+ " \n",
|
|
|
+ " init = None\n",
|
|
|
+ " if init_image is not None:\n",
|
|
|
+ " init = Image.open(fetch(init_image)).convert('RGB')\n",
|
|
|
+ " init = init.resize((args.side_x, args.side_y), Image.LANCZOS)\n",
|
|
|
+ " init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)\n",
|
|
|
+ " \n",
|
|
|
+ " if args.perlin_init:\n",
|
|
|
+ " if args.perlin_mode == 'color':\n",
|
|
|
+ " init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)\n",
|
|
|
+ " init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, False)\n",
|
|
|
+ " elif args.perlin_mode == 'gray':\n",
|
|
|
+ " init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, True)\n",
|
|
|
+ " init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)\n",
|
|
|
+ " else:\n",
|
|
|
+ " init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)\n",
|
|
|
+ " init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)\n",
|
|
|
+ " # init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device)\n",
|
|
|
+ " init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device).unsqueeze(0).mul(2).sub(1)\n",
|
|
|
+ " del init2\n",
|
|
|
+ " \n",
|
|
|
+ " cur_t = None\n",
|
|
|
+ " \n",
|
|
|
+ " def cond_fn(x, t, y=None):\n",
|
|
|
+ " with torch.enable_grad():\n",
|
|
|
+ " x_is_NaN = False\n",
|
|
|
+ " x = x.detach().requires_grad_()\n",
|
|
|
+ " n = x.shape[0]\n",
|
|
|
+ " if use_secondary_model is True:\n",
|
|
|
+ " alpha = torch.tensor(diffusion.sqrt_alphas_cumprod[cur_t], device=device, dtype=torch.float32)\n",
|
|
|
+ " sigma = torch.tensor(diffusion.sqrt_one_minus_alphas_cumprod[cur_t], device=device, dtype=torch.float32)\n",
|
|
|
+ " cosine_t = alpha_sigma_to_t(alpha, sigma)\n",
|
|
|
+ " out = secondary_model(x, cosine_t[None].repeat([n])).pred\n",
|
|
|
+ " fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]\n",
|
|
|
+ " x_in = out * fac + x * (1 - fac)\n",
|
|
|
+ " x_in_grad = torch.zeros_like(x_in)\n",
|
|
|
+ " else:\n",
|
|
|
+ " my_t = torch.ones([n], device=device, dtype=torch.long) * cur_t\n",
|
|
|
+ " out = diffusion.p_mean_variance(model, x, my_t, clip_denoised=False, model_kwargs={'y': y})\n",
|
|
|
+ " fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]\n",
|
|
|
+ " x_in = out['pred_xstart'] * fac + x * (1 - fac)\n",
|
|
|
+ " x_in_grad = torch.zeros_like(x_in)\n",
|
|
|
+ " for model_stat in model_stats:\n",
|
|
|
+ " for i in range(args.cutn_batches):\n",
|
|
|
+ " t_int = int(t.item())+1 #errors on last step without +1, need to find source\n",
|
|
|
+ " #when using SLIP Base model the dimensions need to be hard coded to avoid AttributeError: 'VisionTransformer' object has no attribute 'input_resolution'\n",
|
|
|
+ " try:\n",
|
|
|
+ " input_resolution=model_stat[\"clip_model\"].visual.input_resolution\n",
|
|
|
+ " except:\n",
|
|
|
+ " input_resolution=224\n",
|
|
|
+ "\n",
|
|
|
+ " cuts = MakeCutoutsDango(input_resolution,\n",
|
|
|
+ " Overview= args.cut_overview[1000-t_int], \n",
|
|
|
+ " InnerCrop = args.cut_innercut[1000-t_int], IC_Size_Pow=args.cut_ic_pow, IC_Grey_P = args.cut_icgray_p[1000-t_int]\n",
|
|
|
+ " )\n",
|
|
|
+ " clip_in = normalize(cuts(x_in.add(1).div(2)))\n",
|
|
|
+ " image_embeds = model_stat[\"clip_model\"].encode_image(clip_in).float()\n",
|
|
|
+ " dists = spherical_dist_loss(image_embeds.unsqueeze(1), model_stat[\"target_embeds\"].unsqueeze(0))\n",
|
|
|
+ " dists = dists.view([args.cut_overview[1000-t_int]+args.cut_innercut[1000-t_int], n, -1])\n",
|
|
|
+ " losses = dists.mul(model_stat[\"weights\"]).sum(2).mean(0)\n",
|
|
|
+ " loss_values.append(losses.sum().item()) # log loss, probably shouldn't do per cutn_batch\n",
|
|
|
+ " x_in_grad += torch.autograd.grad(losses.sum() * clip_guidance_scale, x_in)[0] / cutn_batches\n",
|
|
|
+ " tv_losses = tv_loss(x_in)\n",
|
|
|
+ " if use_secondary_model is True:\n",
|
|
|
+ " range_losses = range_loss(out)\n",
|
|
|
+ " else:\n",
|
|
|
+ " range_losses = range_loss(out['pred_xstart'])\n",
|
|
|
+ " sat_losses = torch.abs(x_in - x_in.clamp(min=-1,max=1)).mean()\n",
|
|
|
+ " loss = tv_losses.sum() * tv_scale + range_losses.sum() * range_scale + sat_losses.sum() * sat_scale\n",
|
|
|
+ " if init is not None and args.init_scale:\n",
|
|
|
+ " init_losses = lpips_model(x_in, init)\n",
|
|
|
+ " loss = loss + init_losses.sum() * args.init_scale\n",
|
|
|
+ " x_in_grad += torch.autograd.grad(loss, x_in)[0]\n",
|
|
|
+ " if torch.isnan(x_in_grad).any()==False:\n",
|
|
|
+ " grad = -torch.autograd.grad(x_in, x, x_in_grad)[0]\n",
|
|
|
+ " else:\n",
|
|
|
+ " # print(\"NaN'd\")\n",
|
|
|
+ " x_is_NaN = True\n",
|
|
|
+ " grad = torch.zeros_like(x)\n",
|
|
|
+ " if args.clamp_grad and x_is_NaN == False:\n",
|
|
|
+ " magnitude = grad.square().mean().sqrt()\n",
|
|
|
+ " return grad * magnitude.clamp(max=args.clamp_max) / magnitude #min=-0.02, min=-clamp_max, \n",
|
|
|
+ " return grad\n",
|
|
|
+ " \n",
|
|
|
+ " if args.diffusion_sampling_mode == 'ddim':\n",
|
|
|
+ " sample_fn = diffusion.ddim_sample_loop_progressive\n",
|
|
|
+ " else:\n",
|
|
|
+ " sample_fn = diffusion.plms_sample_loop_progressive\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ " image_display = Output()\n",
|
|
|
+ " for i in range(args.n_batches):\n",
|
|
|
+ " if args.animation_mode == 'None':\n",
|
|
|
+ " display.clear_output(wait=True)\n",
|
|
|
+ " batchBar = tqdm(range(args.n_batches), desc =\"Batches\")\n",
|
|
|
+ " batchBar.n = i\n",
|
|
|
+ " batchBar.refresh()\n",
|
|
|
+ " print('')\n",
|
|
|
+ " display.display(image_display)\n",
|
|
|
+ " gc.collect()\n",
|
|
|
+ " torch.cuda.empty_cache()\n",
|
|
|
+ " cur_t = diffusion.num_timesteps - skip_steps - 1\n",
|
|
|
+ " total_steps = cur_t\n",
|
|
|
+ "\n",
|
|
|
+ " if perlin_init:\n",
|
|
|
+ " init = regen_perlin()\n",
|
|
|
+ "\n",
|
|
|
+ " if args.diffusion_sampling_mode == 'ddim':\n",
|
|
|
+ " samples = sample_fn(\n",
|
|
|
+ " model,\n",
|
|
|
+ " (batch_size, 3, args.side_y, args.side_x),\n",
|
|
|
+ " clip_denoised=clip_denoised,\n",
|
|
|
+ " model_kwargs={},\n",
|
|
|
+ " cond_fn=cond_fn,\n",
|
|
|
+ " progress=True,\n",
|
|
|
+ " skip_timesteps=skip_steps,\n",
|
|
|
+ " init_image=init,\n",
|
|
|
+ " randomize_class=randomize_class,\n",
|
|
|
+ " eta=eta,\n",
|
|
|
+ " )\n",
|
|
|
+ " else:\n",
|
|
|
+ " samples = sample_fn(\n",
|
|
|
+ " model,\n",
|
|
|
+ " (batch_size, 3, args.side_y, args.side_x),\n",
|
|
|
+ " clip_denoised=clip_denoised,\n",
|
|
|
+ " model_kwargs={},\n",
|
|
|
+ " cond_fn=cond_fn,\n",
|
|
|
+ " progress=True,\n",
|
|
|
+ " skip_timesteps=skip_steps,\n",
|
|
|
+ " init_image=init,\n",
|
|
|
+ " randomize_class=randomize_class,\n",
|
|
|
+ " order=2,\n",
|
|
|
+ " )\n",
|
|
|
+ " \n",
|
|
|
+ " \n",
|
|
|
+ " # with run_display:\n",
|
|
|
+ " # display.clear_output(wait=True)\n",
|
|
|
+ " imgToSharpen = None\n",
|
|
|
+ " for j, sample in enumerate(samples): \n",
|
|
|
+ " cur_t -= 1\n",
|
|
|
+ " intermediateStep = False\n",
|
|
|
+ " if args.steps_per_checkpoint is not None:\n",
|
|
|
+ " if j % steps_per_checkpoint == 0 and j > 0:\n",
|
|
|
+ " intermediateStep = True\n",
|
|
|
+ " elif j in args.intermediate_saves:\n",
|
|
|
+ " intermediateStep = True\n",
|
|
|
+ " with image_display:\n",
|
|
|
+ " if j % args.display_rate == 0 or cur_t == -1 or intermediateStep == True:\n",
|
|
|
+ " for k, image in enumerate(sample['pred_xstart']):\n",
|
|
|
+ " # tqdm.write(f'Batch {i}, step {j}, output {k}:')\n",
|
|
|
+ " current_time = datetime.now().strftime('%y%m%d-%H%M%S_%f')\n",
|
|
|
+ " percent = math.ceil(j/total_steps*100)\n",
|
|
|
+ " if args.n_batches > 0:\n",
|
|
|
+ " #if intermediates are saved to the subfolder, don't append a step or percentage to the name\n",
|
|
|
+ " if cur_t == -1 and args.intermediates_in_subfolder is True:\n",
|
|
|
+ " save_num = f'{frame_num:04}' if animation_mode != \"None\" else i\n",
|
|
|
+ " filename = f'{args.batch_name}({args.batchNum})_{save_num}.png'\n",
|
|
|
+ " else:\n",
|
|
|
+ " #If we're working with percentages, append it\n",
|
|
|
+ " if args.steps_per_checkpoint is not None:\n",
|
|
|
+ " filename = f'{args.batch_name}({args.batchNum})_{i:04}-{percent:02}%.png'\n",
|
|
|
+ " # Or else, iIf we're working with specific steps, append those\n",
|
|
|
+ " else:\n",
|
|
|
+ " filename = f'{args.batch_name}({args.batchNum})_{i:04}-{j:03}.png'\n",
|
|
|
+ " image = TF.to_pil_image(image.add(1).div(2).clamp(0, 1))\n",
|
|
|
+ " if j % args.display_rate == 0 or cur_t == -1:\n",
|
|
|
+ " image.save('progress.png')\n",
|
|
|
+ " display.clear_output(wait=True)\n",
|
|
|
+ " display.display(display.Image('progress.png'))\n",
|
|
|
+ " if args.steps_per_checkpoint is not None:\n",
|
|
|
+ " if j % args.steps_per_checkpoint == 0 and j > 0:\n",
|
|
|
+ " if args.intermediates_in_subfolder is True:\n",
|
|
|
+ " image.save(f'{partialFolder}/{filename}')\n",
|
|
|
+ " else:\n",
|
|
|
+ " image.save(f'{batchFolder}/{filename}')\n",
|
|
|
+ " else:\n",
|
|
|
+ " if j in args.intermediate_saves:\n",
|
|
|
+ " if args.intermediates_in_subfolder is True:\n",
|
|
|
+ " image.save(f'{partialFolder}/{filename}')\n",
|
|
|
+ " else:\n",
|
|
|
+ " image.save(f'{batchFolder}/{filename}')\n",
|
|
|
+ " if cur_t == -1:\n",
|
|
|
+ " if frame_num == 0:\n",
|
|
|
+ " save_settings()\n",
|
|
|
+ " if args.animation_mode != \"None\":\n",
|
|
|
+ " image.save('prevFrame.png')\n",
|
|
|
+ " if args.sharpen_preset != \"Off\" and animation_mode == \"None\":\n",
|
|
|
+ " imgToSharpen = image\n",
|
|
|
+ " if args.keep_unsharp is True:\n",
|
|
|
+ " image.save(f'{unsharpenFolder}/{filename}')\n",
|
|
|
+ " else:\n",
|
|
|
+ " image.save(f'{batchFolder}/{filename}')\n",
|
|
|
+ " if args.animation_mode == \"3D\":\n",
|
|
|
+ " # If turbo, save a blended image\n",
|
|
|
+ " if turbo_mode:\n",
|
|
|
+ " # Mix new image with prevFrameScaled\n",
|
|
|
+ " blend_factor = (1)/int(turbo_steps)\n",
|
|
|
+ " newFrame = cv2.imread('prevFrame.png') # This is already updated..\n",
|
|
|
+ " prev_frame_warped = cv2.imread('prevFrameScaled.png')\n",
|
|
|
+ " blendedImage = cv2.addWeighted(newFrame, blend_factor, prev_frame_warped, (1-blend_factor), 0.0)\n",
|
|
|
+ " cv2.imwrite(f'{batchFolder}/{filename}',blendedImage)\n",
|
|
|
+ " else:\n",
|
|
|
+ " image.save(f'{batchFolder}/{filename}')\n",
|
|
|
+ " # if frame_num != args.max_frames-1:\n",
|
|
|
+ " # display.clear_output()\n",
|
|
|
+ "\n",
|
|
|
+ " with image_display: \n",
|
|
|
+ " if args.sharpen_preset != \"Off\" and animation_mode == \"None\":\n",
|
|
|
+ " print('Starting Diffusion Sharpening...')\n",
|
|
|
+ " do_superres(imgToSharpen, f'{batchFolder}/{filename}')\n",
|
|
|
+ " display.clear_output()\n",
|
|
|
+ " \n",
|
|
|
+ " plt.plot(np.array(loss_values), 'r')\n",
|
|
|
+ "\n",
|
|
|
+ "def save_settings():\n",
|
|
|
+ " setting_list = {\n",
|
|
|
+ " 'text_prompts': text_prompts,\n",
|
|
|
+ " 'image_prompts': image_prompts,\n",
|
|
|
+ " 'clip_guidance_scale': clip_guidance_scale,\n",
|
|
|
+ " 'tv_scale': tv_scale,\n",
|
|
|
+ " 'range_scale': range_scale,\n",
|
|
|
+ " 'sat_scale': sat_scale,\n",
|
|
|
+ " # 'cutn': cutn,\n",
|
|
|
+ " 'cutn_batches': cutn_batches,\n",
|
|
|
+ " 'max_frames': max_frames,\n",
|
|
|
+ " 'interp_spline': interp_spline,\n",
|
|
|
+ " # 'rotation_per_frame': rotation_per_frame,\n",
|
|
|
+ " 'init_image': init_image,\n",
|
|
|
+ " 'init_scale': init_scale,\n",
|
|
|
+ " 'skip_steps': skip_steps,\n",
|
|
|
+ " # 'zoom_per_frame': zoom_per_frame,\n",
|
|
|
+ " 'frames_scale': frames_scale,\n",
|
|
|
+ " 'frames_skip_steps': frames_skip_steps,\n",
|
|
|
+ " 'perlin_init': perlin_init,\n",
|
|
|
+ " 'perlin_mode': perlin_mode,\n",
|
|
|
+ " 'skip_augs': skip_augs,\n",
|
|
|
+ " 'randomize_class': randomize_class,\n",
|
|
|
+ " 'clip_denoised': clip_denoised,\n",
|
|
|
+ " 'clamp_grad': clamp_grad,\n",
|
|
|
+ " 'clamp_max': clamp_max,\n",
|
|
|
+ " 'seed': seed,\n",
|
|
|
+ " 'fuzzy_prompt': fuzzy_prompt,\n",
|
|
|
+ " 'rand_mag': rand_mag,\n",
|
|
|
+ " 'eta': eta,\n",
|
|
|
+ " 'width': width_height[0],\n",
|
|
|
+ " 'height': width_height[1],\n",
|
|
|
+ " 'diffusion_model': diffusion_model,\n",
|
|
|
+ " 'use_secondary_model': use_secondary_model,\n",
|
|
|
+ " 'steps': steps,\n",
|
|
|
+ " 'diffusion_steps': diffusion_steps,\n",
|
|
|
+ " 'diffusion_sampling_mode': diffusion_sampling_mode,\n",
|
|
|
+ " 'ViTB32': ViTB32,\n",
|
|
|
+ " 'ViTB16': ViTB16,\n",
|
|
|
+ " 'ViTL14': ViTL14,\n",
|
|
|
+ " 'RN101': RN101,\n",
|
|
|
+ " 'RN50': RN50,\n",
|
|
|
+ " 'RN50x4': RN50x4,\n",
|
|
|
+ " 'RN50x16': RN50x16,\n",
|
|
|
+ " 'RN50x64': RN50x64,\n",
|
|
|
+ " 'cut_overview': str(cut_overview),\n",
|
|
|
+ " 'cut_innercut': str(cut_innercut),\n",
|
|
|
+ " 'cut_ic_pow': cut_ic_pow,\n",
|
|
|
+ " 'cut_icgray_p': str(cut_icgray_p),\n",
|
|
|
+ " 'key_frames': key_frames,\n",
|
|
|
+ " 'max_frames': max_frames,\n",
|
|
|
+ " 'angle': angle,\n",
|
|
|
+ " 'zoom': zoom,\n",
|
|
|
+ " 'translation_x': translation_x,\n",
|
|
|
+ " 'translation_y': translation_y,\n",
|
|
|
+ " 'translation_z': translation_z,\n",
|
|
|
+ " 'rotation_3d_x': rotation_3d_x,\n",
|
|
|
+ " 'rotation_3d_y': rotation_3d_y,\n",
|
|
|
+ " 'rotation_3d_z': rotation_3d_z,\n",
|
|
|
+ " 'midas_depth_model': midas_depth_model,\n",
|
|
|
+ " 'midas_weight': midas_weight,\n",
|
|
|
+ " 'near_plane': near_plane,\n",
|
|
|
+ " 'far_plane': far_plane,\n",
|
|
|
+ " 'fov': fov,\n",
|
|
|
+ " 'padding_mode': padding_mode,\n",
|
|
|
+ " 'sampling_mode': sampling_mode,\n",
|
|
|
+ " 'video_init_path':video_init_path,\n",
|
|
|
+ " 'extract_nth_frame':extract_nth_frame,\n",
|
|
|
+ " 'video_init_seed_continuity': video_init_seed_continuity,\n",
|
|
|
+ " 'turbo_mode':turbo_mode,\n",
|
|
|
+ " 'turbo_steps':turbo_steps,\n",
|
|
|
+ " 'turbo_preroll':turbo_preroll,\n",
|
|
|
+ " }\n",
|
|
|
+ " # print('Settings:', setting_list)\n",
|
|
|
+ " with open(f\"{batchFolder}/{batch_name}({batchNum})_settings.txt\", \"w+\") as f: #save settings\n",
|
|
|
+ " json.dump(setting_list, f, ensure_ascii=False, indent=4)"
|
|
|
+ ],
|
|
|
+ "outputs": [],
|
|
|
+ "execution_count": null
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "metadata": {
|
|
|
+ "cellView": "form",
|
|
|
+ "id": "DefSecModel"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "#@title 1.6 Define the secondary diffusion model\n",
|
|
|
+ "\n",
|
|
|
+ "def append_dims(x, n):\n",
|
|
|
+ " return x[(Ellipsis, *(None,) * (n - x.ndim))]\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "def expand_to_planes(x, shape):\n",
|
|
|
+ " return append_dims(x, len(shape)).repeat([1, 1, *shape[2:]])\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "def alpha_sigma_to_t(alpha, sigma):\n",
|
|
|
+ " return torch.atan2(sigma, alpha) * 2 / math.pi\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "def t_to_alpha_sigma(t):\n",
|
|
|
+ " return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "@dataclass\n",
|
|
|
+ "class DiffusionOutput:\n",
|
|
|
+ " v: torch.Tensor\n",
|
|
|
+ " pred: torch.Tensor\n",
|
|
|
+ " eps: torch.Tensor\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "class ConvBlock(nn.Sequential):\n",
|
|
|
+ " def __init__(self, c_in, c_out):\n",
|
|
|
+ " super().__init__(\n",
|
|
|
+ " nn.Conv2d(c_in, c_out, 3, padding=1),\n",
|
|
|
+ " nn.ReLU(inplace=True),\n",
|
|
|
+ " )\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "class SkipBlock(nn.Module):\n",
|
|
|
+ " def __init__(self, main, skip=None):\n",
|
|
|
+ " super().__init__()\n",
|
|
|
+ " self.main = nn.Sequential(*main)\n",
|
|
|
+ " self.skip = skip if skip else nn.Identity()\n",
|
|
|
+ "\n",
|
|
|
+ " def forward(self, input):\n",
|
|
|
+ " return torch.cat([self.main(input), self.skip(input)], dim=1)\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "class FourierFeatures(nn.Module):\n",
|
|
|
+ " def __init__(self, in_features, out_features, std=1.):\n",
|
|
|
+ " super().__init__()\n",
|
|
|
+ " assert out_features % 2 == 0\n",
|
|
|
+ " self.weight = nn.Parameter(torch.randn([out_features // 2, in_features]) * std)\n",
|
|
|
+ "\n",
|
|
|
+ " def forward(self, input):\n",
|
|
|
+ " f = 2 * math.pi * input @ self.weight.T\n",
|
|
|
+ " return torch.cat([f.cos(), f.sin()], dim=-1)\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "class SecondaryDiffusionImageNet(nn.Module):\n",
|
|
|
+ " def __init__(self):\n",
|
|
|
+ " super().__init__()\n",
|
|
|
+ " c = 64 # The base channel count\n",
|
|
|
+ "\n",
|
|
|
+ " self.timestep_embed = FourierFeatures(1, 16)\n",
|
|
|
+ "\n",
|
|
|
+ " self.net = nn.Sequential(\n",
|
|
|
+ " ConvBlock(3 + 16, c),\n",
|
|
|
+ " ConvBlock(c, c),\n",
|
|
|
+ " SkipBlock([\n",
|
|
|
+ " nn.AvgPool2d(2),\n",
|
|
|
+ " ConvBlock(c, c * 2),\n",
|
|
|
+ " ConvBlock(c * 2, c * 2),\n",
|
|
|
+ " SkipBlock([\n",
|
|
|
+ " nn.AvgPool2d(2),\n",
|
|
|
+ " ConvBlock(c * 2, c * 4),\n",
|
|
|
+ " ConvBlock(c * 4, c * 4),\n",
|
|
|
+ " SkipBlock([\n",
|
|
|
+ " nn.AvgPool2d(2),\n",
|
|
|
+ " ConvBlock(c * 4, c * 8),\n",
|
|
|
+ " ConvBlock(c * 8, c * 4),\n",
|
|
|
+ " nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),\n",
|
|
|
+ " ]),\n",
|
|
|
+ " ConvBlock(c * 8, c * 4),\n",
|
|
|
+ " ConvBlock(c * 4, c * 2),\n",
|
|
|
+ " nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),\n",
|
|
|
+ " ]),\n",
|
|
|
+ " ConvBlock(c * 4, c * 2),\n",
|
|
|
+ " ConvBlock(c * 2, c),\n",
|
|
|
+ " nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),\n",
|
|
|
+ " ]),\n",
|
|
|
+ " ConvBlock(c * 2, c),\n",
|
|
|
+ " nn.Conv2d(c, 3, 3, padding=1),\n",
|
|
|
+ " )\n",
|
|
|
+ "\n",
|
|
|
+ " def forward(self, input, t):\n",
|
|
|
+ " timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape)\n",
|
|
|
+ " v = self.net(torch.cat([input, timestep_embed], dim=1))\n",
|
|
|
+ " alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t))\n",
|
|
|
+ " pred = input * alphas - v * sigmas\n",
|
|
|
+ " eps = input * sigmas + v * alphas\n",
|
|
|
+ " return DiffusionOutput(v, pred, eps)\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "class SecondaryDiffusionImageNet2(nn.Module):\n",
|
|
|
+ " def __init__(self):\n",
|
|
|
+ " super().__init__()\n",
|
|
|
+ " c = 64 # The base channel count\n",
|
|
|
+ " cs = [c, c * 2, c * 2, c * 4, c * 4, c * 8]\n",
|
|
|
+ "\n",
|
|
|
+ " self.timestep_embed = FourierFeatures(1, 16)\n",
|
|
|
+ " self.down = nn.AvgPool2d(2)\n",
|
|
|
+ " self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)\n",
|
|
|
+ "\n",
|
|
|
+ " self.net = nn.Sequential(\n",
|
|
|
+ " ConvBlock(3 + 16, cs[0]),\n",
|
|
|
+ " ConvBlock(cs[0], cs[0]),\n",
|
|
|
+ " SkipBlock([\n",
|
|
|
+ " self.down,\n",
|
|
|
+ " ConvBlock(cs[0], cs[1]),\n",
|
|
|
+ " ConvBlock(cs[1], cs[1]),\n",
|
|
|
+ " SkipBlock([\n",
|
|
|
+ " self.down,\n",
|
|
|
+ " ConvBlock(cs[1], cs[2]),\n",
|
|
|
+ " ConvBlock(cs[2], cs[2]),\n",
|
|
|
+ " SkipBlock([\n",
|
|
|
+ " self.down,\n",
|
|
|
+ " ConvBlock(cs[2], cs[3]),\n",
|
|
|
+ " ConvBlock(cs[3], cs[3]),\n",
|
|
|
+ " SkipBlock([\n",
|
|
|
+ " self.down,\n",
|
|
|
+ " ConvBlock(cs[3], cs[4]),\n",
|
|
|
+ " ConvBlock(cs[4], cs[4]),\n",
|
|
|
+ " SkipBlock([\n",
|
|
|
+ " self.down,\n",
|
|
|
+ " ConvBlock(cs[4], cs[5]),\n",
|
|
|
+ " ConvBlock(cs[5], cs[5]),\n",
|
|
|
+ " ConvBlock(cs[5], cs[5]),\n",
|
|
|
+ " ConvBlock(cs[5], cs[4]),\n",
|
|
|
+ " self.up,\n",
|
|
|
+ " ]),\n",
|
|
|
+ " ConvBlock(cs[4] * 2, cs[4]),\n",
|
|
|
+ " ConvBlock(cs[4], cs[3]),\n",
|
|
|
+ " self.up,\n",
|
|
|
+ " ]),\n",
|
|
|
+ " ConvBlock(cs[3] * 2, cs[3]),\n",
|
|
|
+ " ConvBlock(cs[3], cs[2]),\n",
|
|
|
+ " self.up,\n",
|
|
|
+ " ]),\n",
|
|
|
+ " ConvBlock(cs[2] * 2, cs[2]),\n",
|
|
|
+ " ConvBlock(cs[2], cs[1]),\n",
|
|
|
+ " self.up,\n",
|
|
|
+ " ]),\n",
|
|
|
+ " ConvBlock(cs[1] * 2, cs[1]),\n",
|
|
|
+ " ConvBlock(cs[1], cs[0]),\n",
|
|
|
+ " self.up,\n",
|
|
|
+ " ]),\n",
|
|
|
+ " ConvBlock(cs[0] * 2, cs[0]),\n",
|
|
|
+ " nn.Conv2d(cs[0], 3, 3, padding=1),\n",
|
|
|
+ " )\n",
|
|
|
+ "\n",
|
|
|
+ " def forward(self, input, t):\n",
|
|
|
+ " timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape)\n",
|
|
|
+ " v = self.net(torch.cat([input, timestep_embed], dim=1))\n",
|
|
|
+ " alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t))\n",
|
|
|
+ " pred = input * alphas - v * sigmas\n",
|
|
|
+ " eps = input * sigmas + v * alphas\n",
|
|
|
+ " return DiffusionOutput(v, pred, eps)"
|
|
|
+ ],
|
|
|
+ "outputs": [],
|
|
|
+ "execution_count": null
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "metadata": {
|
|
|
+ "cellView": "form",
|
|
|
+ "id": "DefSuperRes"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "#@title 1.7 SuperRes Define\n",
|
|
|
+ "class DDIMSampler(object):\n",
|
|
|
+ " def __init__(self, model, schedule=\"linear\", **kwargs):\n",
|
|
|
+ " super().__init__()\n",
|
|
|
+ " self.model = model\n",
|
|
|
+ " self.ddpm_num_timesteps = model.num_timesteps\n",
|
|
|
+ " self.schedule = schedule\n",
|
|
|
+ "\n",
|
|
|
+ " def register_buffer(self, name, attr):\n",
|
|
|
+ " if type(attr) == torch.Tensor:\n",
|
|
|
+ " if attr.device != torch.device(\"cuda\"):\n",
|
|
|
+ " attr = attr.to(torch.device(\"cuda\"))\n",
|
|
|
+ " setattr(self, name, attr)\n",
|
|
|
+ "\n",
|
|
|
+ " def make_schedule(self, ddim_num_steps, ddim_discretize=\"uniform\", ddim_eta=0., verbose=True):\n",
|
|
|
+ " self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,\n",
|
|
|
+ " num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)\n",
|
|
|
+ " alphas_cumprod = self.model.alphas_cumprod\n",
|
|
|
+ " assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'\n",
|
|
|
+ " to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)\n",
|
|
|
+ "\n",
|
|
|
+ " self.register_buffer('betas', to_torch(self.model.betas))\n",
|
|
|
+ " self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))\n",
|
|
|
+ " self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))\n",
|
|
|
+ "\n",
|
|
|
+ " # calculations for diffusion q(x_t | x_{t-1}) and others\n",
|
|
|
+ " self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))\n",
|
|
|
+ " self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))\n",
|
|
|
+ " self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))\n",
|
|
|
+ " self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))\n",
|
|
|
+ " self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))\n",
|
|
|
+ "\n",
|
|
|
+ " # ddim sampling parameters\n",
|
|
|
+ " ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),\n",
|
|
|
+ " ddim_timesteps=self.ddim_timesteps,\n",
|
|
|
+ " eta=ddim_eta,verbose=verbose)\n",
|
|
|
+ " self.register_buffer('ddim_sigmas', ddim_sigmas)\n",
|
|
|
+ " self.register_buffer('ddim_alphas', ddim_alphas)\n",
|
|
|
+ " self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)\n",
|
|
|
+ " self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))\n",
|
|
|
+ " sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(\n",
|
|
|
+ " (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (\n",
|
|
|
+ " 1 - self.alphas_cumprod / self.alphas_cumprod_prev))\n",
|
|
|
+ " self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)\n",
|
|
|
+ "\n",
|
|
|
+ " @torch.no_grad()\n",
|
|
|
+ " def sample(self,\n",
|
|
|
+ " S,\n",
|
|
|
+ " batch_size,\n",
|
|
|
+ " shape,\n",
|
|
|
+ " conditioning=None,\n",
|
|
|
+ " callback=None,\n",
|
|
|
+ " normals_sequence=None,\n",
|
|
|
+ " img_callback=None,\n",
|
|
|
+ " quantize_x0=False,\n",
|
|
|
+ " eta=0.,\n",
|
|
|
+ " mask=None,\n",
|
|
|
+ " x0=None,\n",
|
|
|
+ " temperature=1.,\n",
|
|
|
+ " noise_dropout=0.,\n",
|
|
|
+ " score_corrector=None,\n",
|
|
|
+ " corrector_kwargs=None,\n",
|
|
|
+ " verbose=True,\n",
|
|
|
+ " x_T=None,\n",
|
|
|
+ " log_every_t=100,\n",
|
|
|
+ " **kwargs\n",
|
|
|
+ " ):\n",
|
|
|
+ " if conditioning is not None:\n",
|
|
|
+ " if isinstance(conditioning, dict):\n",
|
|
|
+ " cbs = conditioning[list(conditioning.keys())[0]].shape[0]\n",
|
|
|
+ " if cbs != batch_size:\n",
|
|
|
+ " print(f\"Warning: Got {cbs} conditionings but batch-size is {batch_size}\")\n",
|
|
|
+ " else:\n",
|
|
|
+ " if conditioning.shape[0] != batch_size:\n",
|
|
|
+ " print(f\"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}\")\n",
|
|
|
+ "\n",
|
|
|
+ " self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)\n",
|
|
|
+ " # sampling\n",
|
|
|
+ " C, H, W = shape\n",
|
|
|
+ " size = (batch_size, C, H, W)\n",
|
|
|
+ " # print(f'Data shape for DDIM sampling is {size}, eta {eta}')\n",
|
|
|
+ "\n",
|
|
|
+ " samples, intermediates = self.ddim_sampling(conditioning, size,\n",
|
|
|
+ " callback=callback,\n",
|
|
|
+ " img_callback=img_callback,\n",
|
|
|
+ " quantize_denoised=quantize_x0,\n",
|
|
|
+ " mask=mask, x0=x0,\n",
|
|
|
+ " ddim_use_original_steps=False,\n",
|
|
|
+ " noise_dropout=noise_dropout,\n",
|
|
|
+ " temperature=temperature,\n",
|
|
|
+ " score_corrector=score_corrector,\n",
|
|
|
+ " corrector_kwargs=corrector_kwargs,\n",
|
|
|
+ " x_T=x_T,\n",
|
|
|
+ " log_every_t=log_every_t\n",
|
|
|
+ " )\n",
|
|
|
+ " return samples, intermediates\n",
|
|
|
+ "\n",
|
|
|
+ " @torch.no_grad()\n",
|
|
|
+ " def ddim_sampling(self, cond, shape,\n",
|
|
|
+ " x_T=None, ddim_use_original_steps=False,\n",
|
|
|
+ " callback=None, timesteps=None, quantize_denoised=False,\n",
|
|
|
+ " mask=None, x0=None, img_callback=None, log_every_t=100,\n",
|
|
|
+ " temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):\n",
|
|
|
+ " device = self.model.betas.device\n",
|
|
|
+ " b = shape[0]\n",
|
|
|
+ " if x_T is None:\n",
|
|
|
+ " img = torch.randn(shape, device=device)\n",
|
|
|
+ " else:\n",
|
|
|
+ " img = x_T\n",
|
|
|
+ "\n",
|
|
|
+ " if timesteps is None:\n",
|
|
|
+ " timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps\n",
|
|
|
+ " elif timesteps is not None and not ddim_use_original_steps:\n",
|
|
|
+ " subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1\n",
|
|
|
+ " timesteps = self.ddim_timesteps[:subset_end]\n",
|
|
|
+ "\n",
|
|
|
+ " intermediates = {'x_inter': [img], 'pred_x0': [img]}\n",
|
|
|
+ " time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)\n",
|
|
|
+ " total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]\n",
|
|
|
+ " print(f\"Running DDIM Sharpening with {total_steps} timesteps\")\n",
|
|
|
+ "\n",
|
|
|
+ " iterator = tqdm(time_range, desc='DDIM Sharpening', total=total_steps)\n",
|
|
|
+ "\n",
|
|
|
+ " for i, step in enumerate(iterator):\n",
|
|
|
+ " index = total_steps - i - 1\n",
|
|
|
+ " ts = torch.full((b,), step, device=device, dtype=torch.long)\n",
|
|
|
+ "\n",
|
|
|
+ " if mask is not None:\n",
|
|
|
+ " assert x0 is not None\n",
|
|
|
+ " img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?\n",
|
|
|
+ " img = img_orig * mask + (1. - mask) * img\n",
|
|
|
+ "\n",
|
|
|
+ " outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,\n",
|
|
|
+ " quantize_denoised=quantize_denoised, temperature=temperature,\n",
|
|
|
+ " noise_dropout=noise_dropout, score_corrector=score_corrector,\n",
|
|
|
+ " corrector_kwargs=corrector_kwargs)\n",
|
|
|
+ " img, pred_x0 = outs\n",
|
|
|
+ " if callback: callback(i)\n",
|
|
|
+ " if img_callback: img_callback(pred_x0, i)\n",
|
|
|
+ "\n",
|
|
|
+ " if index % log_every_t == 0 or index == total_steps - 1:\n",
|
|
|
+ " intermediates['x_inter'].append(img)\n",
|
|
|
+ " intermediates['pred_x0'].append(pred_x0)\n",
|
|
|
+ "\n",
|
|
|
+ " return img, intermediates\n",
|
|
|
+ "\n",
|
|
|
+ " @torch.no_grad()\n",
|
|
|
+ " def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,\n",
|
|
|
+ " temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):\n",
|
|
|
+ " b, *_, device = *x.shape, x.device\n",
|
|
|
+ " e_t = self.model.apply_model(x, t, c)\n",
|
|
|
+ " if score_corrector is not None:\n",
|
|
|
+ " assert self.model.parameterization == \"eps\"\n",
|
|
|
+ " e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)\n",
|
|
|
+ "\n",
|
|
|
+ " alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas\n",
|
|
|
+ " alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev\n",
|
|
|
+ " sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas\n",
|
|
|
+ " sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas\n",
|
|
|
+ " # select parameters corresponding to the currently considered timestep\n",
|
|
|
+ " a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)\n",
|
|
|
+ " a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)\n",
|
|
|
+ " sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)\n",
|
|
|
+ " sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)\n",
|
|
|
+ "\n",
|
|
|
+ " # current prediction for x_0\n",
|
|
|
+ " pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()\n",
|
|
|
+ " if quantize_denoised:\n",
|
|
|
+ " pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)\n",
|
|
|
+ " # direction pointing to x_t\n",
|
|
|
+ " dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t\n",
|
|
|
+ " noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature\n",
|
|
|
+ " if noise_dropout > 0.:\n",
|
|
|
+ " noise = torch.nn.functional.dropout(noise, p=noise_dropout)\n",
|
|
|
+ " x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise\n",
|
|
|
+ " return x_prev, pred_x0\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "def download_models(mode):\n",
|
|
|
+ "\n",
|
|
|
+ " if mode == \"superresolution\":\n",
|
|
|
+ " # this is the small bsr light model\n",
|
|
|
+ " url_conf = 'https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1'\n",
|
|
|
+ " url_ckpt = 'https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1'\n",
|
|
|
+ "\n",
|
|
|
+ " path_conf = f'{model_path}/superres/project.yaml'\n",
|
|
|
+ " path_ckpt = f'{model_path}/superres/last.ckpt'\n",
|
|
|
+ "\n",
|
|
|
+ " download_url(url_conf, path_conf)\n",
|
|
|
+ " download_url(url_ckpt, path_ckpt)\n",
|
|
|
+ "\n",
|
|
|
+ " path_conf = path_conf + '/?dl=1' # fix it\n",
|
|
|
+ " path_ckpt = path_ckpt + '/?dl=1' # fix it\n",
|
|
|
+ " return path_conf, path_ckpt\n",
|
|
|
+ "\n",
|
|
|
+ " else:\n",
|
|
|
+ " raise NotImplementedError\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "def load_model_from_config(config, ckpt):\n",
|
|
|
+ " print(f\"Loading model from {ckpt}\")\n",
|
|
|
+ " pl_sd = torch.load(ckpt, map_location=\"cpu\")\n",
|
|
|
+ " global_step = pl_sd[\"global_step\"]\n",
|
|
|
+ " sd = pl_sd[\"state_dict\"]\n",
|
|
|
+ " model = instantiate_from_config(config.model)\n",
|
|
|
+ " m, u = model.load_state_dict(sd, strict=False)\n",
|
|
|
+ " model.cuda()\n",
|
|
|
+ " model.eval()\n",
|
|
|
+ " return {\"model\": model}, global_step\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "def get_model(mode):\n",
|
|
|
+ " path_conf, path_ckpt = download_models(mode)\n",
|
|
|
+ " config = OmegaConf.load(path_conf)\n",
|
|
|
+ " model, step = load_model_from_config(config, path_ckpt)\n",
|
|
|
+ " return model\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "def get_custom_cond(mode):\n",
|
|
|
+ " dest = \"data/example_conditioning\"\n",
|
|
|
+ "\n",
|
|
|
+ " if mode == \"superresolution\":\n",
|
|
|
+ " uploaded_img = files.upload()\n",
|
|
|
+ " filename = next(iter(uploaded_img))\n",
|
|
|
+ " name, filetype = filename.split(\".\") # todo assumes just one dot in name !\n",
|
|
|
+ " os.rename(f\"{filename}\", f\"{dest}/{mode}/custom_{name}.{filetype}\")\n",
|
|
|
+ "\n",
|
|
|
+ " elif mode == \"text_conditional\":\n",
|
|
|
+ " w = widgets.Text(value='A cake with cream!', disabled=True)\n",
|
|
|
+ " display.display(w)\n",
|
|
|
+ "\n",
|
|
|
+ " with open(f\"{dest}/{mode}/custom_{w.value[:20]}.txt\", 'w') as f:\n",
|
|
|
+ " f.write(w.value)\n",
|
|
|
+ "\n",
|
|
|
+ " elif mode == \"class_conditional\":\n",
|
|
|
+ " w = widgets.IntSlider(min=0, max=1000)\n",
|
|
|
+ " display.display(w)\n",
|
|
|
+ " with open(f\"{dest}/{mode}/custom.txt\", 'w') as f:\n",
|
|
|
+ " f.write(w.value)\n",
|
|
|
+ "\n",
|
|
|
+ " else:\n",
|
|
|
+ " raise NotImplementedError(f\"cond not implemented for mode{mode}\")\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "def get_cond_options(mode):\n",
|
|
|
+ " path = \"data/example_conditioning\"\n",
|
|
|
+ " path = os.path.join(path, mode)\n",
|
|
|
+ " onlyfiles = [f for f in sorted(os.listdir(path))]\n",
|
|
|
+ " return path, onlyfiles\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "def select_cond_path(mode):\n",
|
|
|
+ " path = \"data/example_conditioning\" # todo\n",
|
|
|
+ " path = os.path.join(path, mode)\n",
|
|
|
+ " onlyfiles = [f for f in sorted(os.listdir(path))]\n",
|
|
|
+ "\n",
|
|
|
+ " selected = widgets.RadioButtons(\n",
|
|
|
+ " options=onlyfiles,\n",
|
|
|
+ " description='Select conditioning:',\n",
|
|
|
+ " disabled=False\n",
|
|
|
+ " )\n",
|
|
|
+ " display.display(selected)\n",
|
|
|
+ " selected_path = os.path.join(path, selected.value)\n",
|
|
|
+ " return selected_path\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "def get_cond(mode, img):\n",
|
|
|
+ " example = dict()\n",
|
|
|
+ " if mode == \"superresolution\":\n",
|
|
|
+ " up_f = 4\n",
|
|
|
+ " # visualize_cond_img(selected_path)\n",
|
|
|
+ "\n",
|
|
|
+ " c = img\n",
|
|
|
+ " c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)\n",
|
|
|
+ " c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], antialias=True)\n",
|
|
|
+ " c_up = rearrange(c_up, '1 c h w -> 1 h w c')\n",
|
|
|
+ " c = rearrange(c, '1 c h w -> 1 h w c')\n",
|
|
|
+ " c = 2. * c - 1.\n",
|
|
|
+ "\n",
|
|
|
+ " c = c.to(torch.device(\"cuda\"))\n",
|
|
|
+ " example[\"LR_image\"] = c\n",
|
|
|
+ " example[\"image\"] = c_up\n",
|
|
|
+ "\n",
|
|
|
+ " return example\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "def visualize_cond_img(path):\n",
|
|
|
+ " display.display(ipyimg(filename=path))\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "def sr_run(model, img, task, custom_steps, eta, resize_enabled=False, classifier_ckpt=None, global_step=None):\n",
|
|
|
+ " # global stride\n",
|
|
|
+ "\n",
|
|
|
+ " example = get_cond(task, img)\n",
|
|
|
+ "\n",
|
|
|
+ " save_intermediate_vid = False\n",
|
|
|
+ " n_runs = 1\n",
|
|
|
+ " masked = False\n",
|
|
|
+ " guider = None\n",
|
|
|
+ " ckwargs = None\n",
|
|
|
+ " mode = 'ddim'\n",
|
|
|
+ " ddim_use_x0_pred = False\n",
|
|
|
+ " temperature = 1.\n",
|
|
|
+ " eta = eta\n",
|
|
|
+ " make_progrow = True\n",
|
|
|
+ " custom_shape = None\n",
|
|
|
+ "\n",
|
|
|
+ " height, width = example[\"image\"].shape[1:3]\n",
|
|
|
+ " split_input = height >= 128 and width >= 128\n",
|
|
|
+ "\n",
|
|
|
+ " if split_input:\n",
|
|
|
+ " ks = 128\n",
|
|
|
+ " stride = 64\n",
|
|
|
+ " vqf = 4 #\n",
|
|
|
+ " model.split_input_params = {\"ks\": (ks, ks), \"stride\": (stride, stride),\n",
|
|
|
+ " \"vqf\": vqf,\n",
|
|
|
+ " \"patch_distributed_vq\": True,\n",
|
|
|
+ " \"tie_braker\": False,\n",
|
|
|
+ " \"clip_max_weight\": 0.5,\n",
|
|
|
+ " \"clip_min_weight\": 0.01,\n",
|
|
|
+ " \"clip_max_tie_weight\": 0.5,\n",
|
|
|
+ " \"clip_min_tie_weight\": 0.01}\n",
|
|
|
+ " else:\n",
|
|
|
+ " if hasattr(model, \"split_input_params\"):\n",
|
|
|
+ " delattr(model, \"split_input_params\")\n",
|
|
|
+ "\n",
|
|
|
+ " invert_mask = False\n",
|
|
|
+ "\n",
|
|
|
+ " x_T = None\n",
|
|
|
+ " for n in range(n_runs):\n",
|
|
|
+ " if custom_shape is not None:\n",
|
|
|
+ " x_T = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)\n",
|
|
|
+ " x_T = repeat(x_T, '1 c h w -> b c h w', b=custom_shape[0])\n",
|
|
|
+ "\n",
|
|
|
+ " logs = make_convolutional_sample(example, model,\n",
|
|
|
+ " mode=mode, custom_steps=custom_steps,\n",
|
|
|
+ " eta=eta, swap_mode=False , masked=masked,\n",
|
|
|
+ " invert_mask=invert_mask, quantize_x0=False,\n",
|
|
|
+ " custom_schedule=None, decode_interval=10,\n",
|
|
|
+ " resize_enabled=resize_enabled, custom_shape=custom_shape,\n",
|
|
|
+ " temperature=temperature, noise_dropout=0.,\n",
|
|
|
+ " corrector=guider, corrector_kwargs=ckwargs, x_T=x_T, save_intermediate_vid=save_intermediate_vid,\n",
|
|
|
+ " make_progrow=make_progrow,ddim_use_x0_pred=ddim_use_x0_pred\n",
|
|
|
+ " )\n",
|
|
|
+ " return logs\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "@torch.no_grad()\n",
|
|
|
+ "def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,\n",
|
|
|
+ " mask=None, x0=None, quantize_x0=False, img_callback=None,\n",
|
|
|
+ " temperature=1., noise_dropout=0., score_corrector=None,\n",
|
|
|
+ " corrector_kwargs=None, x_T=None, log_every_t=None\n",
|
|
|
+ " ):\n",
|
|
|
+ "\n",
|
|
|
+ " ddim = DDIMSampler(model)\n",
|
|
|
+ " bs = shape[0] # dont know where this comes from but wayne\n",
|
|
|
+ " shape = shape[1:] # cut batch dim\n",
|
|
|
+ " # print(f\"Sampling with eta = {eta}; steps: {steps}\")\n",
|
|
|
+ " samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,\n",
|
|
|
+ " normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,\n",
|
|
|
+ " mask=mask, x0=x0, temperature=temperature, verbose=False,\n",
|
|
|
+ " score_corrector=score_corrector,\n",
|
|
|
+ " corrector_kwargs=corrector_kwargs, x_T=x_T)\n",
|
|
|
+ "\n",
|
|
|
+ " return samples, intermediates\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "@torch.no_grad()\n",
|
|
|
+ "def make_convolutional_sample(batch, model, mode=\"vanilla\", custom_steps=None, eta=1.0, swap_mode=False, masked=False,\n",
|
|
|
+ " invert_mask=True, quantize_x0=False, custom_schedule=None, decode_interval=1000,\n",
|
|
|
+ " resize_enabled=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,\n",
|
|
|
+ " corrector_kwargs=None, x_T=None, save_intermediate_vid=False, make_progrow=True,ddim_use_x0_pred=False):\n",
|
|
|
+ " log = dict()\n",
|
|
|
+ "\n",
|
|
|
+ " z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,\n",
|
|
|
+ " return_first_stage_outputs=True,\n",
|
|
|
+ " force_c_encode=not (hasattr(model, 'split_input_params')\n",
|
|
|
+ " and model.cond_stage_key == 'coordinates_bbox'),\n",
|
|
|
+ " return_original_cond=True)\n",
|
|
|
+ "\n",
|
|
|
+ " log_every_t = 1 if save_intermediate_vid else None\n",
|
|
|
+ "\n",
|
|
|
+ " if custom_shape is not None:\n",
|
|
|
+ " z = torch.randn(custom_shape)\n",
|
|
|
+ " # print(f\"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}\")\n",
|
|
|
+ "\n",
|
|
|
+ " z0 = None\n",
|
|
|
+ "\n",
|
|
|
+ " log[\"input\"] = x\n",
|
|
|
+ " log[\"reconstruction\"] = xrec\n",
|
|
|
+ "\n",
|
|
|
+ " if ismap(xc):\n",
|
|
|
+ " log[\"original_conditioning\"] = model.to_rgb(xc)\n",
|
|
|
+ " if hasattr(model, 'cond_stage_key'):\n",
|
|
|
+ " log[model.cond_stage_key] = model.to_rgb(xc)\n",
|
|
|
+ "\n",
|
|
|
+ " else:\n",
|
|
|
+ " log[\"original_conditioning\"] = xc if xc is not None else torch.zeros_like(x)\n",
|
|
|
+ " if model.cond_stage_model:\n",
|
|
|
+ " log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)\n",
|
|
|
+ " if model.cond_stage_key =='class_label':\n",
|
|
|
+ " log[model.cond_stage_key] = xc[model.cond_stage_key]\n",
|
|
|
+ "\n",
|
|
|
+ " with model.ema_scope(\"Plotting\"):\n",
|
|
|
+ " t0 = time.time()\n",
|
|
|
+ " img_cb = None\n",
|
|
|
+ "\n",
|
|
|
+ " sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,\n",
|
|
|
+ " eta=eta,\n",
|
|
|
+ " quantize_x0=quantize_x0, img_callback=img_cb, mask=None, x0=z0,\n",
|
|
|
+ " temperature=temperature, noise_dropout=noise_dropout,\n",
|
|
|
+ " score_corrector=corrector, corrector_kwargs=corrector_kwargs,\n",
|
|
|
+ " x_T=x_T, log_every_t=log_every_t)\n",
|
|
|
+ " t1 = time.time()\n",
|
|
|
+ "\n",
|
|
|
+ " if ddim_use_x0_pred:\n",
|
|
|
+ " sample = intermediates['pred_x0'][-1]\n",
|
|
|
+ "\n",
|
|
|
+ " x_sample = model.decode_first_stage(sample)\n",
|
|
|
+ "\n",
|
|
|
+ " try:\n",
|
|
|
+ " x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)\n",
|
|
|
+ " log[\"sample_noquant\"] = x_sample_noquant\n",
|
|
|
+ " log[\"sample_diff\"] = torch.abs(x_sample_noquant - x_sample)\n",
|
|
|
+ " except:\n",
|
|
|
+ " pass\n",
|
|
|
+ "\n",
|
|
|
+ " log[\"sample\"] = x_sample\n",
|
|
|
+ " log[\"time\"] = t1 - t0\n",
|
|
|
+ "\n",
|
|
|
+ " return log\n",
|
|
|
+ "\n",
|
|
|
+ "sr_diffMode = 'superresolution'\n",
|
|
|
+ "sr_model = get_model('superresolution')\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "def do_superres(img, filepath):\n",
|
|
|
+ "\n",
|
|
|
+ " if args.sharpen_preset == 'Faster':\n",
|
|
|
+ " sr_diffusion_steps = \"25\" \n",
|
|
|
+ " sr_pre_downsample = '1/2' \n",
|
|
|
+ " if args.sharpen_preset == 'Fast':\n",
|
|
|
+ " sr_diffusion_steps = \"100\" \n",
|
|
|
+ " sr_pre_downsample = '1/2' \n",
|
|
|
+ " if args.sharpen_preset == 'Slow':\n",
|
|
|
+ " sr_diffusion_steps = \"25\" \n",
|
|
|
+ " sr_pre_downsample = 'None' \n",
|
|
|
+ " if args.sharpen_preset == 'Very Slow':\n",
|
|
|
+ " sr_diffusion_steps = \"100\" \n",
|
|
|
+ " sr_pre_downsample = 'None' \n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ " sr_post_downsample = 'Original Size'\n",
|
|
|
+ " sr_diffusion_steps = int(sr_diffusion_steps)\n",
|
|
|
+ " sr_eta = 1.0 \n",
|
|
|
+ " sr_downsample_method = 'Lanczos' \n",
|
|
|
+ "\n",
|
|
|
+ " gc.collect()\n",
|
|
|
+ " torch.cuda.empty_cache()\n",
|
|
|
+ "\n",
|
|
|
+ " im_og = img\n",
|
|
|
+ " width_og, height_og = im_og.size\n",
|
|
|
+ "\n",
|
|
|
+ " #Downsample Pre\n",
|
|
|
+ " if sr_pre_downsample == '1/2':\n",
|
|
|
+ " downsample_rate = 2\n",
|
|
|
+ " elif sr_pre_downsample == '1/4':\n",
|
|
|
+ " downsample_rate = 4\n",
|
|
|
+ " else:\n",
|
|
|
+ " downsample_rate = 1\n",
|
|
|
+ "\n",
|
|
|
+ " width_downsampled_pre = width_og//downsample_rate\n",
|
|
|
+ " height_downsampled_pre = height_og//downsample_rate\n",
|
|
|
+ "\n",
|
|
|
+ " if downsample_rate != 1:\n",
|
|
|
+ " # print(f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')\n",
|
|
|
+ " im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)\n",
|
|
|
+ " # im_og.save('/content/temp.png')\n",
|
|
|
+ " # filepath = '/content/temp.png'\n",
|
|
|
+ "\n",
|
|
|
+ " logs = sr_run(sr_model[\"model\"], im_og, sr_diffMode, sr_diffusion_steps, sr_eta)\n",
|
|
|
+ "\n",
|
|
|
+ " sample = logs[\"sample\"]\n",
|
|
|
+ " sample = sample.detach().cpu()\n",
|
|
|
+ " sample = torch.clamp(sample, -1., 1.)\n",
|
|
|
+ " sample = (sample + 1.) / 2. * 255\n",
|
|
|
+ " sample = sample.numpy().astype(np.uint8)\n",
|
|
|
+ " sample = np.transpose(sample, (0, 2, 3, 1))\n",
|
|
|
+ " a = Image.fromarray(sample[0])\n",
|
|
|
+ "\n",
|
|
|
+ " #Downsample Post\n",
|
|
|
+ " if sr_post_downsample == '1/2':\n",
|
|
|
+ " downsample_rate = 2\n",
|
|
|
+ " elif sr_post_downsample == '1/4':\n",
|
|
|
+ " downsample_rate = 4\n",
|
|
|
+ " else:\n",
|
|
|
+ " downsample_rate = 1\n",
|
|
|
+ "\n",
|
|
|
+ " width, height = a.size\n",
|
|
|
+ " width_downsampled_post = width//downsample_rate\n",
|
|
|
+ " height_downsampled_post = height//downsample_rate\n",
|
|
|
+ "\n",
|
|
|
+ " if sr_downsample_method == 'Lanczos':\n",
|
|
|
+ " aliasing = Image.LANCZOS\n",
|
|
|
+ " else:\n",
|
|
|
+ " aliasing = Image.NEAREST\n",
|
|
|
+ "\n",
|
|
|
+ " if downsample_rate != 1:\n",
|
|
|
+ " # print(f'Downsampling from [{width}, {height}] to [{width_downsampled_post}, {height_downsampled_post}]')\n",
|
|
|
+ " a = a.resize((width_downsampled_post, height_downsampled_post), aliasing)\n",
|
|
|
+ " elif sr_post_downsample == 'Original Size':\n",
|
|
|
+ " # print(f'Downsampling from [{width}, {height}] to Original Size [{width_og}, {height_og}]')\n",
|
|
|
+ " a = a.resize((width_og, height_og), aliasing)\n",
|
|
|
+ "\n",
|
|
|
+ " display.display(a)\n",
|
|
|
+ " a.save(filepath)\n",
|
|
|
+ " return\n",
|
|
|
+ " print(f'Processing finished!')\n"
|
|
|
+ ],
|
|
|
+ "outputs": [],
|
|
|
+ "execution_count": null
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {
|
|
|
+ "id": "DiffClipSetTop"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "# 2. Diffusion and CLIP model settings"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "metadata": {
|
|
|
+ "id": "ModelSettings"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "#@markdown ####**Models Settings:**\n",
|
|
|
+ "diffusion_model = \"512x512_diffusion_uncond_finetune_008100\" #@param [\"256x256_diffusion_uncond\", \"512x512_diffusion_uncond_finetune_008100\"]\n",
|
|
|
+ "use_secondary_model = True #@param {type: 'boolean'}\n",
|
|
|
+ "diffusion_sampling_mode = 'ddim' #@param ['plms','ddim'] \n",
|
|
|
+ "\n",
|
|
|
+ "timestep_respacing = '250' #@param ['25','50','100','150','250','500','1000','ddim25','ddim50', 'ddim75', 'ddim100','ddim150','ddim250','ddim500','ddim1000'] \n",
|
|
|
+ "diffusion_steps = 300 #@param {type: 'number'}\n",
|
|
|
+ "use_checkpoint = True #@param {type: 'boolean'}\n",
|
|
|
+ "ViTB32 = True #@param{type:\"boolean\"}\n",
|
|
|
+ "ViTB16 = True #@param{type:\"boolean\"}\n",
|
|
|
+ "ViTL14 = False #@param{type:\"boolean\"}\n",
|
|
|
+ "RN101 = False #@param{type:\"boolean\"}\n",
|
|
|
+ "RN50 = True #@param{type:\"boolean\"}\n",
|
|
|
+ "RN50x4 = False #@param{type:\"boolean\"}\n",
|
|
|
+ "RN50x16 = False #@param{type:\"boolean\"}\n",
|
|
|
+ "RN50x64 = False #@param{type:\"boolean\"}\n",
|
|
|
+ "SLIPB16 = False #@param{type:\"boolean\"}\n",
|
|
|
+ "SLIPL16 = False #@param{type:\"boolean\"}\n",
|
|
|
+ "\n",
|
|
|
+ "#@markdown If you're having issues with model downloads, check this to compare SHA's:\n",
|
|
|
+ "check_model_SHA = False #@param{type:\"boolean\"}\n",
|
|
|
+ "\n",
|
|
|
+ "model_256_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a'\n",
|
|
|
+ "model_512_SHA = '9c111ab89e214862b76e1fa6a1b3f1d329b1a88281885943d2cdbe357ad57648'\n",
|
|
|
+ "model_secondary_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a'\n",
|
|
|
+ "\n",
|
|
|
+ "model_256_link = 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt'\n",
|
|
|
+ "model_512_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/512x512_diffusion_uncond_finetune_008100.pt'\n",
|
|
|
+ "model_secondary_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/secondary_model_imagenet_2.pth'\n",
|
|
|
+ "\n",
|
|
|
+ "model_256_path = f'{model_path}/256x256_diffusion_uncond.pt'\n",
|
|
|
+ "model_512_path = f'{model_path}/512x512_diffusion_uncond_finetune_008100.pt'\n",
|
|
|
+ "model_secondary_path = f'{model_path}/secondary_model_imagenet_2.pth'\n",
|
|
|
+ "\n",
|
|
|
+ "# Download the diffusion model\n",
|
|
|
+ "if diffusion_model == '256x256_diffusion_uncond':\n",
|
|
|
+ " if os.path.exists(model_256_path) and check_model_SHA:\n",
|
|
|
+ " print('Checking 256 Diffusion File')\n",
|
|
|
+ " with open(model_256_path,\"rb\") as f:\n",
|
|
|
+ " bytes = f.read() \n",
|
|
|
+ " hash = hashlib.sha256(bytes).hexdigest();\n",
|
|
|
+ " if hash == model_256_SHA:\n",
|
|
|
+ " print('256 Model SHA matches')\n",
|
|
|
+ " model_256_downloaded = True\n",
|
|
|
+ " else: \n",
|
|
|
+ " print(\"256 Model SHA doesn't match, redownloading...\")\n",
|
|
|
+ " wget(model_256_link, model_path)\n",
|
|
|
+ " model_256_downloaded = True\n",
|
|
|
+ " elif os.path.exists(model_256_path) and not check_model_SHA or model_256_downloaded == True:\n",
|
|
|
+ " print('256 Model already downloaded, check check_model_SHA if the file is corrupt')\n",
|
|
|
+ " else: \n",
|
|
|
+ " wget(model_256_link, model_path)\n",
|
|
|
+ " model_256_downloaded = True\n",
|
|
|
+ "elif diffusion_model == '512x512_diffusion_uncond_finetune_008100':\n",
|
|
|
+ " if os.path.exists(model_512_path) and check_model_SHA:\n",
|
|
|
+ " print('Checking 512 Diffusion File')\n",
|
|
|
+ " with open(model_512_path,\"rb\") as f:\n",
|
|
|
+ " bytes = f.read() \n",
|
|
|
+ " hash = hashlib.sha256(bytes).hexdigest();\n",
|
|
|
+ " if hash == model_512_SHA:\n",
|
|
|
+ " print('512 Model SHA matches')\n",
|
|
|
+ " model_512_downloaded = True\n",
|
|
|
+ " else: \n",
|
|
|
+ " print(\"512 Model SHA doesn't match, redownloading...\")\n",
|
|
|
+ " wget(model_512_link, model_path)\n",
|
|
|
+ " model_512_downloaded = True\n",
|
|
|
+ " elif os.path.exists(model_512_path) and not check_model_SHA or model_512_downloaded == True:\n",
|
|
|
+ " print('512 Model already downloaded, check check_model_SHA if the file is corrupt')\n",
|
|
|
+ " else: \n",
|
|
|
+ " wget(model_512_link, model_path)\n",
|
|
|
+ " model_512_downloaded = True\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "# Download the secondary diffusion model v2\n",
|
|
|
+ "if use_secondary_model == True:\n",
|
|
|
+ " if os.path.exists(model_secondary_path) and check_model_SHA:\n",
|
|
|
+ " print('Checking Secondary Diffusion File')\n",
|
|
|
+ " with open(model_secondary_path,\"rb\") as f:\n",
|
|
|
+ " bytes = f.read() \n",
|
|
|
+ " hash = hashlib.sha256(bytes).hexdigest();\n",
|
|
|
+ " if hash == model_secondary_SHA:\n",
|
|
|
+ " print('Secondary Model SHA matches')\n",
|
|
|
+ " model_secondary_downloaded = True\n",
|
|
|
+ " else: \n",
|
|
|
+ " print(\"Secondary Model SHA doesn't match, redownloading...\")\n",
|
|
|
+ " wget(model_secondary_link, model_path)\n",
|
|
|
+ " model_secondary_downloaded = True\n",
|
|
|
+ " elif os.path.exists(model_secondary_path) and not check_model_SHA or model_secondary_downloaded == True:\n",
|
|
|
+ " print('Secondary Model already downloaded, check check_model_SHA if the file is corrupt')\n",
|
|
|
+ " else: \n",
|
|
|
+ " wget(model_secondary_link, model_path)\n",
|
|
|
+ " model_secondary_downloaded = True\n",
|
|
|
+ "\n",
|
|
|
+ "model_config = model_and_diffusion_defaults()\n",
|
|
|
+ "if diffusion_model == '512x512_diffusion_uncond_finetune_008100':\n",
|
|
|
+ " model_config.update({\n",
|
|
|
+ " 'attention_resolutions': '32, 16, 8',\n",
|
|
|
+ " 'class_cond': False,\n",
|
|
|
+ " 'diffusion_steps': diffusion_steps,\n",
|
|
|
+ " 'rescale_timesteps': True,\n",
|
|
|
+ " 'timestep_respacing': timestep_respacing,\n",
|
|
|
+ " 'image_size': 512,\n",
|
|
|
+ " 'learn_sigma': True,\n",
|
|
|
+ " 'noise_schedule': 'linear',\n",
|
|
|
+ " 'num_channels': 256,\n",
|
|
|
+ " 'num_head_channels': 64,\n",
|
|
|
+ " 'num_res_blocks': 2,\n",
|
|
|
+ " 'resblock_updown': True,\n",
|
|
|
+ " 'use_checkpoint': use_checkpoint,\n",
|
|
|
+ " 'use_fp16': True,\n",
|
|
|
+ " 'use_scale_shift_norm': True,\n",
|
|
|
+ " })\n",
|
|
|
+ "elif diffusion_model == '256x256_diffusion_uncond':\n",
|
|
|
+ " model_config.update({\n",
|
|
|
+ " 'attention_resolutions': '32, 16, 8',\n",
|
|
|
+ " 'class_cond': False,\n",
|
|
|
+ " 'diffusion_steps': diffusion_steps,\n",
|
|
|
+ " 'rescale_timesteps': True,\n",
|
|
|
+ " 'timestep_respacing': timestep_respacing,\n",
|
|
|
+ " 'image_size': 256,\n",
|
|
|
+ " 'learn_sigma': True,\n",
|
|
|
+ " 'noise_schedule': 'linear',\n",
|
|
|
+ " 'num_channels': 256,\n",
|
|
|
+ " 'num_head_channels': 64,\n",
|
|
|
+ " 'num_res_blocks': 2,\n",
|
|
|
+ " 'resblock_updown': True,\n",
|
|
|
+ " 'use_checkpoint': use_checkpoint,\n",
|
|
|
+ " 'use_fp16': True,\n",
|
|
|
+ " 'use_scale_shift_norm': True,\n",
|
|
|
+ " })\n",
|
|
|
+ "\n",
|
|
|
+ "secondary_model_ver = 2\n",
|
|
|
+ "model_default = model_config['image_size']\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "if secondary_model_ver == 2:\n",
|
|
|
+ " secondary_model = SecondaryDiffusionImageNet2()\n",
|
|
|
+ " secondary_model.load_state_dict(torch.load(f'{model_path}/secondary_model_imagenet_2.pth', map_location='cpu'))\n",
|
|
|
+ "secondary_model.eval().requires_grad_(False).to(device)\n",
|
|
|
+ "\n",
|
|
|
+ "clip_models = []\n",
|
|
|
+ "if ViTB32 is True: clip_models.append(clip.load('ViT-B/32', jit=False)[0].eval().requires_grad_(False).to(device)) \n",
|
|
|
+ "if ViTB16 is True: clip_models.append(clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device) ) \n",
|
|
|
+ "if ViTL14 is True: clip_models.append(clip.load('ViT-L/14', jit=False)[0].eval().requires_grad_(False).to(device) ) \n",
|
|
|
+ "if RN50 is True: clip_models.append(clip.load('RN50', jit=False)[0].eval().requires_grad_(False).to(device))\n",
|
|
|
+ "if RN50x4 is True: clip_models.append(clip.load('RN50x4', jit=False)[0].eval().requires_grad_(False).to(device)) \n",
|
|
|
+ "if RN50x16 is True: clip_models.append(clip.load('RN50x16', jit=False)[0].eval().requires_grad_(False).to(device)) \n",
|
|
|
+ "if RN50x64 is True: clip_models.append(clip.load('RN50x64', jit=False)[0].eval().requires_grad_(False).to(device)) \n",
|
|
|
+ "if RN101 is True: clip_models.append(clip.load('RN101', jit=False)[0].eval().requires_grad_(False).to(device)) \n",
|
|
|
+ "\n",
|
|
|
+ "if SLIPB16:\n",
|
|
|
+ " SLIPB16model = SLIP_VITB16(ssl_mlp_dim=4096, ssl_emb_dim=256)\n",
|
|
|
+ " if not os.path.exists(f'{model_path}/slip_base_100ep.pt'):\n",
|
|
|
+ " wget(\"https://dl.fbaipublicfiles.com/slip/slip_base_100ep.pt\", model_path)\n",
|
|
|
+ " sd = torch.load(f'{model_path}/slip_base_100ep.pt')\n",
|
|
|
+ " real_sd = {}\n",
|
|
|
+ " for k, v in sd['state_dict'].items():\n",
|
|
|
+ " real_sd['.'.join(k.split('.')[1:])] = v\n",
|
|
|
+ " del sd\n",
|
|
|
+ " SLIPB16model.load_state_dict(real_sd)\n",
|
|
|
+ " SLIPB16model.requires_grad_(False).eval().to(device)\n",
|
|
|
+ "\n",
|
|
|
+ " clip_models.append(SLIPB16model)\n",
|
|
|
+ "\n",
|
|
|
+ "if SLIPL16:\n",
|
|
|
+ " SLIPL16model = SLIP_VITL16(ssl_mlp_dim=4096, ssl_emb_dim=256)\n",
|
|
|
+ " if not os.path.exists(f'{model_path}/slip_large_100ep.pt'):\n",
|
|
|
+ " wget(\"https://dl.fbaipublicfiles.com/slip/slip_large_100ep.pt\", model_path)\n",
|
|
|
+ " sd = torch.load(f'{model_path}/slip_large_100ep.pt')\n",
|
|
|
+ " real_sd = {}\n",
|
|
|
+ " for k, v in sd['state_dict'].items():\n",
|
|
|
+ " real_sd['.'.join(k.split('.')[1:])] = v\n",
|
|
|
+ " del sd\n",
|
|
|
+ " SLIPL16model.load_state_dict(real_sd)\n",
|
|
|
+ " SLIPL16model.requires_grad_(False).eval().to(device)\n",
|
|
|
+ "\n",
|
|
|
+ " clip_models.append(SLIPL16model)\n",
|
|
|
+ "\n",
|
|
|
+ "normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])\n",
|
|
|
+ "lpips_model = lpips.LPIPS(net='vgg').to(device)\n"
|
|
|
+ ],
|
|
|
+ "outputs": [],
|
|
|
+ "execution_count": null
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {
|
|
|
+ "id": "SettingsTop"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "# 3. Settings"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "metadata": {
|
|
|
+ "id": "BasicSettings"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "#@markdown ####**Basic Settings:**\n",
|
|
|
+ "batch_name = 'new_House' #@param{type: 'string'}\n",
|
|
|
+ "steps = 300#@param [25,50,100,150,250,500,1000]{type: 'raw', allow-input: true}\n",
|
|
|
+ "width_height = [1280, 720]#@param{type: 'raw'}\n",
|
|
|
+ "clip_guidance_scale = 5000 #@param{type: 'number'}\n",
|
|
|
+ "tv_scale = 0#@param{type: 'number'}\n",
|
|
|
+ "range_scale = 150#@param{type: 'number'}\n",
|
|
|
+ "sat_scale = 0#@param{type: 'number'}\n",
|
|
|
+ "cutn_batches = 4#@param{type: 'number'}\n",
|
|
|
+ "skip_augs = False#@param{type: 'boolean'}\n",
|
|
|
+ "\n",
|
|
|
+ "#@markdown ---\n",
|
|
|
+ "\n",
|
|
|
+ "#@markdown ####**Init Settings:**\n",
|
|
|
+ "init_image = \"/content/drive/MyDrive/AI/Disco_Diffusion/init_images/xv_1_decoupe_noback.jpg\" #@param{type: 'string'}\n",
|
|
|
+ "init_scale = 1000 #@param{type: 'integer'}\n",
|
|
|
+ "skip_steps = 50 #@param{type: 'integer'}\n",
|
|
|
+ "#@markdown *Make sure you set skip_steps to ~50% of your steps if you want to use an init image.*\n",
|
|
|
+ "\n",
|
|
|
+ "#Get corrected sizes\n",
|
|
|
+ "side_x = (width_height[0]//64)*64;\n",
|
|
|
+ "side_y = (width_height[1]//64)*64;\n",
|
|
|
+ "if side_x != width_height[0] or side_y != width_height[1]:\n",
|
|
|
+ " print(f'Changing output size to {side_x}x{side_y}. Dimensions must by multiples of 64.')\n",
|
|
|
+ "\n",
|
|
|
+ "#Update Model Settings\n",
|
|
|
+ "timestep_respacing = f'ddim{steps}'\n",
|
|
|
+ "diffusion_steps = (1000//steps)*steps if steps < 1000 else steps\n",
|
|
|
+ "model_config.update({\n",
|
|
|
+ " 'timestep_respacing': timestep_respacing,\n",
|
|
|
+ " 'diffusion_steps': diffusion_steps,\n",
|
|
|
+ "})\n",
|
|
|
+ "\n",
|
|
|
+ "#Make folder for batch\n",
|
|
|
+ "batchFolder = f'{outDirPath}/{batch_name}'\n",
|
|
|
+ "createPath(batchFolder)\n"
|
|
|
+ ],
|
|
|
+ "outputs": [],
|
|
|
+ "execution_count": null
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {
|
|
|
+ "id": "AnimSetTop"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "### Animation Settings"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "metadata": {
|
|
|
+ "id": "AnimSettings"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "#@markdown ####**Animation Mode:**\n",
|
|
|
+ "animation_mode = 'None' #@param ['None', '2D', '3D', 'Video Input'] {type:'string'}\n",
|
|
|
+ "#@markdown *For animation, you probably want to turn `cutn_batches` to 1 to make it quicker.*\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "#@markdown ---\n",
|
|
|
+ "\n",
|
|
|
+ "#@markdown ####**Video Input Settings:**\n",
|
|
|
+ "if is_colab:\n",
|
|
|
+ " video_init_path = \"/content/training.mp4\" #@param {type: 'string'}\n",
|
|
|
+ "else:\n",
|
|
|
+ " video_init_path = \"training.mp4\" #@param {type: 'string'}\n",
|
|
|
+ "extract_nth_frame = 2 #@param {type: 'number'}\n",
|
|
|
+ "video_init_seed_continuity = True #@param {type: 'boolean'}\n",
|
|
|
+ "\n",
|
|
|
+ "if animation_mode == \"Video Input\":\n",
|
|
|
+ " if is_colab:\n",
|
|
|
+ " videoFramesFolder = f'/content/videoFrames'\n",
|
|
|
+ " else:\n",
|
|
|
+ " videoFramesFolder = f'videoFrames'\n",
|
|
|
+ " createPath(videoFramesFolder)\n",
|
|
|
+ " print(f\"Exporting Video Frames (1 every {extract_nth_frame})...\")\n",
|
|
|
+ " try:\n",
|
|
|
+ " for f in pathlib.Path(f'{videoFramesFolder}').glob('*.jpg'):\n",
|
|
|
+ " f.unlink()\n",
|
|
|
+ " except:\n",
|
|
|
+ " print('')\n",
|
|
|
+ " vf = f'\"select=not(mod(n\\,{extract_nth_frame}))\"'\n",
|
|
|
+ " subprocess.run(['ffmpeg', '-i', f'{video_init_path}', '-vf', f'{vf}', '-vsync', 'vfr', '-q:v', '2', '-loglevel', 'error', '-stats', f'{videoFramesFolder}/%04d.jpg'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
|
|
+ " #!ffmpeg -i {video_init_path} -vf {vf} -vsync vfr -q:v 2 -loglevel error -stats {videoFramesFolder}/%04d.jpg\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "#@markdown ---\n",
|
|
|
+ "\n",
|
|
|
+ "#@markdown ####**2D Animation Settings:**\n",
|
|
|
+ "#@markdown `zoom` is a multiplier of dimensions, 1 is no zoom.\n",
|
|
|
+ "#@markdown All rotations are provided in degrees.\n",
|
|
|
+ "\n",
|
|
|
+ "key_frames = True #@param {type:\"boolean\"}\n",
|
|
|
+ "max_frames = 10000#@param {type:\"number\"}\n",
|
|
|
+ "\n",
|
|
|
+ "if animation_mode == \"Video Input\":\n",
|
|
|
+ " max_frames = len(glob(f'{videoFramesFolder}/*.jpg'))\n",
|
|
|
+ "\n",
|
|
|
+ "interp_spline = 'Linear' #Do not change, currently will not look good. param ['Linear','Quadratic','Cubic']{type:\"string\"}\n",
|
|
|
+ "angle = \"0:(0)\"#@param {type:\"string\"}\n",
|
|
|
+ "zoom = \"0: (1), 10: (1.05)\"#@param {type:\"string\"}\n",
|
|
|
+ "translation_x = \"0: (0)\"#@param {type:\"string\"}\n",
|
|
|
+ "translation_y = \"0: (0)\"#@param {type:\"string\"}\n",
|
|
|
+ "translation_z = \"0: (10.0)\"#@param {type:\"string\"}\n",
|
|
|
+ "rotation_3d_x = \"0: (0)\"#@param {type:\"string\"}\n",
|
|
|
+ "rotation_3d_y = \"0: (0)\"#@param {type:\"string\"}\n",
|
|
|
+ "rotation_3d_z = \"0: (0)\"#@param {type:\"string\"}\n",
|
|
|
+ "midas_depth_model = \"dpt_large\"#@param {type:\"string\"}\n",
|
|
|
+ "midas_weight = 0.3#@param {type:\"number\"}\n",
|
|
|
+ "near_plane = 200#@param {type:\"number\"}\n",
|
|
|
+ "far_plane = 10000#@param {type:\"number\"}\n",
|
|
|
+ "fov = 40#@param {type:\"number\"}\n",
|
|
|
+ "padding_mode = 'border'#@param {type:\"string\"}\n",
|
|
|
+ "sampling_mode = 'bicubic'#@param {type:\"string\"}\n",
|
|
|
+ "\n",
|
|
|
+ "#======= TURBO MODE\n",
|
|
|
+ "#@markdown ---\n",
|
|
|
+ "#@markdown ####**Turbo Mode (3D anim only):**\n",
|
|
|
+ "#@markdown (Starts after frame 10,) skips diffusion steps and just uses depth map to warp images for skipped frames.\n",
|
|
|
+ "#@markdown Speeds up rendering by 2x-4x, and may improve image coherence between frames. frame_blend_mode smooths abrupt texture changes across 2 frames.\n",
|
|
|
+ "#@markdown For different settings tuned for Turbo Mode, refer to the original Disco-Turbo Github: https://github.com/zippy731/disco-diffusion-turbo\n",
|
|
|
+ "\n",
|
|
|
+ "turbo_mode = False #@param {type:\"boolean\"}\n",
|
|
|
+ "turbo_steps = \"3\" #@param [\"2\",\"3\",\"4\",\"5\",\"6\"] {type:\"string\"}\n",
|
|
|
+ "turbo_preroll = 10 # frames\n",
|
|
|
+ "\n",
|
|
|
+ "#insist turbo be used only w 3d anim.\n",
|
|
|
+ "if turbo_mode and animation_mode != '3D':\n",
|
|
|
+ " print('=====')\n",
|
|
|
+ " print('Turbo mode only available with 3D animations. Disabling Turbo.')\n",
|
|
|
+ " print('=====')\n",
|
|
|
+ " turbo_mode = False\n",
|
|
|
+ "\n",
|
|
|
+ "#@markdown ---\n",
|
|
|
+ "\n",
|
|
|
+ "#@markdown ####**Coherency Settings:**\n",
|
|
|
+ "#@markdown `frame_scale` tries to guide the new frame to looking like the old one. A good default is 1500.\n",
|
|
|
+ "frames_scale = 1500 #@param{type: 'integer'}\n",
|
|
|
+ "#@markdown `frame_skip_steps` will blur the previous frame - higher values will flicker less but struggle to add enough new detail to zoom into.\n",
|
|
|
+ "frames_skip_steps = '60%' #@param ['40%', '50%', '60%', '70%', '80%'] {type: 'string'}\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "def parse_key_frames(string, prompt_parser=None):\n",
|
|
|
+ " \"\"\"Given a string representing frame numbers paired with parameter values at that frame,\n",
|
|
|
+ " return a dictionary with the frame numbers as keys and the parameter values as the values.\n",
|
|
|
+ "\n",
|
|
|
+ " Parameters\n",
|
|
|
+ " ----------\n",
|
|
|
+ " string: string\n",
|
|
|
+ " Frame numbers paired with parameter values at that frame number, in the format\n",
|
|
|
+ " 'framenumber1: (parametervalues1), framenumber2: (parametervalues2), ...'\n",
|
|
|
+ " prompt_parser: function or None, optional\n",
|
|
|
+ " If provided, prompt_parser will be applied to each string of parameter values.\n",
|
|
|
+ " \n",
|
|
|
+ " Returns\n",
|
|
|
+ " -------\n",
|
|
|
+ " dict\n",
|
|
|
+ " Frame numbers as keys, parameter values at that frame number as values\n",
|
|
|
+ "\n",
|
|
|
+ " Raises\n",
|
|
|
+ " ------\n",
|
|
|
+ " RuntimeError\n",
|
|
|
+ " If the input string does not match the expected format.\n",
|
|
|
+ " \n",
|
|
|
+ " Examples\n",
|
|
|
+ " --------\n",
|
|
|
+ " >>> parse_key_frames(\"10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)\")\n",
|
|
|
+ " {10: 'Apple: 1| Orange: 0', 20: 'Apple: 0| Orange: 1| Peach: 1'}\n",
|
|
|
+ "\n",
|
|
|
+ " >>> parse_key_frames(\"10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)\", prompt_parser=lambda x: x.lower()))\n",
|
|
|
+ " {10: 'apple: 1| orange: 0', 20: 'apple: 0| orange: 1| peach: 1'}\n",
|
|
|
+ " \"\"\"\n",
|
|
|
+ " import re\n",
|
|
|
+ " pattern = r'((?P<frame>[0-9]+):[\\s]*[\\(](?P<param>[\\S\\s]*?)[\\)])'\n",
|
|
|
+ " frames = dict()\n",
|
|
|
+ " for match_object in re.finditer(pattern, string):\n",
|
|
|
+ " frame = int(match_object.groupdict()['frame'])\n",
|
|
|
+ " param = match_object.groupdict()['param']\n",
|
|
|
+ " if prompt_parser:\n",
|
|
|
+ " frames[frame] = prompt_parser(param)\n",
|
|
|
+ " else:\n",
|
|
|
+ " frames[frame] = param\n",
|
|
|
+ "\n",
|
|
|
+ " if frames == {} and len(string) != 0:\n",
|
|
|
+ " raise RuntimeError('Key Frame string not correctly formatted')\n",
|
|
|
+ " return frames\n",
|
|
|
+ "\n",
|
|
|
+ "def get_inbetweens(key_frames, integer=False):\n",
|
|
|
+ " \"\"\"Given a dict with frame numbers as keys and a parameter value as values,\n",
|
|
|
+ " return a pandas Series containing the value of the parameter at every frame from 0 to max_frames.\n",
|
|
|
+ " Any values not provided in the input dict are calculated by linear interpolation between\n",
|
|
|
+ " the values of the previous and next provided frames. If there is no previous provided frame, then\n",
|
|
|
+ " the value is equal to the value of the next provided frame, or if there is no next provided frame,\n",
|
|
|
+ " then the value is equal to the value of the previous provided frame. If no frames are provided,\n",
|
|
|
+ " all frame values are NaN.\n",
|
|
|
+ "\n",
|
|
|
+ " Parameters\n",
|
|
|
+ " ----------\n",
|
|
|
+ " key_frames: dict\n",
|
|
|
+ " A dict with integer frame numbers as keys and numerical values of a particular parameter as values.\n",
|
|
|
+ " integer: Bool, optional\n",
|
|
|
+ " If True, the values of the output series are converted to integers.\n",
|
|
|
+ " Otherwise, the values are floats.\n",
|
|
|
+ " \n",
|
|
|
+ " Returns\n",
|
|
|
+ " -------\n",
|
|
|
+ " pd.Series\n",
|
|
|
+ " A Series with length max_frames representing the parameter values for each frame.\n",
|
|
|
+ " \n",
|
|
|
+ " Examples\n",
|
|
|
+ " --------\n",
|
|
|
+ " >>> max_frames = 5\n",
|
|
|
+ " >>> get_inbetweens({1: 5, 3: 6})\n",
|
|
|
+ " 0 5.0\n",
|
|
|
+ " 1 5.0\n",
|
|
|
+ " 2 5.5\n",
|
|
|
+ " 3 6.0\n",
|
|
|
+ " 4 6.0\n",
|
|
|
+ " dtype: float64\n",
|
|
|
+ "\n",
|
|
|
+ " >>> get_inbetweens({1: 5, 3: 6}, integer=True)\n",
|
|
|
+ " 0 5\n",
|
|
|
+ " 1 5\n",
|
|
|
+ " 2 5\n",
|
|
|
+ " 3 6\n",
|
|
|
+ " 4 6\n",
|
|
|
+ " dtype: int64\n",
|
|
|
+ " \"\"\"\n",
|
|
|
+ " key_frame_series = pd.Series([np.nan for a in range(max_frames)])\n",
|
|
|
+ "\n",
|
|
|
+ " for i, value in key_frames.items():\n",
|
|
|
+ " key_frame_series[i] = value\n",
|
|
|
+ " key_frame_series = key_frame_series.astype(float)\n",
|
|
|
+ " \n",
|
|
|
+ " interp_method = interp_spline\n",
|
|
|
+ "\n",
|
|
|
+ " if interp_method == 'Cubic' and len(key_frames.items()) <=3:\n",
|
|
|
+ " interp_method = 'Quadratic'\n",
|
|
|
+ " \n",
|
|
|
+ " if interp_method == 'Quadratic' and len(key_frames.items()) <= 2:\n",
|
|
|
+ " interp_method = 'Linear'\n",
|
|
|
+ " \n",
|
|
|
+ " \n",
|
|
|
+ " key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()]\n",
|
|
|
+ " key_frame_series[max_frames-1] = key_frame_series[key_frame_series.last_valid_index()]\n",
|
|
|
+ " # key_frame_series = key_frame_series.interpolate(method=intrp_method,order=1, limit_direction='both')\n",
|
|
|
+ " key_frame_series = key_frame_series.interpolate(method=interp_method.lower(),limit_direction='both')\n",
|
|
|
+ " if integer:\n",
|
|
|
+ " return key_frame_series.astype(int)\n",
|
|
|
+ " return key_frame_series\n",
|
|
|
+ "\n",
|
|
|
+ "def split_prompts(prompts):\n",
|
|
|
+ " prompt_series = pd.Series([np.nan for a in range(max_frames)])\n",
|
|
|
+ " for i, prompt in prompts.items():\n",
|
|
|
+ " prompt_series[i] = prompt\n",
|
|
|
+ " # prompt_series = prompt_series.astype(str)\n",
|
|
|
+ " prompt_series = prompt_series.ffill().bfill()\n",
|
|
|
+ " return prompt_series\n",
|
|
|
+ "\n",
|
|
|
+ "if key_frames:\n",
|
|
|
+ " try:\n",
|
|
|
+ " angle_series = get_inbetweens(parse_key_frames(angle))\n",
|
|
|
+ " except RuntimeError as e:\n",
|
|
|
+ " print(\n",
|
|
|
+ " \"WARNING: You have selected to use key frames, but you have not \"\n",
|
|
|
+ " \"formatted `angle` correctly for key frames.\\n\"\n",
|
|
|
+ " \"Attempting to interpret `angle` as \"\n",
|
|
|
+ " f'\"0: ({angle})\"\\n'\n",
|
|
|
+ " \"Please read the instructions to find out how to use key frames \"\n",
|
|
|
+ " \"correctly.\\n\"\n",
|
|
|
+ " )\n",
|
|
|
+ " angle = f\"0: ({angle})\"\n",
|
|
|
+ " angle_series = get_inbetweens(parse_key_frames(angle))\n",
|
|
|
+ "\n",
|
|
|
+ " try:\n",
|
|
|
+ " zoom_series = get_inbetweens(parse_key_frames(zoom))\n",
|
|
|
+ " except RuntimeError as e:\n",
|
|
|
+ " print(\n",
|
|
|
+ " \"WARNING: You have selected to use key frames, but you have not \"\n",
|
|
|
+ " \"formatted `zoom` correctly for key frames.\\n\"\n",
|
|
|
+ " \"Attempting to interpret `zoom` as \"\n",
|
|
|
+ " f'\"0: ({zoom})\"\\n'\n",
|
|
|
+ " \"Please read the instructions to find out how to use key frames \"\n",
|
|
|
+ " \"correctly.\\n\"\n",
|
|
|
+ " )\n",
|
|
|
+ " zoom = f\"0: ({zoom})\"\n",
|
|
|
+ " zoom_series = get_inbetweens(parse_key_frames(zoom))\n",
|
|
|
+ "\n",
|
|
|
+ " try:\n",
|
|
|
+ " translation_x_series = get_inbetweens(parse_key_frames(translation_x))\n",
|
|
|
+ " except RuntimeError as e:\n",
|
|
|
+ " print(\n",
|
|
|
+ " \"WARNING: You have selected to use key frames, but you have not \"\n",
|
|
|
+ " \"formatted `translation_x` correctly for key frames.\\n\"\n",
|
|
|
+ " \"Attempting to interpret `translation_x` as \"\n",
|
|
|
+ " f'\"0: ({translation_x})\"\\n'\n",
|
|
|
+ " \"Please read the instructions to find out how to use key frames \"\n",
|
|
|
+ " \"correctly.\\n\"\n",
|
|
|
+ " )\n",
|
|
|
+ " translation_x = f\"0: ({translation_x})\"\n",
|
|
|
+ " translation_x_series = get_inbetweens(parse_key_frames(translation_x))\n",
|
|
|
+ "\n",
|
|
|
+ " try:\n",
|
|
|
+ " translation_y_series = get_inbetweens(parse_key_frames(translation_y))\n",
|
|
|
+ " except RuntimeError as e:\n",
|
|
|
+ " print(\n",
|
|
|
+ " \"WARNING: You have selected to use key frames, but you have not \"\n",
|
|
|
+ " \"formatted `translation_y` correctly for key frames.\\n\"\n",
|
|
|
+ " \"Attempting to interpret `translation_y` as \"\n",
|
|
|
+ " f'\"0: ({translation_y})\"\\n'\n",
|
|
|
+ " \"Please read the instructions to find out how to use key frames \"\n",
|
|
|
+ " \"correctly.\\n\"\n",
|
|
|
+ " )\n",
|
|
|
+ " translation_y = f\"0: ({translation_y})\"\n",
|
|
|
+ " translation_y_series = get_inbetweens(parse_key_frames(translation_y))\n",
|
|
|
+ "\n",
|
|
|
+ " try:\n",
|
|
|
+ " translation_z_series = get_inbetweens(parse_key_frames(translation_z))\n",
|
|
|
+ " except RuntimeError as e:\n",
|
|
|
+ " print(\n",
|
|
|
+ " \"WARNING: You have selected to use key frames, but you have not \"\n",
|
|
|
+ " \"formatted `translation_z` correctly for key frames.\\n\"\n",
|
|
|
+ " \"Attempting to interpret `translation_z` as \"\n",
|
|
|
+ " f'\"0: ({translation_z})\"\\n'\n",
|
|
|
+ " \"Please read the instructions to find out how to use key frames \"\n",
|
|
|
+ " \"correctly.\\n\"\n",
|
|
|
+ " )\n",
|
|
|
+ " translation_z = f\"0: ({translation_z})\"\n",
|
|
|
+ " translation_z_series = get_inbetweens(parse_key_frames(translation_z))\n",
|
|
|
+ "\n",
|
|
|
+ " try:\n",
|
|
|
+ " rotation_3d_x_series = get_inbetweens(parse_key_frames(rotation_3d_x))\n",
|
|
|
+ " except RuntimeError as e:\n",
|
|
|
+ " print(\n",
|
|
|
+ " \"WARNING: You have selected to use key frames, but you have not \"\n",
|
|
|
+ " \"formatted `rotation_3d_x` correctly for key frames.\\n\"\n",
|
|
|
+ " \"Attempting to interpret `rotation_3d_x` as \"\n",
|
|
|
+ " f'\"0: ({rotation_3d_x})\"\\n'\n",
|
|
|
+ " \"Please read the instructions to find out how to use key frames \"\n",
|
|
|
+ " \"correctly.\\n\"\n",
|
|
|
+ " )\n",
|
|
|
+ " rotation_3d_x = f\"0: ({rotation_3d_x})\"\n",
|
|
|
+ " rotation_3d_x_series = get_inbetweens(parse_key_frames(rotation_3d_x))\n",
|
|
|
+ "\n",
|
|
|
+ " try:\n",
|
|
|
+ " rotation_3d_y_series = get_inbetweens(parse_key_frames(rotation_3d_y))\n",
|
|
|
+ " except RuntimeError as e:\n",
|
|
|
+ " print(\n",
|
|
|
+ " \"WARNING: You have selected to use key frames, but you have not \"\n",
|
|
|
+ " \"formatted `rotation_3d_y` correctly for key frames.\\n\"\n",
|
|
|
+ " \"Attempting to interpret `rotation_3d_y` as \"\n",
|
|
|
+ " f'\"0: ({rotation_3d_y})\"\\n'\n",
|
|
|
+ " \"Please read the instructions to find out how to use key frames \"\n",
|
|
|
+ " \"correctly.\\n\"\n",
|
|
|
+ " )\n",
|
|
|
+ " rotation_3d_y = f\"0: ({rotation_3d_y})\"\n",
|
|
|
+ " rotation_3d_y_series = get_inbetweens(parse_key_frames(rotation_3d_y))\n",
|
|
|
+ "\n",
|
|
|
+ " try:\n",
|
|
|
+ " rotation_3d_z_series = get_inbetweens(parse_key_frames(rotation_3d_z))\n",
|
|
|
+ " except RuntimeError as e:\n",
|
|
|
+ " print(\n",
|
|
|
+ " \"WARNING: You have selected to use key frames, but you have not \"\n",
|
|
|
+ " \"formatted `rotation_3d_z` correctly for key frames.\\n\"\n",
|
|
|
+ " \"Attempting to interpret `rotation_3d_z` as \"\n",
|
|
|
+ " f'\"0: ({rotation_3d_z})\"\\n'\n",
|
|
|
+ " \"Please read the instructions to find out how to use key frames \"\n",
|
|
|
+ " \"correctly.\\n\"\n",
|
|
|
+ " )\n",
|
|
|
+ " rotation_3d_z = f\"0: ({rotation_3d_z})\"\n",
|
|
|
+ " rotation_3d_z_series = get_inbetweens(parse_key_frames(rotation_3d_z))\n",
|
|
|
+ "\n",
|
|
|
+ "else:\n",
|
|
|
+ " angle = float(angle)\n",
|
|
|
+ " zoom = float(zoom)\n",
|
|
|
+ " translation_x = float(translation_x)\n",
|
|
|
+ " translation_y = float(translation_y)\n",
|
|
|
+ " translation_z = float(translation_z)\n",
|
|
|
+ " rotation_3d_x = float(rotation_3d_x)\n",
|
|
|
+ " rotation_3d_y = float(rotation_3d_y)\n",
|
|
|
+ " rotation_3d_z = float(rotation_3d_z)\n"
|
|
|
+ ],
|
|
|
+ "outputs": [],
|
|
|
+ "execution_count": null
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {
|
|
|
+ "id": "ExtraSetTop"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "### Extra Settings\n",
|
|
|
+ " Partial Saves, Diffusion Sharpening, Advanced Settings, Cutn Scheduling"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "metadata": {
|
|
|
+ "id": "ExtraSettings"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "#@markdown ####**Saving:**\n",
|
|
|
+ "\n",
|
|
|
+ "intermediate_saves = 4#@param{type: 'raw'}\n",
|
|
|
+ "intermediates_in_subfolder = True #@param{type: 'boolean'}\n",
|
|
|
+ "#@markdown Intermediate steps will save a copy at your specified intervals. You can either format it as a single integer or a list of specific steps \n",
|
|
|
+ "\n",
|
|
|
+ "#@markdown A value of `2` will save a copy at 33% and 66%. 0 will save none.\n",
|
|
|
+ "\n",
|
|
|
+ "#@markdown A value of `[5, 9, 34, 45]` will save at steps 5, 9, 34, and 45. (Make sure to include the brackets)\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "if type(intermediate_saves) is not list:\n",
|
|
|
+ " if intermediate_saves:\n",
|
|
|
+ " steps_per_checkpoint = math.floor((steps - skip_steps - 1) // (intermediate_saves+1))\n",
|
|
|
+ " steps_per_checkpoint = steps_per_checkpoint if steps_per_checkpoint > 0 else 1\n",
|
|
|
+ " print(f'Will save every {steps_per_checkpoint} steps')\n",
|
|
|
+ " else:\n",
|
|
|
+ " steps_per_checkpoint = steps+10\n",
|
|
|
+ "else:\n",
|
|
|
+ " steps_per_checkpoint = None\n",
|
|
|
+ "\n",
|
|
|
+ "if intermediate_saves and intermediates_in_subfolder is True:\n",
|
|
|
+ " partialFolder = f'{batchFolder}/partials'\n",
|
|
|
+ " createPath(partialFolder)\n",
|
|
|
+ "\n",
|
|
|
+ " #@markdown ---\n",
|
|
|
+ "\n",
|
|
|
+ "#@markdown ####**SuperRes Sharpening:**\n",
|
|
|
+ "#@markdown *Sharpen each image using latent-diffusion. Does not run in animation mode. `keep_unsharp` will save both versions.*\n",
|
|
|
+ "sharpen_preset = 'Fast' #@param ['Off', 'Faster', 'Fast', 'Slow', 'Very Slow']\n",
|
|
|
+ "keep_unsharp = True #@param{type: 'boolean'}\n",
|
|
|
+ "\n",
|
|
|
+ "if sharpen_preset != 'Off' and keep_unsharp is True:\n",
|
|
|
+ " unsharpenFolder = f'{batchFolder}/unsharpened'\n",
|
|
|
+ " createPath(unsharpenFolder)\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ " #@markdown ---\n",
|
|
|
+ "\n",
|
|
|
+ "#@markdown ####**Advanced Settings:**\n",
|
|
|
+ "#@markdown *There are a few extra advanced settings available if you double click this cell.*\n",
|
|
|
+ "\n",
|
|
|
+ "#@markdown *Perlin init will replace your init, so uncheck if using one.*\n",
|
|
|
+ "\n",
|
|
|
+ "perlin_init = False #@param{type: 'boolean'}\n",
|
|
|
+ "perlin_mode = 'mixed' #@param ['mixed', 'color', 'gray']\n",
|
|
|
+ "set_seed = 'random_seed' #@param{type: 'string'}\n",
|
|
|
+ "eta = 0.8#@param{type: 'number'}\n",
|
|
|
+ "clamp_grad = True #@param{type: 'boolean'}\n",
|
|
|
+ "clamp_max = 0.05 #@param{type: 'number'}\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "### EXTRA ADVANCED SETTINGS:\n",
|
|
|
+ "randomize_class = True\n",
|
|
|
+ "clip_denoised = False\n",
|
|
|
+ "fuzzy_prompt = False\n",
|
|
|
+ "rand_mag = 0.05\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ " #@markdown ---\n",
|
|
|
+ "\n",
|
|
|
+ "#@markdown ####**Cutn Scheduling:**\n",
|
|
|
+ "#@markdown Format: `[40]*400+[20]*600` = 40 cuts for the first 400 /1000 steps, then 20 for the last 600/1000\n",
|
|
|
+ "\n",
|
|
|
+ "#@markdown cut_overview and cut_innercut are cumulative for total cutn on any given step. Overview cuts see the entire image and are good for early structure, innercuts are your standard cutn.\n",
|
|
|
+ "\n",
|
|
|
+ "cut_overview = \"[12]*400+[4]*600\" #@param {type: 'string'} \n",
|
|
|
+ "cut_innercut =\"[4]*400+[12]*600\"#@param {type: 'string'} \n",
|
|
|
+ "cut_ic_pow = 1#@param {type: 'number'} \n",
|
|
|
+ "cut_icgray_p = \"[0.2]*400+[0]*600\"#@param {type: 'string'}\n"
|
|
|
+ ],
|
|
|
+ "outputs": [],
|
|
|
+ "execution_count": null
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {
|
|
|
+ "id": "PromptsTop"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "### Prompts\n",
|
|
|
+ "`animation_mode: None` will only use the first set. `animation_mode: 2D / Video` will run through them per the set frames and hold on the last one."
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "metadata": {
|
|
|
+ "id": "Prompts"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "text_prompts = {\n",
|
|
|
+ " 0: [\n",
|
|
|
+ " \"megastructure in the cloud, blame!, contemporary house in the mist, artstation\",\n",
|
|
|
+ " ]\n",
|
|
|
+ "}\n",
|
|
|
+ "\n",
|
|
|
+ "image_prompts = {\n",
|
|
|
+ " # 0:['ImagePromptsWorkButArentVeryGood.png:2',],\n",
|
|
|
+ "}\n"
|
|
|
+ ],
|
|
|
+ "outputs": [],
|
|
|
+ "execution_count": null
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {
|
|
|
+ "id": "DiffuseTop"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "# 4. Diffuse!"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "metadata": {
|
|
|
+ "id": "DoTheRun"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "#@title Do the Run!\n",
|
|
|
+ "#@markdown `n_batches` ignored with animation modes.\n",
|
|
|
+ "display_rate = 50#@param{type: 'number'}\n",
|
|
|
+ "n_batches = 50#@param{type: 'number'}\n",
|
|
|
+ "\n",
|
|
|
+ "#Update Model Settings\n",
|
|
|
+ "timestep_respacing = f'ddim{steps}'\n",
|
|
|
+ "diffusion_steps = (1000//steps)*steps if steps < 1000 else steps\n",
|
|
|
+ "model_config.update({\n",
|
|
|
+ " 'timestep_respacing': timestep_respacing,\n",
|
|
|
+ " 'diffusion_steps': diffusion_steps,\n",
|
|
|
+ "})\n",
|
|
|
+ "\n",
|
|
|
+ "batch_size = 1 \n",
|
|
|
+ "\n",
|
|
|
+ "def move_files(start_num, end_num, old_folder, new_folder):\n",
|
|
|
+ " for i in range(start_num, end_num):\n",
|
|
|
+ " old_file = old_folder + f'/{batch_name}({batchNum})_{i:04}.png'\n",
|
|
|
+ " new_file = new_folder + f'/{batch_name}({batchNum})_{i:04}.png'\n",
|
|
|
+ " os.rename(old_file, new_file)\n",
|
|
|
+ "\n",
|
|
|
+ "#@markdown ---\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "resume_run = False #@param{type: 'boolean'}\n",
|
|
|
+ "run_to_resume = 'latest' #@param{type: 'string'}\n",
|
|
|
+ "resume_from_frame = 'latest' #@param{type: 'string'}\n",
|
|
|
+ "retain_overwritten_frames = False #@param{type: 'boolean'}\n",
|
|
|
+ "if retain_overwritten_frames is True:\n",
|
|
|
+ " retainFolder = f'{batchFolder}/retained'\n",
|
|
|
+ " createPath(retainFolder)\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "skip_step_ratio = int(frames_skip_steps.rstrip(\"%\")) / 100\n",
|
|
|
+ "calc_frames_skip_steps = math.floor(steps * skip_step_ratio)\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ "if steps <= calc_frames_skip_steps:\n",
|
|
|
+ " sys.exit(\"ERROR: You can't skip more steps than your total steps\")\n",
|
|
|
+ "\n",
|
|
|
+ "if resume_run:\n",
|
|
|
+ " if run_to_resume == 'latest':\n",
|
|
|
+ " try:\n",
|
|
|
+ " batchNum\n",
|
|
|
+ " except:\n",
|
|
|
+ " batchNum = len(glob(f\"{batchFolder}/{batch_name}(*)_settings.txt\"))-1\n",
|
|
|
+ " else:\n",
|
|
|
+ " batchNum = int(run_to_resume)\n",
|
|
|
+ " if resume_from_frame == 'latest':\n",
|
|
|
+ " start_frame = len(glob(batchFolder+f\"/{batch_name}({batchNum})_*.png\"))\n",
|
|
|
+ " if animation_mode != '3D' and turbo_mode == True and start_frame > turbo_preroll and start_frame % int(turbo_steps) != 0:\n",
|
|
|
+ " start_frame = start_frame - (start_frame % int(turbo_steps))\n",
|
|
|
+ " else:\n",
|
|
|
+ " start_frame = int(resume_from_frame)+1\n",
|
|
|
+ " if animation_mode != '3D' and turbo_mode == True and start_frame > turbo_preroll and start_frame % int(turbo_steps) != 0:\n",
|
|
|
+ " start_frame = start_frame - (start_frame % int(turbo_steps))\n",
|
|
|
+ " if retain_overwritten_frames is True:\n",
|
|
|
+ " existing_frames = len(glob(batchFolder+f\"/{batch_name}({batchNum})_*.png\"))\n",
|
|
|
+ " frames_to_save = existing_frames - start_frame\n",
|
|
|
+ " print(f'Moving {frames_to_save} frames to the Retained folder')\n",
|
|
|
+ " move_files(start_frame, existing_frames, batchFolder, retainFolder)\n",
|
|
|
+ "else:\n",
|
|
|
+ " start_frame = 0\n",
|
|
|
+ " batchNum = len(glob(batchFolder+\"/*.txt\"))\n",
|
|
|
+ " while os.path.isfile(f\"{batchFolder}/{batch_name}({batchNum})_settings.txt\") is True or os.path.isfile(f\"{batchFolder}/{batch_name}-{batchNum}_settings.txt\") is True:\n",
|
|
|
+ " batchNum += 1\n",
|
|
|
+ "\n",
|
|
|
+ "print(f'Starting Run: {batch_name}({batchNum}) at frame {start_frame}')\n",
|
|
|
+ "\n",
|
|
|
+ "if set_seed == 'random_seed':\n",
|
|
|
+ " random.seed()\n",
|
|
|
+ " seed = random.randint(0, 2**32)\n",
|
|
|
+ " # print(f'Using seed: {seed}')\n",
|
|
|
+ "else:\n",
|
|
|
+ " seed = int(set_seed)\n",
|
|
|
+ "\n",
|
|
|
+ "args = {\n",
|
|
|
+ " 'batchNum': batchNum,\n",
|
|
|
+ " 'prompts_series':split_prompts(text_prompts) if text_prompts else None,\n",
|
|
|
+ " 'image_prompts_series':split_prompts(image_prompts) if image_prompts else None,\n",
|
|
|
+ " 'seed': seed,\n",
|
|
|
+ " 'display_rate':display_rate,\n",
|
|
|
+ " 'n_batches':n_batches if animation_mode == 'None' else 1,\n",
|
|
|
+ " 'batch_size':batch_size,\n",
|
|
|
+ " 'batch_name': batch_name,\n",
|
|
|
+ " 'steps': steps,\n",
|
|
|
+ " 'diffusion_sampling_mode': diffusion_sampling_mode,\n",
|
|
|
+ " 'width_height': width_height,\n",
|
|
|
+ " 'clip_guidance_scale': clip_guidance_scale,\n",
|
|
|
+ " 'tv_scale': tv_scale,\n",
|
|
|
+ " 'range_scale': range_scale,\n",
|
|
|
+ " 'sat_scale': sat_scale,\n",
|
|
|
+ " 'cutn_batches': cutn_batches,\n",
|
|
|
+ " 'init_image': init_image,\n",
|
|
|
+ " 'init_scale': init_scale,\n",
|
|
|
+ " 'skip_steps': skip_steps,\n",
|
|
|
+ " 'sharpen_preset': sharpen_preset,\n",
|
|
|
+ " 'keep_unsharp': keep_unsharp,\n",
|
|
|
+ " 'side_x': side_x,\n",
|
|
|
+ " 'side_y': side_y,\n",
|
|
|
+ " 'timestep_respacing': timestep_respacing,\n",
|
|
|
+ " 'diffusion_steps': diffusion_steps,\n",
|
|
|
+ " 'animation_mode': animation_mode,\n",
|
|
|
+ " 'video_init_path': video_init_path,\n",
|
|
|
+ " 'extract_nth_frame': extract_nth_frame,\n",
|
|
|
+ " 'video_init_seed_continuity': video_init_seed_continuity,\n",
|
|
|
+ " 'key_frames': key_frames,\n",
|
|
|
+ " 'max_frames': max_frames if animation_mode != \"None\" else 1,\n",
|
|
|
+ " 'interp_spline': interp_spline,\n",
|
|
|
+ " 'start_frame': start_frame,\n",
|
|
|
+ " 'angle': angle,\n",
|
|
|
+ " 'zoom': zoom,\n",
|
|
|
+ " 'translation_x': translation_x,\n",
|
|
|
+ " 'translation_y': translation_y,\n",
|
|
|
+ " 'translation_z': translation_z,\n",
|
|
|
+ " 'rotation_3d_x': rotation_3d_x,\n",
|
|
|
+ " 'rotation_3d_y': rotation_3d_y,\n",
|
|
|
+ " 'rotation_3d_z': rotation_3d_z,\n",
|
|
|
+ " 'midas_depth_model': midas_depth_model,\n",
|
|
|
+ " 'midas_weight': midas_weight,\n",
|
|
|
+ " 'near_plane': near_plane,\n",
|
|
|
+ " 'far_plane': far_plane,\n",
|
|
|
+ " 'fov': fov,\n",
|
|
|
+ " 'padding_mode': padding_mode,\n",
|
|
|
+ " 'sampling_mode': sampling_mode,\n",
|
|
|
+ " 'angle_series':angle_series,\n",
|
|
|
+ " 'zoom_series':zoom_series,\n",
|
|
|
+ " 'translation_x_series':translation_x_series,\n",
|
|
|
+ " 'translation_y_series':translation_y_series,\n",
|
|
|
+ " 'translation_z_series':translation_z_series,\n",
|
|
|
+ " 'rotation_3d_x_series':rotation_3d_x_series,\n",
|
|
|
+ " 'rotation_3d_y_series':rotation_3d_y_series,\n",
|
|
|
+ " 'rotation_3d_z_series':rotation_3d_z_series,\n",
|
|
|
+ " 'frames_scale': frames_scale,\n",
|
|
|
+ " 'calc_frames_skip_steps': calc_frames_skip_steps,\n",
|
|
|
+ " 'skip_step_ratio': skip_step_ratio,\n",
|
|
|
+ " 'calc_frames_skip_steps': calc_frames_skip_steps,\n",
|
|
|
+ " 'text_prompts': text_prompts,\n",
|
|
|
+ " 'image_prompts': image_prompts,\n",
|
|
|
+ " 'cut_overview': eval(cut_overview),\n",
|
|
|
+ " 'cut_innercut': eval(cut_innercut),\n",
|
|
|
+ " 'cut_ic_pow': cut_ic_pow,\n",
|
|
|
+ " 'cut_icgray_p': eval(cut_icgray_p),\n",
|
|
|
+ " 'intermediate_saves': intermediate_saves,\n",
|
|
|
+ " 'intermediates_in_subfolder': intermediates_in_subfolder,\n",
|
|
|
+ " 'steps_per_checkpoint': steps_per_checkpoint,\n",
|
|
|
+ " 'perlin_init': perlin_init,\n",
|
|
|
+ " 'perlin_mode': perlin_mode,\n",
|
|
|
+ " 'set_seed': set_seed,\n",
|
|
|
+ " 'eta': eta,\n",
|
|
|
+ " 'clamp_grad': clamp_grad,\n",
|
|
|
+ " 'clamp_max': clamp_max,\n",
|
|
|
+ " 'skip_augs': skip_augs,\n",
|
|
|
+ " 'randomize_class': randomize_class,\n",
|
|
|
+ " 'clip_denoised': clip_denoised,\n",
|
|
|
+ " 'fuzzy_prompt': fuzzy_prompt,\n",
|
|
|
+ " 'rand_mag': rand_mag,\n",
|
|
|
+ "}\n",
|
|
|
+ "\n",
|
|
|
+ "args = SimpleNamespace(**args)\n",
|
|
|
+ "\n",
|
|
|
+ "print('Prepping model...')\n",
|
|
|
+ "model, diffusion = create_model_and_diffusion(**model_config)\n",
|
|
|
+ "model.load_state_dict(torch.load(f'{model_path}/{diffusion_model}.pt', map_location='cpu'))\n",
|
|
|
+ "model.requires_grad_(False).eval().to(device)\n",
|
|
|
+ "for name, param in model.named_parameters():\n",
|
|
|
+ " if 'qkv' in name or 'norm' in name or 'proj' in name:\n",
|
|
|
+ " param.requires_grad_()\n",
|
|
|
+ "if model_config['use_fp16']:\n",
|
|
|
+ " model.convert_to_fp16()\n",
|
|
|
+ "\n",
|
|
|
+ "gc.collect()\n",
|
|
|
+ "torch.cuda.empty_cache()\n",
|
|
|
+ "try:\n",
|
|
|
+ " do_run()\n",
|
|
|
+ "except KeyboardInterrupt:\n",
|
|
|
+ " pass\n",
|
|
|
+ "finally:\n",
|
|
|
+ " print('Seed used:', seed)\n",
|
|
|
+ " gc.collect()\n",
|
|
|
+ " torch.cuda.empty_cache()\n"
|
|
|
+ ],
|
|
|
+ "outputs": [],
|
|
|
+ "execution_count": null
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "markdown",
|
|
|
+ "metadata": {
|
|
|
+ "id": "CreateVidTop"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "# 5. Create the video"
|
|
|
+ ]
|
|
|
+ },
|
|
|
+ {
|
|
|
+ "cell_type": "code",
|
|
|
+ "metadata": {
|
|
|
+ "id": "CreateVid"
|
|
|
+ },
|
|
|
+ "source": [
|
|
|
+ "# @title ### **Create video**\n",
|
|
|
+ "#@markdown Video file will save in the same folder as your images.\n",
|
|
|
+ "\n",
|
|
|
+ "skip_video_for_run_all = True #@param {type: 'boolean'}\n",
|
|
|
+ "\n",
|
|
|
+ "if skip_video_for_run_all == True:\n",
|
|
|
+ " print('Skipping video creation, uncheck skip_video_for_run_all if you want to run it')\n",
|
|
|
+ "\n",
|
|
|
+ "else:\n",
|
|
|
+ " # import subprocess in case this cell is run without the above cells\n",
|
|
|
+ " import subprocess\n",
|
|
|
+ " from base64 import b64encode\n",
|
|
|
+ "\n",
|
|
|
+ " latest_run = batchNum\n",
|
|
|
+ "\n",
|
|
|
+ " folder = batch_name #@param\n",
|
|
|
+ " run = latest_run #@param\n",
|
|
|
+ " final_frame = 'final_frame'\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ " init_frame = 1#@param {type:\"number\"} This is the frame where the video will start\n",
|
|
|
+ " last_frame = final_frame#@param {type:\"number\"} You can change i to the number of the last frame you want to generate. It will raise an error if that number of frames does not exist.\n",
|
|
|
+ " fps = 12#@param {type:\"number\"}\n",
|
|
|
+ " # view_video_in_cell = True #@param {type: 'boolean'}\n",
|
|
|
+ "\n",
|
|
|
+ " frames = []\n",
|
|
|
+ " # tqdm.write('Generating video...')\n",
|
|
|
+ "\n",
|
|
|
+ " if last_frame == 'final_frame':\n",
|
|
|
+ " last_frame = len(glob(batchFolder+f\"/{folder}({run})_*.png\"))\n",
|
|
|
+ " print(f'Total frames: {last_frame}')\n",
|
|
|
+ "\n",
|
|
|
+ " image_path = f\"{outDirPath}/{folder}/{folder}({run})_%04d.png\"\n",
|
|
|
+ " filepath = f\"{outDirPath}/{folder}/{folder}({run}).mp4\"\n",
|
|
|
+ "\n",
|
|
|
+ "\n",
|
|
|
+ " cmd = [\n",
|
|
|
+ " 'ffmpeg',\n",
|
|
|
+ " '-y',\n",
|
|
|
+ " '-vcodec',\n",
|
|
|
+ " 'png',\n",
|
|
|
+ " '-r',\n",
|
|
|
+ " str(fps),\n",
|
|
|
+ " '-start_number',\n",
|
|
|
+ " str(init_frame),\n",
|
|
|
+ " '-i',\n",
|
|
|
+ " image_path,\n",
|
|
|
+ " '-frames:v',\n",
|
|
|
+ " str(last_frame+1),\n",
|
|
|
+ " '-c:v',\n",
|
|
|
+ " 'libx264',\n",
|
|
|
+ " '-vf',\n",
|
|
|
+ " f'fps={fps}',\n",
|
|
|
+ " '-pix_fmt',\n",
|
|
|
+ " 'yuv420p',\n",
|
|
|
+ " '-crf',\n",
|
|
|
+ " '17',\n",
|
|
|
+ " '-preset',\n",
|
|
|
+ " 'veryslow',\n",
|
|
|
+ " filepath\n",
|
|
|
+ " ]\n",
|
|
|
+ "\n",
|
|
|
+ " process = subprocess.Popen(cmd, cwd=f'{batchFolder}', stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n",
|
|
|
+ " stdout, stderr = process.communicate()\n",
|
|
|
+ " if process.returncode != 0:\n",
|
|
|
+ " print(stderr)\n",
|
|
|
+ " raise RuntimeError(stderr)\n",
|
|
|
+ " else:\n",
|
|
|
+ " print(\"The video is ready and saved to the images folder\")\n",
|
|
|
+ "\n",
|
|
|
+ " # if view_video_in_cell:\n",
|
|
|
+ " # mp4 = open(filepath,'rb').read()\n",
|
|
|
+ " # data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
|
|
|
+ " # display.HTML(f'<video width=400 controls><source src=\"{data_url}\" type=\"video/mp4\"></video>')\n",
|
|
|
+ " \n"
|
|
|
+ ],
|
|
|
+ "outputs": [],
|
|
|
+ "execution_count": null
|
|
|
+ }
|
|
|
+ ],
|
|
|
+ "metadata": {
|
|
|
+ "anaconda-cloud": {},
|
|
|
+ "accelerator": "GPU",
|
|
|
+ "colab": {
|
|
|
+ "collapsed_sections": [
|
|
|
+ "CreditsChTop",
|
|
|
+ "TutorialTop",
|
|
|
+ "CheckGPU",
|
|
|
+ "InstallDeps",
|
|
|
+ "DefMidasFns",
|
|
|
+ "DefFns",
|
|
|
+ "DefSecModel",
|
|
|
+ "DefSuperRes",
|
|
|
+ "AnimSetTop",
|
|
|
+ "ExtraSetTop"
|
|
|
+ ],
|
|
|
+ "machine_shape": "hm",
|
|
|
+ "name": "Copie de Disco Diffusion v5.1 [w/ Turbo]",
|
|
|
+ "private_outputs": true,
|
|
|
+ "provenance": []
|
|
|
+ },
|
|
|
+ "kernelspec": {
|
|
|
+ "display_name": "Python 3",
|
|
|
+ "language": "python",
|
|
|
+ "name": "python3"
|
|
|
+ },
|
|
|
+ "language_info": {
|
|
|
+ "codemirror_mode": {
|
|
|
+ "name": "ipython",
|
|
|
+ "version": 3
|
|
|
+ },
|
|
|
+ "file_extension": ".py",
|
|
|
+ "mimetype": "text/x-python",
|
|
|
+ "name": "python",
|
|
|
+ "nbconvert_exporter": "python",
|
|
|
+ "pygments_lexer": "ipython3",
|
|
|
+ "version": "3.6.1"
|
|
|
+ }
|
|
|
+ },
|
|
|
+ "nbformat": 4,
|
|
|
+ "nbformat_minor": 0
|
|
|
+}
|