|
@@ -1,3269 +1,3186 @@
|
|
|
{
|
|
|
- "cells": [
|
|
|
- {
|
|
|
- "cell_type": "markdown",
|
|
|
- "metadata": {
|
|
|
- "id": "TitleTop"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "# Disco Diffusion v5.1 - Now with Turbo\n",
|
|
|
- "\n",
|
|
|
- "In case of confusion, Disco is the name of this notebook edit. The diffusion model in use is Katherine Crowson's fine-tuned 512x512 model\n",
|
|
|
- "\n",
|
|
|
- "For issues, join the [Disco Diffusion Discord](https://discord.gg/msEZBy4HxA) or message us on twitter at [@somnai_dreams](https://twitter.com/somnai_dreams) or [@gandamu](https://twitter.com/gandamu_ml)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "markdown",
|
|
|
- "metadata": {
|
|
|
- "id": "CreditsChTop"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "### Credits & Changelog ⬇️"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "markdown",
|
|
|
- "metadata": {
|
|
|
- "id": "Credits"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "#### Credits\n",
|
|
|
- "\n",
|
|
|
- "Original notebook by Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). It uses either OpenAI's 256x256 unconditional ImageNet or Katherine Crowson's fine-tuned 512x512 diffusion model (https://github.com/openai/guided-diffusion), together with CLIP (https://github.com/openai/CLIP) to connect text prompts with images.\n",
|
|
|
- "\n",
|
|
|
- "Modified by Daniel Russell (https://github.com/russelldc, https://twitter.com/danielrussruss) to include (hopefully) optimal params for quick generations in 15-100 timesteps rather than 1000, as well as more robust augmentations.\n",
|
|
|
- "\n",
|
|
|
- "Further improvements from Dango233 and nsheppard helped improve the quality of diffusion in general, and especially so for shorter runs like this notebook aims to achieve.\n",
|
|
|
- "\n",
|
|
|
- "Vark added code to load in multiple Clip models at once, which all prompts are evaluated against, which may greatly improve accuracy.\n",
|
|
|
- "\n",
|
|
|
- "The latest zoom, pan, rotation, and keyframes features were taken from Chigozie Nri's VQGAN Zoom Notebook (https://github.com/chigozienri, https://twitter.com/chigozienri)\n",
|
|
|
- "\n",
|
|
|
- "Advanced DangoCutn Cutout method is also from Dango223.\n",
|
|
|
- "\n",
|
|
|
- "--\n",
|
|
|
- "\n",
|
|
|
- "Disco:\n",
|
|
|
- "\n",
|
|
|
- "Somnai (https://twitter.com/Somnai_dreams) added Diffusion Animation techniques, QoL improvements and various implementations of tech and techniques, mostly listed in the changelog below.\n",
|
|
|
- "\n",
|
|
|
- "3D animation implementation added by Adam Letts (https://twitter.com/gandamu_ml) in collaboration with Somnai.\n",
|
|
|
- "\n",
|
|
|
- "Turbo feature by Chris Allen (https://twitter.com/zippy731)"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "markdown",
|
|
|
- "metadata": {
|
|
|
- "id": "LicenseTop"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "#### License"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "markdown",
|
|
|
- "metadata": {
|
|
|
- "id": "License"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "Licensed under the MIT License\n",
|
|
|
- "\n",
|
|
|
- "Copyright (c) 2021 Katherine Crowson \n",
|
|
|
- "\n",
|
|
|
- "Permission is hereby granted, free of charge, to any person obtaining a copy\n",
|
|
|
- "of this software and associated documentation files (the \"Software\"), to deal\n",
|
|
|
- "in the Software without restriction, including without limitation the rights\n",
|
|
|
- "to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n",
|
|
|
- "copies of the Software, and to permit persons to whom the Software is\n",
|
|
|
- "furnished to do so, subject to the following conditions:\n",
|
|
|
- "\n",
|
|
|
- "The above copyright notice and this permission notice shall be included in\n",
|
|
|
- "all copies or substantial portions of the Software.\n",
|
|
|
- "\n",
|
|
|
- "THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n",
|
|
|
- "IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n",
|
|
|
- "FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n",
|
|
|
- "AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n",
|
|
|
- "LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n",
|
|
|
- "OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n",
|
|
|
- "THE SOFTWARE.\n",
|
|
|
- "\n",
|
|
|
- "--\n",
|
|
|
- "\n",
|
|
|
- "MIT License\n",
|
|
|
- "\n",
|
|
|
- "Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab)\n",
|
|
|
- "\n",
|
|
|
- "Permission is hereby granted, free of charge, to any person obtaining a copy\n",
|
|
|
- "of this software and associated documentation files (the \"Software\"), to deal\n",
|
|
|
- "in the Software without restriction, including without limitation the rights\n",
|
|
|
- "to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n",
|
|
|
- "copies of the Software, and to permit persons to whom the Software is\n",
|
|
|
- "furnished to do so, subject to the following conditions:\n",
|
|
|
- "\n",
|
|
|
- "The above copyright notice and this permission notice shall be included in all\n",
|
|
|
- "copies or substantial portions of the Software.\n",
|
|
|
- "\n",
|
|
|
- "THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n",
|
|
|
- "IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n",
|
|
|
- "FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n",
|
|
|
- "AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n",
|
|
|
- "LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n",
|
|
|
- "OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE\n",
|
|
|
- "SOFTWARE.\n",
|
|
|
- "\n",
|
|
|
- "--\n",
|
|
|
- "\n",
|
|
|
- "Licensed under the MIT License\n",
|
|
|
- "\n",
|
|
|
- "Copyright (c) 2021 Maxwell Ingham\n",
|
|
|
- "\n",
|
|
|
- "Copyright (c) 2022 Adam Letts \n",
|
|
|
- "\n",
|
|
|
- "Permission is hereby granted, free of charge, to any person obtaining a copy\n",
|
|
|
- "of this software and associated documentation files (the \"Software\"), to deal\n",
|
|
|
- "in the Software without restriction, including without limitation the rights\n",
|
|
|
- "to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n",
|
|
|
- "copies of the Software, and to permit persons to whom the Software is\n",
|
|
|
- "furnished to do so, subject to the following conditions:\n",
|
|
|
- "\n",
|
|
|
- "The above copyright notice and this permission notice shall be included in\n",
|
|
|
- "all copies or substantial portions of the Software.\n",
|
|
|
- "\n",
|
|
|
- "THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n",
|
|
|
- "IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n",
|
|
|
- "FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n",
|
|
|
- "AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n",
|
|
|
- "LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n",
|
|
|
- "OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n",
|
|
|
- "THE SOFTWARE."
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "markdown",
|
|
|
- "metadata": {
|
|
|
- "id": "ChangelogTop"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "#### Changelog"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "metadata": {
|
|
|
- "cellView": "form",
|
|
|
- "id": "Changelog"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "#@title <- View Changelog\n",
|
|
|
- "skip_for_run_all = True #@param {type: 'boolean'}\n",
|
|
|
- "\n",
|
|
|
- "if skip_for_run_all == False:\n",
|
|
|
- " print(\n",
|
|
|
- " '''\n",
|
|
|
- " v1 Update: Oct 29th 2021 - Somnai\n",
|
|
|
- "\n",
|
|
|
- " QoL improvements added by Somnai (@somnai_dreams), including user friendly UI, settings+prompt saving and improved google drive folder organization.\n",
|
|
|
- "\n",
|
|
|
- " v1.1 Update: Nov 13th 2021 - Somnai\n",
|
|
|
- "\n",
|
|
|
- " Now includes sizing options, intermediate saves and fixed image prompts and perlin inits. unexposed batch option since it doesn't work\n",
|
|
|
- "\n",
|
|
|
- " v2 Update: Nov 22nd 2021 - Somnai\n",
|
|
|
- "\n",
|
|
|
- " Initial addition of Katherine Crowson's Secondary Model Method (https://colab.research.google.com/drive/1mpkrhOjoyzPeSWy2r7T8EYRaU7amYOOi#scrollTo=X5gODNAMEUCR)\n",
|
|
|
- "\n",
|
|
|
- " Noticed settings were saving with the wrong name so corrected it. Let me know if you preferred the old scheme.\n",
|
|
|
- "\n",
|
|
|
- " v3 Update: Dec 24th 2021 - Somnai\n",
|
|
|
- "\n",
|
|
|
- " Implemented Dango's advanced cutout method\n",
|
|
|
- "\n",
|
|
|
- " Added SLIP models, thanks to NeuralDivergent\n",
|
|
|
- "\n",
|
|
|
- " Fixed issue with NaNs resulting in black images, with massive help and testing from @Softology\n",
|
|
|
- "\n",
|
|
|
- " Perlin now changes properly within batches (not sure where this perlin_regen code came from originally, but thank you)\n",
|
|
|
- "\n",
|
|
|
- " v4 Update: Jan 2021 - Somnai\n",
|
|
|
- "\n",
|
|
|
- " Implemented Diffusion Zooming\n",
|
|
|
- "\n",
|
|
|
- " Added Chigozie keyframing\n",
|
|
|
- "\n",
|
|
|
- " Made a bunch of edits to processes\n",
|
|
|
- " \n",
|
|
|
- " v4.1 Update: Jan 14th 2021 - Somnai\n",
|
|
|
- "\n",
|
|
|
- " Added video input mode\n",
|
|
|
- "\n",
|
|
|
- " Added license that somehow went missing\n",
|
|
|
- "\n",
|
|
|
- " Added improved prompt keyframing, fixed image_prompts and multiple prompts\n",
|
|
|
- "\n",
|
|
|
- " Improved UI\n",
|
|
|
- "\n",
|
|
|
- " Significant under the hood cleanup and improvement\n",
|
|
|
- "\n",
|
|
|
- " Refined defaults for each mode\n",
|
|
|
- "\n",
|
|
|
- " Added latent-diffusion SuperRes for sharpening\n",
|
|
|
- "\n",
|
|
|
- " Added resume run mode\n",
|
|
|
- "\n",
|
|
|
- " v4.9 Update: Feb 5th 2022 - gandamu / Adam Letts\n",
|
|
|
- "\n",
|
|
|
- " Added 3D\n",
|
|
|
- "\n",
|
|
|
- " Added brightness corrections to prevent animation from steadily going dark over time\n",
|
|
|
- "\n",
|
|
|
- " v4.91 Update: Feb 19th 2022 - gandamu / Adam Letts\n",
|
|
|
- "\n",
|
|
|
- " Cleaned up 3D implementation and made associated args accessible via Colab UI elements\n",
|
|
|
- "\n",
|
|
|
- " v4.92 Update: Feb 20th 2022 - gandamu / Adam Letts\n",
|
|
|
- "\n",
|
|
|
- " Separated transform code\n",
|
|
|
- "\n",
|
|
|
- " v5.01 Update: Mar 10th 2022 - gandamu / Adam Letts\n",
|
|
|
- "\n",
|
|
|
- " IPython magic commands replaced by Python code\n",
|
|
|
- "\n",
|
|
|
- " v5.1 Update: Mar 30th 2022 - zippy / Chris Allen and gandamu / Adam Letts\n",
|
|
|
- "\n",
|
|
|
- " Integrated Turbo+Smooth features from Disco Diffusion Turbo -- just the implementation, without its defaults.\n",
|
|
|
- "\n",
|
|
|
- " Implemented resume of turbo animations in such a way that it's now possible to resume from different batch folders and batch numbers.\n",
|
|
|
- "\n",
|
|
|
- " 3D rotation parameter units are now degrees (rather than radians)\n",
|
|
|
- "\n",
|
|
|
- " Corrected name collision in sampling_mode (now diffusion_sampling_mode for plms/ddim, and sampling_mode for 3D transform sampling)\n",
|
|
|
- "\n",
|
|
|
- " Added video_init_seed_continuity option to make init video animations more continuous\n",
|
|
|
- "\n",
|
|
|
- " '''\n",
|
|
|
- " )\n"
|
|
|
- ],
|
|
|
- "outputs": [],
|
|
|
- "execution_count": null
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "markdown",
|
|
|
- "metadata": {
|
|
|
- "id": "TutorialTop"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "# Tutorial"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "markdown",
|
|
|
- "metadata": {
|
|
|
- "id": "DiffusionSet"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "**Diffusion settings (Defaults are heavily outdated)**\n",
|
|
|
- "---\n",
|
|
|
- "\n",
|
|
|
- "This section is outdated as of v2\n",
|
|
|
- "\n",
|
|
|
- "Setting | Description | Default\n",
|
|
|
- "--- | --- | ---\n",
|
|
|
- "**Your vision:**\n",
|
|
|
- "`text_prompts` | A description of what you'd like the machine to generate. Think of it like writing the caption below your image on a website. | N/A\n",
|
|
|
- "`image_prompts` | Think of these images more as a description of their contents. | N/A\n",
|
|
|
- "**Image quality:**\n",
|
|
|
- "`clip_guidance_scale` | Controls how much the image should look like the prompt. | 1000\n",
|
|
|
- "`tv_scale` | Controls the smoothness of the final output. | 150\n",
|
|
|
- "`range_scale` | Controls how far out of range RGB values are allowed to be. | 150\n",
|
|
|
- "`sat_scale` | Controls how much saturation is allowed. From nshepperd's JAX notebook. | 0\n",
|
|
|
- "`cutn` | Controls how many crops to take from the image. | 16\n",
|
|
|
- "`cutn_batches` | Accumulate CLIP gradient from multiple batches of cuts | 2\n",
|
|
|
- "**Init settings:**\n",
|
|
|
- "`init_image` | URL or local path | None\n",
|
|
|
- "`init_scale` | This enhances the effect of the init image, a good value is 1000 | 0\n",
|
|
|
- "`skip_steps Controls the starting point along the diffusion timesteps | 0\n",
|
|
|
- "`perlin_init` | Option to start with random perlin noise | False\n",
|
|
|
- "`perlin_mode` | ('gray', 'color') | 'mixed'\n",
|
|
|
- "**Advanced:**\n",
|
|
|
- "`skip_augs` |Controls whether to skip torchvision augmentations | False\n",
|
|
|
- "`randomize_class` |Controls whether the imagenet class is randomly changed each iteration | True\n",
|
|
|
- "`clip_denoised` |Determines whether CLIP discriminates a noisy or denoised image | False\n",
|
|
|
- "`clamp_grad` |Experimental: Using adaptive clip grad in the cond_fn | True\n",
|
|
|
- "`seed` | Choose a random seed and print it at end of run for reproduction | random_seed\n",
|
|
|
- "`fuzzy_prompt` | Controls whether to add multiple noisy prompts to the prompt losses | False\n",
|
|
|
- "`rand_mag` |Controls the magnitude of the random noise | 0.1\n",
|
|
|
- "`eta` | DDIM hyperparameter | 0.5\n",
|
|
|
- "\n",
|
|
|
- "..\n",
|
|
|
- "\n",
|
|
|
- "**Model settings**\n",
|
|
|
- "---\n",
|
|
|
- "\n",
|
|
|
- "Setting | Description | Default\n",
|
|
|
- "--- | --- | ---\n",
|
|
|
- "**Diffusion:**\n",
|
|
|
- "`timestep_respacing` | Modify this value to decrease the number of timesteps. | ddim100\n",
|
|
|
- "`diffusion_steps` || 1000\n",
|
|
|
- "**Diffusion:**\n",
|
|
|
- "`clip_models` | Models of CLIP to load. Typically the more, the better but they all come at a hefty VRAM cost. | ViT-B/32, ViT-B/16, RN50x4"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "markdown",
|
|
|
- "metadata": {
|
|
|
- "id": "SetupTop"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "# 1. Set Up"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "metadata": {
|
|
|
- "cellView": "form",
|
|
|
- "id": "CheckGPU"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "#@title 1.1 Check GPU Status\n",
|
|
|
- "import subprocess\n",
|
|
|
- "simple_nvidia_smi_display = False#@param {type:\"boolean\"}\n",
|
|
|
- "if simple_nvidia_smi_display:\n",
|
|
|
- " #!nvidia-smi\n",
|
|
|
- " nvidiasmi_output = subprocess.run(['nvidia-smi', '-L'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
|
|
- " print(nvidiasmi_output)\n",
|
|
|
- "else:\n",
|
|
|
- " #!nvidia-smi -i 0 -e 0\n",
|
|
|
- " nvidiasmi_output = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
|
|
- " print(nvidiasmi_output)\n",
|
|
|
- " nvidiasmi_ecc_note = subprocess.run(['nvidia-smi', '-i', '0', '-e', '0'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
|
|
- " print(nvidiasmi_ecc_note)"
|
|
|
- ],
|
|
|
- "outputs": [],
|
|
|
- "execution_count": null
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "metadata": {
|
|
|
- "cellView": "form",
|
|
|
- "id": "PrepFolders"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "#@title 1.2 Prepare Folders\n",
|
|
|
- "import subprocess\n",
|
|
|
- "import sys\n",
|
|
|
- "import ipykernel\n",
|
|
|
- "\n",
|
|
|
- "def gitclone(url):\n",
|
|
|
- " res = subprocess.run(['git', 'clone', url], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
|
|
- " print(res)\n",
|
|
|
- "\n",
|
|
|
- "def pipi(modulestr):\n",
|
|
|
- " res = subprocess.run(['pip', 'install', modulestr], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
|
|
- " print(res)\n",
|
|
|
- "\n",
|
|
|
- "def pipie(modulestr):\n",
|
|
|
- " res = subprocess.run(['git', 'install', '-e', modulestr], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
|
|
- " print(res)\n",
|
|
|
- "\n",
|
|
|
- "def wget(url, outputdir):\n",
|
|
|
- " res = subprocess.run(['wget', url, '-P', f'{outputdir}'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
|
|
- " print(res)\n",
|
|
|
- "\n",
|
|
|
- "try:\n",
|
|
|
- " from google.colab import drive\n",
|
|
|
- " print(\"Google Colab detected. Using Google Drive.\")\n",
|
|
|
- " is_colab = True\n",
|
|
|
- " #@markdown If you connect your Google Drive, you can save the final image of each run on your drive.\n",
|
|
|
- " google_drive = True #@param {type:\"boolean\"}\n",
|
|
|
- " #@markdown Click here if you'd like to save the diffusion model checkpoint file to (and/or load from) your Google Drive:\n",
|
|
|
- " save_models_to_google_drive = True #@param {type:\"boolean\"}\n",
|
|
|
- "except:\n",
|
|
|
- " is_colab = False\n",
|
|
|
- " google_drive = False\n",
|
|
|
- " save_models_to_google_drive = False\n",
|
|
|
- " print(\"Google Colab not detected.\")\n",
|
|
|
- "\n",
|
|
|
- "if is_colab:\n",
|
|
|
- " if google_drive is True:\n",
|
|
|
- " drive.mount('/content/drive')\n",
|
|
|
- " root_path = '/content/drive/MyDrive/AI/Disco_Diffusion'\n",
|
|
|
- " else:\n",
|
|
|
- " root_path = '/content'\n",
|
|
|
- "else:\n",
|
|
|
- " root_path = '.'\n",
|
|
|
- "\n",
|
|
|
- "import os\n",
|
|
|
- "def createPath(filepath):\n",
|
|
|
- " os.makedirs(filepath, exist_ok=True)\n",
|
|
|
- "\n",
|
|
|
- "initDirPath = f'{root_path}/init_images'\n",
|
|
|
- "createPath(initDirPath)\n",
|
|
|
- "outDirPath = f'{root_path}/images_out'\n",
|
|
|
- "createPath(outDirPath)\n",
|
|
|
- "\n",
|
|
|
- "if is_colab:\n",
|
|
|
- " if google_drive and not save_models_to_google_drive or not google_drive:\n",
|
|
|
- " model_path = '/content/model'\n",
|
|
|
- " createPath(model_path)\n",
|
|
|
- " if google_drive and save_models_to_google_drive:\n",
|
|
|
- " model_path = f'{root_path}/model'\n",
|
|
|
- " createPath(model_path)\n",
|
|
|
- "else:\n",
|
|
|
- " model_path = f'{root_path}/model'\n",
|
|
|
- " createPath(model_path)\n",
|
|
|
- "\n",
|
|
|
- "# libraries = f'{root_path}/libraries'\n",
|
|
|
- "# createPath(libraries)"
|
|
|
- ],
|
|
|
- "outputs": [],
|
|
|
- "execution_count": null
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "metadata": {
|
|
|
- "cellView": "form",
|
|
|
- "id": "InstallDeps"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "#@title ### 1.3 Install and import dependencies\n",
|
|
|
- "\n",
|
|
|
- "import pathlib, shutil\n",
|
|
|
- "\n",
|
|
|
- "if not is_colab:\n",
|
|
|
- " # If running locally, there's a good chance your env will need this in order to not crash upon np.matmul() or similar operations.\n",
|
|
|
- " os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'\n",
|
|
|
- "\n",
|
|
|
- "PROJECT_DIR = os.path.abspath(os.getcwd())\n",
|
|
|
- "USE_ADABINS = True\n",
|
|
|
- "\n",
|
|
|
- "if is_colab:\n",
|
|
|
- " if google_drive is not True:\n",
|
|
|
- " root_path = f'/content'\n",
|
|
|
- " model_path = '/content/models' \n",
|
|
|
- "else:\n",
|
|
|
- " root_path = f'.'\n",
|
|
|
- " model_path = f'{root_path}/model'\n",
|
|
|
- "\n",
|
|
|
- "model_256_downloaded = False\n",
|
|
|
- "model_512_downloaded = False\n",
|
|
|
- "model_secondary_downloaded = False\n",
|
|
|
- "\n",
|
|
|
- "if is_colab:\n",
|
|
|
- " gitclone(\"https://github.com/openai/CLIP\")\n",
|
|
|
- " #gitclone(\"https://github.com/facebookresearch/SLIP.git\")\n",
|
|
|
- " gitclone(\"https://github.com/crowsonkb/guided-diffusion\")\n",
|
|
|
- " gitclone(\"https://github.com/assafshocher/ResizeRight.git\")\n",
|
|
|
- " gitclone(\"https://github.com/MSFTserver/pytorch3d-lite.git\")\n",
|
|
|
- " pipie(\"./CLIP\")\n",
|
|
|
- " pipie(\"./guided-diffusion\")\n",
|
|
|
- " multipip_res = subprocess.run(['pip', 'install', 'lpips', 'datetime', 'timm', 'ftfy'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
|
|
- " print(multipip_res)\n",
|
|
|
- " subprocess.run(['apt', 'install', 'imagemagick'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
|
|
- " gitclone(\"https://github.com/isl-org/MiDaS.git\")\n",
|
|
|
- " gitclone(\"https://github.com/alembics/disco-diffusion.git\")\n",
|
|
|
- " pipi(\"pytorch-lightning\")\n",
|
|
|
- " pipi(\"omegaconf\")\n",
|
|
|
- " pipi(\"einops\")\n",
|
|
|
- " # Rename a file to avoid a name conflict..\n",
|
|
|
- " try:\n",
|
|
|
- " os.rename(\"MiDaS/utils.py\", \"MiDaS/midas_utils.py\")\n",
|
|
|
- " shutil.copyfile(\"disco-diffusion/disco_xform_utils.py\", \"disco_xform_utils.py\")\n",
|
|
|
- " except:\n",
|
|
|
- " pass\n",
|
|
|
- "\n",
|
|
|
- "if not os.path.exists(f'{model_path}'):\n",
|
|
|
- " pathlib.Path(model_path).mkdir(parents=True, exist_ok=True)\n",
|
|
|
- "if not os.path.exists(f'{model_path}/dpt_large-midas-2f21e586.pt'):\n",
|
|
|
- " wget(\"https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt\", model_path)\n",
|
|
|
- "\n",
|
|
|
- "import sys\n",
|
|
|
- "import torch\n",
|
|
|
- "\n",
|
|
|
- "# sys.path.append('./SLIP')\n",
|
|
|
- "sys.path.append('./pytorch3d-lite')\n",
|
|
|
- "sys.path.append('./ResizeRight')\n",
|
|
|
- "sys.path.append('./MiDaS')\n",
|
|
|
- "from dataclasses import dataclass\n",
|
|
|
- "from functools import partial\n",
|
|
|
- "import cv2\n",
|
|
|
- "import pandas as pd\n",
|
|
|
- "import gc\n",
|
|
|
- "import io\n",
|
|
|
- "import math\n",
|
|
|
- "import timm\n",
|
|
|
- "from IPython import display\n",
|
|
|
- "import lpips\n",
|
|
|
- "from PIL import Image, ImageOps\n",
|
|
|
- "import requests\n",
|
|
|
- "from glob import glob\n",
|
|
|
- "import json\n",
|
|
|
- "from types import SimpleNamespace\n",
|
|
|
- "from torch import nn\n",
|
|
|
- "from torch.nn import functional as F\n",
|
|
|
- "import torchvision.transforms as T\n",
|
|
|
- "import torchvision.transforms.functional as TF\n",
|
|
|
- "from tqdm.notebook import tqdm\n",
|
|
|
- "sys.path.append('./CLIP')\n",
|
|
|
- "sys.path.append('./guided-diffusion')\n",
|
|
|
- "import clip\n",
|
|
|
- "from resize_right import resize\n",
|
|
|
- "# from models import SLIP_VITB16, SLIP, SLIP_VITL16\n",
|
|
|
- "from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults\n",
|
|
|
- "from datetime import datetime\n",
|
|
|
- "import numpy as np\n",
|
|
|
- "import matplotlib.pyplot as plt\n",
|
|
|
- "import random\n",
|
|
|
- "from ipywidgets import Output\n",
|
|
|
- "import hashlib\n",
|
|
|
- "\n",
|
|
|
- "#SuperRes\n",
|
|
|
- "if is_colab:\n",
|
|
|
- " gitclone(\"https://github.com/CompVis/latent-diffusion.git\")\n",
|
|
|
- " gitclone(\"https://github.com/CompVis/taming-transformers\")\n",
|
|
|
- " pipie(\"./taming-transformers\")\n",
|
|
|
- " pipi(\"ipywidgets omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops wandb\")\n",
|
|
|
- "\n",
|
|
|
- "#SuperRes\n",
|
|
|
- "import ipywidgets as widgets\n",
|
|
|
- "import os\n",
|
|
|
- "sys.path.append(\".\")\n",
|
|
|
- "sys.path.append('./taming-transformers')\n",
|
|
|
- "from taming.models import vqgan # checking correct import from taming\n",
|
|
|
- "from torchvision.datasets.utils import download_url\n",
|
|
|
- "\n",
|
|
|
- "if is_colab:\n",
|
|
|
- " os.chdir('/content/latent-diffusion')\n",
|
|
|
- "else:\n",
|
|
|
- " #os.chdir('latent-diffusion')\n",
|
|
|
- " sys.path.append('latent-diffusion')\n",
|
|
|
- "from functools import partial\n",
|
|
|
- "from ldm.util import instantiate_from_config\n",
|
|
|
- "from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like\n",
|
|
|
- "# from ldm.models.diffusion.ddim import DDIMSampler\n",
|
|
|
- "from ldm.util import ismap\n",
|
|
|
- "if is_colab:\n",
|
|
|
- " os.chdir('/content')\n",
|
|
|
- " from google.colab import files\n",
|
|
|
- "else:\n",
|
|
|
- " os.chdir(f'{PROJECT_DIR}')\n",
|
|
|
- "from IPython.display import Image as ipyimg\n",
|
|
|
- "from numpy import asarray\n",
|
|
|
- "from einops import rearrange, repeat\n",
|
|
|
- "import torch, torchvision\n",
|
|
|
- "import time\n",
|
|
|
- "from omegaconf import OmegaConf\n",
|
|
|
- "import warnings\n",
|
|
|
- "warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
|
|
|
- "\n",
|
|
|
- "# AdaBins stuff\n",
|
|
|
- "if USE_ADABINS:\n",
|
|
|
- " if is_colab:\n",
|
|
|
- " gitclone(\"https://github.com/shariqfarooq123/AdaBins.git\")\n",
|
|
|
- " if not os.path.exists(f'{model_path}/AdaBins_nyu.pt'):\n",
|
|
|
- " wget(\"https://cloudflare-ipfs.com/ipfs/Qmd2mMnDLWePKmgfS8m6ntAg4nhV5VkUyAydYBp8cWWeB7/AdaBins_nyu.pt\", model_path)\n",
|
|
|
- " pathlib.Path(\"pretrained\").mkdir(parents=True, exist_ok=True)\n",
|
|
|
- " shutil.copyfile(f\"{model_path}/AdaBins_nyu.pt\", \"pretrained/AdaBins_nyu.pt\")\n",
|
|
|
- " sys.path.append('./AdaBins')\n",
|
|
|
- " from infer import InferenceHelper\n",
|
|
|
- " MAX_ADABINS_AREA = 500000\n",
|
|
|
- "\n",
|
|
|
- "import torch\n",
|
|
|
- "DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
|
|
|
- "print('Using device:', DEVICE)\n",
|
|
|
- "device = DEVICE # At least one of the modules expects this name..\n",
|
|
|
- "\n",
|
|
|
- "if torch.cuda.get_device_capability(DEVICE) == (8,0): ## A100 fix thanks to Emad\n",
|
|
|
- " print('Disabling CUDNN for A100 gpu', file=sys.stderr)\n",
|
|
|
- " torch.backends.cudnn.enabled = False"
|
|
|
- ],
|
|
|
- "outputs": [],
|
|
|
- "execution_count": null
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "metadata": {
|
|
|
- "cellView": "form",
|
|
|
- "id": "DefMidasFns"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "#@title ### 1.4 Define Midas functions\n",
|
|
|
- "\n",
|
|
|
- "from midas.dpt_depth import DPTDepthModel\n",
|
|
|
- "from midas.midas_net import MidasNet\n",
|
|
|
- "from midas.midas_net_custom import MidasNet_small\n",
|
|
|
- "from midas.transforms import Resize, NormalizeImage, PrepareForNet\n",
|
|
|
- "\n",
|
|
|
- "# Initialize MiDaS depth model.\n",
|
|
|
- "# It remains resident in VRAM and likely takes around 2GB VRAM.\n",
|
|
|
- "# You could instead initialize it for each frame (and free it after each frame) to save VRAM.. but initializing it is slow.\n",
|
|
|
- "default_models = {\n",
|
|
|
- " \"midas_v21_small\": f\"{model_path}/midas_v21_small-70d6b9c8.pt\",\n",
|
|
|
- " \"midas_v21\": f\"{model_path}/midas_v21-f6b98070.pt\",\n",
|
|
|
- " \"dpt_large\": f\"{model_path}/dpt_large-midas-2f21e586.pt\",\n",
|
|
|
- " \"dpt_hybrid\": f\"{model_path}/dpt_hybrid-midas-501f0c75.pt\",\n",
|
|
|
- " \"dpt_hybrid_nyu\": f\"{model_path}/dpt_hybrid_nyu-2ce69ec7.pt\",}\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "def init_midas_depth_model(midas_model_type=\"dpt_large\", optimize=True):\n",
|
|
|
- " midas_model = None\n",
|
|
|
- " net_w = None\n",
|
|
|
- " net_h = None\n",
|
|
|
- " resize_mode = None\n",
|
|
|
- " normalization = None\n",
|
|
|
- "\n",
|
|
|
- " print(f\"Initializing MiDaS '{midas_model_type}' depth model...\")\n",
|
|
|
- " # load network\n",
|
|
|
- " midas_model_path = default_models[midas_model_type]\n",
|
|
|
- "\n",
|
|
|
- " if midas_model_type == \"dpt_large\": # DPT-Large\n",
|
|
|
- " midas_model = DPTDepthModel(\n",
|
|
|
- " path=midas_model_path,\n",
|
|
|
- " backbone=\"vitl16_384\",\n",
|
|
|
- " non_negative=True,\n",
|
|
|
- " )\n",
|
|
|
- " net_w, net_h = 384, 384\n",
|
|
|
- " resize_mode = \"minimal\"\n",
|
|
|
- " normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n",
|
|
|
- " elif midas_model_type == \"dpt_hybrid\": #DPT-Hybrid\n",
|
|
|
- " midas_model = DPTDepthModel(\n",
|
|
|
- " path=midas_model_path,\n",
|
|
|
- " backbone=\"vitb_rn50_384\",\n",
|
|
|
- " non_negative=True,\n",
|
|
|
- " )\n",
|
|
|
- " net_w, net_h = 384, 384\n",
|
|
|
- " resize_mode=\"minimal\"\n",
|
|
|
- " normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n",
|
|
|
- " elif midas_model_type == \"dpt_hybrid_nyu\": #DPT-Hybrid-NYU\n",
|
|
|
- " midas_model = DPTDepthModel(\n",
|
|
|
- " path=midas_model_path,\n",
|
|
|
- " backbone=\"vitb_rn50_384\",\n",
|
|
|
- " non_negative=True,\n",
|
|
|
- " )\n",
|
|
|
- " net_w, net_h = 384, 384\n",
|
|
|
- " resize_mode=\"minimal\"\n",
|
|
|
- " normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n",
|
|
|
- " elif midas_model_type == \"midas_v21\":\n",
|
|
|
- " midas_model = MidasNet(midas_model_path, non_negative=True)\n",
|
|
|
- " net_w, net_h = 384, 384\n",
|
|
|
- " resize_mode=\"upper_bound\"\n",
|
|
|
- " normalization = NormalizeImage(\n",
|
|
|
- " mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]\n",
|
|
|
- " )\n",
|
|
|
- " elif midas_model_type == \"midas_v21_small\":\n",
|
|
|
- " midas_model = MidasNet_small(midas_model_path, features=64, backbone=\"efficientnet_lite3\", exportable=True, non_negative=True, blocks={'expand': True})\n",
|
|
|
- " net_w, net_h = 256, 256\n",
|
|
|
- " resize_mode=\"upper_bound\"\n",
|
|
|
- " normalization = NormalizeImage(\n",
|
|
|
- " mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]\n",
|
|
|
- " )\n",
|
|
|
- " else:\n",
|
|
|
- " print(f\"midas_model_type '{midas_model_type}' not implemented\")\n",
|
|
|
- " assert False\n",
|
|
|
- "\n",
|
|
|
- " midas_transform = T.Compose(\n",
|
|
|
- " [\n",
|
|
|
- " Resize(\n",
|
|
|
- " net_w,\n",
|
|
|
- " net_h,\n",
|
|
|
- " resize_target=None,\n",
|
|
|
- " keep_aspect_ratio=True,\n",
|
|
|
- " ensure_multiple_of=32,\n",
|
|
|
- " resize_method=resize_mode,\n",
|
|
|
- " image_interpolation_method=cv2.INTER_CUBIC,\n",
|
|
|
- " ),\n",
|
|
|
- " normalization,\n",
|
|
|
- " PrepareForNet(),\n",
|
|
|
- " ]\n",
|
|
|
- " )\n",
|
|
|
- "\n",
|
|
|
- " midas_model.eval()\n",
|
|
|
- " \n",
|
|
|
- " if optimize==True:\n",
|
|
|
- " if DEVICE == torch.device(\"cuda\"):\n",
|
|
|
- " midas_model = midas_model.to(memory_format=torch.channels_last) \n",
|
|
|
- " midas_model = midas_model.half()\n",
|
|
|
- "\n",
|
|
|
- " midas_model.to(DEVICE)\n",
|
|
|
- "\n",
|
|
|
- " print(f\"MiDaS '{midas_model_type}' depth model initialized.\")\n",
|
|
|
- " return midas_model, midas_transform, net_w, net_h, resize_mode, normalization"
|
|
|
- ],
|
|
|
- "outputs": [],
|
|
|
- "execution_count": null
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "metadata": {
|
|
|
- "cellView": "form",
|
|
|
- "id": "DefFns"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "#@title 1.5 Define necessary functions\n",
|
|
|
- "\n",
|
|
|
- "# https://gist.github.com/adefossez/0646dbe9ed4005480a2407c62aac8869\n",
|
|
|
- "\n",
|
|
|
- "import py3d_tools as p3dT\n",
|
|
|
- "import disco_xform_utils as dxf\n",
|
|
|
- "\n",
|
|
|
- "def interp(t):\n",
|
|
|
- " return 3 * t**2 - 2 * t ** 3\n",
|
|
|
- "\n",
|
|
|
- "def perlin(width, height, scale=10, device=None):\n",
|
|
|
- " gx, gy = torch.randn(2, width + 1, height + 1, 1, 1, device=device)\n",
|
|
|
- " xs = torch.linspace(0, 1, scale + 1)[:-1, None].to(device)\n",
|
|
|
- " ys = torch.linspace(0, 1, scale + 1)[None, :-1].to(device)\n",
|
|
|
- " wx = 1 - interp(xs)\n",
|
|
|
- " wy = 1 - interp(ys)\n",
|
|
|
- " dots = 0\n",
|
|
|
- " dots += wx * wy * (gx[:-1, :-1] * xs + gy[:-1, :-1] * ys)\n",
|
|
|
- " dots += (1 - wx) * wy * (-gx[1:, :-1] * (1 - xs) + gy[1:, :-1] * ys)\n",
|
|
|
- " dots += wx * (1 - wy) * (gx[:-1, 1:] * xs - gy[:-1, 1:] * (1 - ys))\n",
|
|
|
- " dots += (1 - wx) * (1 - wy) * (-gx[1:, 1:] * (1 - xs) - gy[1:, 1:] * (1 - ys))\n",
|
|
|
- " return dots.permute(0, 2, 1, 3).contiguous().view(width * scale, height * scale)\n",
|
|
|
- "\n",
|
|
|
- "def perlin_ms(octaves, width, height, grayscale, device=device):\n",
|
|
|
- " out_array = [0.5] if grayscale else [0.5, 0.5, 0.5]\n",
|
|
|
- " # out_array = [0.0] if grayscale else [0.0, 0.0, 0.0]\n",
|
|
|
- " for i in range(1 if grayscale else 3):\n",
|
|
|
- " scale = 2 ** len(octaves)\n",
|
|
|
- " oct_width = width\n",
|
|
|
- " oct_height = height\n",
|
|
|
- " for oct in octaves:\n",
|
|
|
- " p = perlin(oct_width, oct_height, scale, device)\n",
|
|
|
- " out_array[i] += p * oct\n",
|
|
|
- " scale //= 2\n",
|
|
|
- " oct_width *= 2\n",
|
|
|
- " oct_height *= 2\n",
|
|
|
- " return torch.cat(out_array)\n",
|
|
|
- "\n",
|
|
|
- "def create_perlin_noise(octaves=[1, 1, 1, 1], width=2, height=2, grayscale=True):\n",
|
|
|
- " out = perlin_ms(octaves, width, height, grayscale)\n",
|
|
|
- " if grayscale:\n",
|
|
|
- " out = TF.resize(size=(side_y, side_x), img=out.unsqueeze(0))\n",
|
|
|
- " out = TF.to_pil_image(out.clamp(0, 1)).convert('RGB')\n",
|
|
|
- " else:\n",
|
|
|
- " out = out.reshape(-1, 3, out.shape[0]//3, out.shape[1])\n",
|
|
|
- " out = TF.resize(size=(side_y, side_x), img=out)\n",
|
|
|
- " out = TF.to_pil_image(out.clamp(0, 1).squeeze())\n",
|
|
|
- "\n",
|
|
|
- " out = ImageOps.autocontrast(out)\n",
|
|
|
- " return out\n",
|
|
|
- "\n",
|
|
|
- "def regen_perlin():\n",
|
|
|
- " if perlin_mode == 'color':\n",
|
|
|
- " init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)\n",
|
|
|
- " init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, False)\n",
|
|
|
- " elif perlin_mode == 'gray':\n",
|
|
|
- " init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, True)\n",
|
|
|
- " init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)\n",
|
|
|
- " else:\n",
|
|
|
- " init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)\n",
|
|
|
- " init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)\n",
|
|
|
- "\n",
|
|
|
- " init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device).unsqueeze(0).mul(2).sub(1)\n",
|
|
|
- " del init2\n",
|
|
|
- " return init.expand(batch_size, -1, -1, -1)\n",
|
|
|
- "\n",
|
|
|
- "def fetch(url_or_path):\n",
|
|
|
- " if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):\n",
|
|
|
- " r = requests.get(url_or_path)\n",
|
|
|
- " r.raise_for_status()\n",
|
|
|
- " fd = io.BytesIO()\n",
|
|
|
- " fd.write(r.content)\n",
|
|
|
- " fd.seek(0)\n",
|
|
|
- " return fd\n",
|
|
|
- " return open(url_or_path, 'rb')\n",
|
|
|
- "\n",
|
|
|
- "def read_image_workaround(path):\n",
|
|
|
- " \"\"\"OpenCV reads images as BGR, Pillow saves them as RGB. Work around\n",
|
|
|
- " this incompatibility to avoid colour inversions.\"\"\"\n",
|
|
|
- " im_tmp = cv2.imread(path)\n",
|
|
|
- " return cv2.cvtColor(im_tmp, cv2.COLOR_BGR2RGB)\n",
|
|
|
- "\n",
|
|
|
- "def parse_prompt(prompt):\n",
|
|
|
- " if prompt.startswith('http://') or prompt.startswith('https://'):\n",
|
|
|
- " vals = prompt.rsplit(':', 2)\n",
|
|
|
- " vals = [vals[0] + ':' + vals[1], *vals[2:]]\n",
|
|
|
- " else:\n",
|
|
|
- " vals = prompt.rsplit(':', 1)\n",
|
|
|
- " vals = vals + ['', '1'][len(vals):]\n",
|
|
|
- " return vals[0], float(vals[1])\n",
|
|
|
- "\n",
|
|
|
- "def sinc(x):\n",
|
|
|
- " return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))\n",
|
|
|
- "\n",
|
|
|
- "def lanczos(x, a):\n",
|
|
|
- " cond = torch.logical_and(-a < x, x < a)\n",
|
|
|
- " out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))\n",
|
|
|
- " return out / out.sum()\n",
|
|
|
- "\n",
|
|
|
- "def ramp(ratio, width):\n",
|
|
|
- " n = math.ceil(width / ratio + 1)\n",
|
|
|
- " out = torch.empty([n])\n",
|
|
|
- " cur = 0\n",
|
|
|
- " for i in range(out.shape[0]):\n",
|
|
|
- " out[i] = cur\n",
|
|
|
- " cur += ratio\n",
|
|
|
- " return torch.cat([-out[1:].flip([0]), out])[1:-1]\n",
|
|
|
- "\n",
|
|
|
- "def resample(input, size, align_corners=True):\n",
|
|
|
- " n, c, h, w = input.shape\n",
|
|
|
- " dh, dw = size\n",
|
|
|
- "\n",
|
|
|
- " input = input.reshape([n * c, 1, h, w])\n",
|
|
|
- "\n",
|
|
|
- " if dh < h:\n",
|
|
|
- " kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)\n",
|
|
|
- " pad_h = (kernel_h.shape[0] - 1) // 2\n",
|
|
|
- " input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')\n",
|
|
|
- " input = F.conv2d(input, kernel_h[None, None, :, None])\n",
|
|
|
- "\n",
|
|
|
- " if dw < w:\n",
|
|
|
- " kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)\n",
|
|
|
- " pad_w = (kernel_w.shape[0] - 1) // 2\n",
|
|
|
- " input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')\n",
|
|
|
- " input = F.conv2d(input, kernel_w[None, None, None, :])\n",
|
|
|
- "\n",
|
|
|
- " input = input.reshape([n, c, h, w])\n",
|
|
|
- " return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)\n",
|
|
|
- "\n",
|
|
|
- "class MakeCutouts(nn.Module):\n",
|
|
|
- " def __init__(self, cut_size, cutn, skip_augs=False):\n",
|
|
|
- " super().__init__()\n",
|
|
|
- " self.cut_size = cut_size\n",
|
|
|
- " self.cutn = cutn\n",
|
|
|
- " self.skip_augs = skip_augs\n",
|
|
|
- " self.augs = T.Compose([\n",
|
|
|
- " T.RandomHorizontalFlip(p=0.5),\n",
|
|
|
- " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
- " T.RandomAffine(degrees=15, translate=(0.1, 0.1)),\n",
|
|
|
- " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
- " T.RandomPerspective(distortion_scale=0.4, p=0.7),\n",
|
|
|
- " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
- " T.RandomGrayscale(p=0.15),\n",
|
|
|
- " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
- " # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),\n",
|
|
|
- " ])\n",
|
|
|
- "\n",
|
|
|
- " def forward(self, input):\n",
|
|
|
- " input = T.Pad(input.shape[2]//4, fill=0)(input)\n",
|
|
|
- " sideY, sideX = input.shape[2:4]\n",
|
|
|
- " max_size = min(sideX, sideY)\n",
|
|
|
- "\n",
|
|
|
- " cutouts = []\n",
|
|
|
- " for ch in range(self.cutn):\n",
|
|
|
- " if ch > self.cutn - self.cutn//4:\n",
|
|
|
- " cutout = input.clone()\n",
|
|
|
- " else:\n",
|
|
|
- " size = int(max_size * torch.zeros(1,).normal_(mean=.8, std=.3).clip(float(self.cut_size/max_size), 1.))\n",
|
|
|
- " offsetx = torch.randint(0, abs(sideX - size + 1), ())\n",
|
|
|
- " offsety = torch.randint(0, abs(sideY - size + 1), ())\n",
|
|
|
- " cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]\n",
|
|
|
- "\n",
|
|
|
- " if not self.skip_augs:\n",
|
|
|
- " cutout = self.augs(cutout)\n",
|
|
|
- " cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))\n",
|
|
|
- " del cutout\n",
|
|
|
- "\n",
|
|
|
- " cutouts = torch.cat(cutouts, dim=0)\n",
|
|
|
- " return cutouts\n",
|
|
|
- "\n",
|
|
|
- "cutout_debug = False\n",
|
|
|
- "padargs = {}\n",
|
|
|
- "\n",
|
|
|
- "class MakeCutoutsDango(nn.Module):\n",
|
|
|
- " def __init__(self, cut_size,\n",
|
|
|
- " Overview=4, \n",
|
|
|
- " InnerCrop = 0, IC_Size_Pow=0.5, IC_Grey_P = 0.2\n",
|
|
|
- " ):\n",
|
|
|
- " super().__init__()\n",
|
|
|
- " self.cut_size = cut_size\n",
|
|
|
- " self.Overview = Overview\n",
|
|
|
- " self.InnerCrop = InnerCrop\n",
|
|
|
- " self.IC_Size_Pow = IC_Size_Pow\n",
|
|
|
- " self.IC_Grey_P = IC_Grey_P\n",
|
|
|
- " if args.animation_mode == 'None':\n",
|
|
|
- " self.augs = T.Compose([\n",
|
|
|
- " T.RandomHorizontalFlip(p=0.5),\n",
|
|
|
- " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
- " T.RandomAffine(degrees=10, translate=(0.05, 0.05), interpolation = T.InterpolationMode.BILINEAR),\n",
|
|
|
- " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
- " T.RandomGrayscale(p=0.1),\n",
|
|
|
- " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
- " T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),\n",
|
|
|
- " ])\n",
|
|
|
- " elif args.animation_mode == 'Video Input':\n",
|
|
|
- " self.augs = T.Compose([\n",
|
|
|
- " T.RandomHorizontalFlip(p=0.5),\n",
|
|
|
- " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
- " T.RandomAffine(degrees=15, translate=(0.1, 0.1)),\n",
|
|
|
- " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
- " T.RandomPerspective(distortion_scale=0.4, p=0.7),\n",
|
|
|
- " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
- " T.RandomGrayscale(p=0.15),\n",
|
|
|
- " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
- " # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),\n",
|
|
|
- " ])\n",
|
|
|
- " elif args.animation_mode == '2D' or args.animation_mode == '3D':\n",
|
|
|
- " self.augs = T.Compose([\n",
|
|
|
- " T.RandomHorizontalFlip(p=0.4),\n",
|
|
|
- " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
- " T.RandomAffine(degrees=10, translate=(0.05, 0.05), interpolation = T.InterpolationMode.BILINEAR),\n",
|
|
|
- " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
- " T.RandomGrayscale(p=0.1),\n",
|
|
|
- " T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
- " T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.3),\n",
|
|
|
- " ])\n",
|
|
|
- " \n",
|
|
|
- "\n",
|
|
|
- " def forward(self, input):\n",
|
|
|
- " cutouts = []\n",
|
|
|
- " gray = T.Grayscale(3)\n",
|
|
|
- " sideY, sideX = input.shape[2:4]\n",
|
|
|
- " max_size = min(sideX, sideY)\n",
|
|
|
- " min_size = min(sideX, sideY, self.cut_size)\n",
|
|
|
- " l_size = max(sideX, sideY)\n",
|
|
|
- " output_shape = [1,3,self.cut_size,self.cut_size] \n",
|
|
|
- " output_shape_2 = [1,3,self.cut_size+2,self.cut_size+2]\n",
|
|
|
- " pad_input = F.pad(input,((sideY-max_size)//2,(sideY-max_size)//2,(sideX-max_size)//2,(sideX-max_size)//2), **padargs)\n",
|
|
|
- " cutout = resize(pad_input, out_shape=output_shape)\n",
|
|
|
- "\n",
|
|
|
- " if self.Overview>0:\n",
|
|
|
- " if self.Overview<=4:\n",
|
|
|
- " if self.Overview>=1:\n",
|
|
|
- " cutouts.append(cutout)\n",
|
|
|
- " if self.Overview>=2:\n",
|
|
|
- " cutouts.append(gray(cutout))\n",
|
|
|
- " if self.Overview>=3:\n",
|
|
|
- " cutouts.append(TF.hflip(cutout))\n",
|
|
|
- " if self.Overview==4:\n",
|
|
|
- " cutouts.append(gray(TF.hflip(cutout)))\n",
|
|
|
- " else:\n",
|
|
|
- " cutout = resize(pad_input, out_shape=output_shape)\n",
|
|
|
- " for _ in range(self.Overview):\n",
|
|
|
- " cutouts.append(cutout)\n",
|
|
|
- "\n",
|
|
|
- " if cutout_debug:\n",
|
|
|
- " if is_colab:\n",
|
|
|
- " TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save(\"/content/cutout_overview0.jpg\",quality=99)\n",
|
|
|
- " else:\n",
|
|
|
- " TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save(\"cutout_overview0.jpg\",quality=99)\n",
|
|
|
- "\n",
|
|
|
- " \n",
|
|
|
- " if self.InnerCrop >0:\n",
|
|
|
- " for i in range(self.InnerCrop):\n",
|
|
|
- " size = int(torch.rand([])**self.IC_Size_Pow * (max_size - min_size) + min_size)\n",
|
|
|
- " offsetx = torch.randint(0, sideX - size + 1, ())\n",
|
|
|
- " offsety = torch.randint(0, sideY - size + 1, ())\n",
|
|
|
- " cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]\n",
|
|
|
- " if i <= int(self.IC_Grey_P * self.InnerCrop):\n",
|
|
|
- " cutout = gray(cutout)\n",
|
|
|
- " cutout = resize(cutout, out_shape=output_shape)\n",
|
|
|
- " cutouts.append(cutout)\n",
|
|
|
- " if cutout_debug:\n",
|
|
|
- " if is_colab:\n",
|
|
|
- " TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save(\"/content/cutout_InnerCrop.jpg\",quality=99)\n",
|
|
|
- " else:\n",
|
|
|
- " TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save(\"cutout_InnerCrop.jpg\",quality=99)\n",
|
|
|
- " cutouts = torch.cat(cutouts)\n",
|
|
|
- " if skip_augs is not True: cutouts=self.augs(cutouts)\n",
|
|
|
- " return cutouts\n",
|
|
|
- "\n",
|
|
|
- "def spherical_dist_loss(x, y):\n",
|
|
|
- " x = F.normalize(x, dim=-1)\n",
|
|
|
- " y = F.normalize(y, dim=-1)\n",
|
|
|
- " return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) \n",
|
|
|
- "\n",
|
|
|
- "def tv_loss(input):\n",
|
|
|
- " \"\"\"L2 total variation loss, as in Mahendran et al.\"\"\"\n",
|
|
|
- " input = F.pad(input, (0, 1, 0, 1), 'replicate')\n",
|
|
|
- " x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]\n",
|
|
|
- " y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]\n",
|
|
|
- " return (x_diff**2 + y_diff**2).mean([1, 2, 3])\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "def range_loss(input):\n",
|
|
|
- " return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])\n",
|
|
|
- "\n",
|
|
|
- "stop_on_next_loop = False # Make sure GPU memory doesn't get corrupted from cancelling the run mid-way through, allow a full frame to complete\n",
|
|
|
- "\n",
|
|
|
- "def do_3d_step(img_filepath, frame_num, midas_model, midas_transform):\n",
|
|
|
- " if args.key_frames:\n",
|
|
|
- " translation_x = args.translation_x_series[frame_num]\n",
|
|
|
- " translation_y = args.translation_y_series[frame_num]\n",
|
|
|
- " translation_z = args.translation_z_series[frame_num]\n",
|
|
|
- " rotation_3d_x = args.rotation_3d_x_series[frame_num]\n",
|
|
|
- " rotation_3d_y = args.rotation_3d_y_series[frame_num]\n",
|
|
|
- " rotation_3d_z = args.rotation_3d_z_series[frame_num]\n",
|
|
|
- " print(\n",
|
|
|
- " f'translation_x: {translation_x}',\n",
|
|
|
- " f'translation_y: {translation_y}',\n",
|
|
|
- " f'translation_z: {translation_z}',\n",
|
|
|
- " f'rotation_3d_x: {rotation_3d_x}',\n",
|
|
|
- " f'rotation_3d_y: {rotation_3d_y}',\n",
|
|
|
- " f'rotation_3d_z: {rotation_3d_z}',\n",
|
|
|
- " )\n",
|
|
|
- "\n",
|
|
|
- " trans_scale = 1.0/200.0\n",
|
|
|
- " translate_xyz = [-translation_x*trans_scale, translation_y*trans_scale, -translation_z*trans_scale]\n",
|
|
|
- " rotate_xyz_degrees = [rotation_3d_x, rotation_3d_y, rotation_3d_z]\n",
|
|
|
- " print('translation:',translate_xyz)\n",
|
|
|
- " print('rotation:',rotate_xyz_degrees)\n",
|
|
|
- " rotate_xyz = [math.radians(rotate_xyz_degrees[0]), math.radians(rotate_xyz_degrees[1]), math.radians(rotate_xyz_degrees[2])]\n",
|
|
|
- " rot_mat = p3dT.euler_angles_to_matrix(torch.tensor(rotate_xyz, device=device), \"XYZ\").unsqueeze(0)\n",
|
|
|
- " print(\"rot_mat: \" + str(rot_mat))\n",
|
|
|
- " next_step_pil = dxf.transform_image_3d(img_filepath, midas_model, midas_transform, DEVICE,\n",
|
|
|
- " rot_mat, translate_xyz, args.near_plane, args.far_plane,\n",
|
|
|
- " args.fov, padding_mode=args.padding_mode,\n",
|
|
|
- " sampling_mode=args.sampling_mode, midas_weight=args.midas_weight)\n",
|
|
|
- " return next_step_pil\n",
|
|
|
- "\n",
|
|
|
- "def do_run():\n",
|
|
|
- " seed = args.seed\n",
|
|
|
- " print(range(args.start_frame, args.max_frames))\n",
|
|
|
- "\n",
|
|
|
- " if (args.animation_mode == \"3D\") and (args.midas_weight > 0.0):\n",
|
|
|
- " midas_model, midas_transform, midas_net_w, midas_net_h, midas_resize_mode, midas_normalization = init_midas_depth_model(args.midas_depth_model)\n",
|
|
|
- " for frame_num in range(args.start_frame, args.max_frames):\n",
|
|
|
- " if stop_on_next_loop:\n",
|
|
|
- " break\n",
|
|
|
- " \n",
|
|
|
- " display.clear_output(wait=True)\n",
|
|
|
- "\n",
|
|
|
- " # Print Frame progress if animation mode is on\n",
|
|
|
- " if args.animation_mode != \"None\":\n",
|
|
|
- " batchBar = tqdm(range(args.max_frames), desc =\"Frames\")\n",
|
|
|
- " batchBar.n = frame_num\n",
|
|
|
- " batchBar.refresh()\n",
|
|
|
- "\n",
|
|
|
- " \n",
|
|
|
- " # Inits if not video frames\n",
|
|
|
- " if args.animation_mode != \"Video Input\":\n",
|
|
|
- " if args.init_image == '':\n",
|
|
|
- " init_image = None\n",
|
|
|
- " else:\n",
|
|
|
- " init_image = args.init_image\n",
|
|
|
- " init_scale = args.init_scale\n",
|
|
|
- " skip_steps = args.skip_steps\n",
|
|
|
- "\n",
|
|
|
- " if args.animation_mode == \"2D\":\n",
|
|
|
- " if args.key_frames:\n",
|
|
|
- " angle = args.angle_series[frame_num]\n",
|
|
|
- " zoom = args.zoom_series[frame_num]\n",
|
|
|
- " translation_x = args.translation_x_series[frame_num]\n",
|
|
|
- " translation_y = args.translation_y_series[frame_num]\n",
|
|
|
- " print(\n",
|
|
|
- " f'angle: {angle}',\n",
|
|
|
- " f'zoom: {zoom}',\n",
|
|
|
- " f'translation_x: {translation_x}',\n",
|
|
|
- " f'translation_y: {translation_y}',\n",
|
|
|
- " )\n",
|
|
|
- " \n",
|
|
|
- " if frame_num > 0:\n",
|
|
|
- " seed += 1\n",
|
|
|
- " if resume_run and frame_num == start_frame:\n",
|
|
|
- " img_0 = cv2.imread(batchFolder+f\"/{batch_name}({batchNum})_{start_frame-1:04}.png\")\n",
|
|
|
- " else:\n",
|
|
|
- " img_0 = cv2.imread('prevFrame.png')\n",
|
|
|
- " center = (1*img_0.shape[1]//2, 1*img_0.shape[0]//2)\n",
|
|
|
- " trans_mat = np.float32(\n",
|
|
|
- " [[1, 0, translation_x],\n",
|
|
|
- " [0, 1, translation_y]]\n",
|
|
|
- " )\n",
|
|
|
- " rot_mat = cv2.getRotationMatrix2D( center, angle, zoom )\n",
|
|
|
- " trans_mat = np.vstack([trans_mat, [0,0,1]])\n",
|
|
|
- " rot_mat = np.vstack([rot_mat, [0,0,1]])\n",
|
|
|
- " transformation_matrix = np.matmul(rot_mat, trans_mat)\n",
|
|
|
- " img_0 = cv2.warpPerspective(\n",
|
|
|
- " img_0,\n",
|
|
|
- " transformation_matrix,\n",
|
|
|
- " (img_0.shape[1], img_0.shape[0]),\n",
|
|
|
- " borderMode=cv2.BORDER_WRAP\n",
|
|
|
- " )\n",
|
|
|
- "\n",
|
|
|
- " cv2.imwrite('prevFrameScaled.png', img_0)\n",
|
|
|
- " init_image = 'prevFrameScaled.png'\n",
|
|
|
- " init_scale = args.frames_scale\n",
|
|
|
- " skip_steps = args.calc_frames_skip_steps\n",
|
|
|
- "\n",
|
|
|
- " if args.animation_mode == \"3D\":\n",
|
|
|
- " if frame_num == 0:\n",
|
|
|
- " pass\n",
|
|
|
- " else:\n",
|
|
|
- " seed += 1 \n",
|
|
|
- " if resume_run and frame_num == start_frame:\n",
|
|
|
- " img_filepath = batchFolder+f\"/{batch_name}({batchNum})_{start_frame-1:04}.png\"\n",
|
|
|
- " if turbo_mode and frame_num > turbo_preroll:\n",
|
|
|
- " shutil.copyfile(img_filepath, 'oldFrameScaled.png')\n",
|
|
|
- " else:\n",
|
|
|
- " img_filepath = '/content/prevFrame.png' if is_colab else 'prevFrame.png'\n",
|
|
|
- "\n",
|
|
|
- " next_step_pil = do_3d_step(img_filepath, frame_num, midas_model, midas_transform)\n",
|
|
|
- " next_step_pil.save('prevFrameScaled.png')\n",
|
|
|
- "\n",
|
|
|
- " ### Turbo mode - skip some diffusions, use 3d morph for clarity and to save time\n",
|
|
|
- " if turbo_mode:\n",
|
|
|
- " if frame_num == turbo_preroll: #start tracking oldframe\n",
|
|
|
- " next_step_pil.save('oldFrameScaled.png')#stash for later blending \n",
|
|
|
- " elif frame_num > turbo_preroll:\n",
|
|
|
- " #set up 2 warped image sequences, old & new, to blend toward new diff image\n",
|
|
|
- " old_frame = do_3d_step('oldFrameScaled.png', frame_num, midas_model, midas_transform)\n",
|
|
|
- " old_frame.save('oldFrameScaled.png')\n",
|
|
|
- " if frame_num % int(turbo_steps) != 0: \n",
|
|
|
- " print('turbo skip this frame: skipping clip diffusion steps')\n",
|
|
|
- " filename = f'{args.batch_name}({args.batchNum})_{frame_num:04}.png'\n",
|
|
|
- " blend_factor = ((frame_num % int(turbo_steps))+1)/int(turbo_steps)\n",
|
|
|
- " print('turbo skip this frame: skipping clip diffusion steps and saving blended frame')\n",
|
|
|
- " newWarpedImg = cv2.imread('prevFrameScaled.png')#this is already updated..\n",
|
|
|
- " oldWarpedImg = cv2.imread('oldFrameScaled.png')\n",
|
|
|
- " blendedImage = cv2.addWeighted(newWarpedImg, blend_factor, oldWarpedImg,1-blend_factor, 0.0)\n",
|
|
|
- " cv2.imwrite(f'{batchFolder}/{filename}',blendedImage)\n",
|
|
|
- " next_step_pil.save(f'{img_filepath}') # save it also as prev_frame to feed next iteration\n",
|
|
|
- " continue\n",
|
|
|
- " else:\n",
|
|
|
- " #if not a skip frame, will run diffusion and need to blend.\n",
|
|
|
- " oldWarpedImg = cv2.imread('prevFrameScaled.png')\n",
|
|
|
- " cv2.imwrite(f'oldFrameScaled.png',oldWarpedImg)#swap in for blending later \n",
|
|
|
- " print('clip/diff this frame - generate clip diff image')\n",
|
|
|
- "\n",
|
|
|
- " init_image = 'prevFrameScaled.png'\n",
|
|
|
- " init_scale = args.frames_scale\n",
|
|
|
- " skip_steps = args.calc_frames_skip_steps\n",
|
|
|
- "\n",
|
|
|
- " if args.animation_mode == \"Video Input\":\n",
|
|
|
- " if not video_init_seed_continuity:\n",
|
|
|
- " seed += 1\n",
|
|
|
- " init_image = f'{videoFramesFolder}/{frame_num+1:04}.jpg'\n",
|
|
|
- " init_scale = args.frames_scale\n",
|
|
|
- " skip_steps = args.calc_frames_skip_steps\n",
|
|
|
- "\n",
|
|
|
- " loss_values = []\n",
|
|
|
- " \n",
|
|
|
- " if seed is not None:\n",
|
|
|
- " np.random.seed(seed)\n",
|
|
|
- " random.seed(seed)\n",
|
|
|
- " torch.manual_seed(seed)\n",
|
|
|
- " torch.cuda.manual_seed_all(seed)\n",
|
|
|
- " torch.backends.cudnn.deterministic = True\n",
|
|
|
- " \n",
|
|
|
- " target_embeds, weights = [], []\n",
|
|
|
- " \n",
|
|
|
- " if args.prompts_series is not None and frame_num >= len(args.prompts_series):\n",
|
|
|
- " frame_prompt = args.prompts_series[-1]\n",
|
|
|
- " elif args.prompts_series is not None:\n",
|
|
|
- " frame_prompt = args.prompts_series[frame_num]\n",
|
|
|
- " else:\n",
|
|
|
- " frame_prompt = []\n",
|
|
|
- " \n",
|
|
|
- " print(args.image_prompts_series)\n",
|
|
|
- " if args.image_prompts_series is not None and frame_num >= len(args.image_prompts_series):\n",
|
|
|
- " image_prompt = args.image_prompts_series[-1]\n",
|
|
|
- " elif args.image_prompts_series is not None:\n",
|
|
|
- " image_prompt = args.image_prompts_series[frame_num]\n",
|
|
|
- " else:\n",
|
|
|
- " image_prompt = []\n",
|
|
|
- "\n",
|
|
|
- " print(f'Frame {frame_num} Prompt: {frame_prompt}')\n",
|
|
|
- "\n",
|
|
|
- " model_stats = []\n",
|
|
|
- " for clip_model in clip_models:\n",
|
|
|
- " cutn = 16\n",
|
|
|
- " model_stat = {\"clip_model\":None,\"target_embeds\":[],\"make_cutouts\":None,\"weights\":[]}\n",
|
|
|
- " model_stat[\"clip_model\"] = clip_model\n",
|
|
|
- " \n",
|
|
|
- " \n",
|
|
|
- " for prompt in frame_prompt:\n",
|
|
|
- " txt, weight = parse_prompt(prompt)\n",
|
|
|
- " txt = clip_model.encode_text(clip.tokenize(prompt).to(device)).float()\n",
|
|
|
- " \n",
|
|
|
- " if args.fuzzy_prompt:\n",
|
|
|
- " for i in range(25):\n",
|
|
|
- " model_stat[\"target_embeds\"].append((txt + torch.randn(txt.shape).cuda() * args.rand_mag).clamp(0,1))\n",
|
|
|
- " model_stat[\"weights\"].append(weight)\n",
|
|
|
- " else:\n",
|
|
|
- " model_stat[\"target_embeds\"].append(txt)\n",
|
|
|
- " model_stat[\"weights\"].append(weight)\n",
|
|
|
- " \n",
|
|
|
- " if image_prompt:\n",
|
|
|
- " model_stat[\"make_cutouts\"] = MakeCutouts(clip_model.visual.input_resolution, cutn, skip_augs=skip_augs) \n",
|
|
|
- " for prompt in image_prompt:\n",
|
|
|
- " path, weight = parse_prompt(prompt)\n",
|
|
|
- " img = Image.open(fetch(path)).convert('RGB')\n",
|
|
|
- " img = TF.resize(img, min(side_x, side_y, *img.size), T.InterpolationMode.LANCZOS)\n",
|
|
|
- " batch = model_stat[\"make_cutouts\"](TF.to_tensor(img).to(device).unsqueeze(0).mul(2).sub(1))\n",
|
|
|
- " embed = clip_model.encode_image(normalize(batch)).float()\n",
|
|
|
- " if fuzzy_prompt:\n",
|
|
|
- " for i in range(25):\n",
|
|
|
- " model_stat[\"target_embeds\"].append((embed + torch.randn(embed.shape).cuda() * rand_mag).clamp(0,1))\n",
|
|
|
- " weights.extend([weight / cutn] * cutn)\n",
|
|
|
- " else:\n",
|
|
|
- " model_stat[\"target_embeds\"].append(embed)\n",
|
|
|
- " model_stat[\"weights\"].extend([weight / cutn] * cutn)\n",
|
|
|
- " \n",
|
|
|
- " model_stat[\"target_embeds\"] = torch.cat(model_stat[\"target_embeds\"])\n",
|
|
|
- " model_stat[\"weights\"] = torch.tensor(model_stat[\"weights\"], device=device)\n",
|
|
|
- " if model_stat[\"weights\"].sum().abs() < 1e-3:\n",
|
|
|
- " raise RuntimeError('The weights must not sum to 0.')\n",
|
|
|
- " model_stat[\"weights\"] /= model_stat[\"weights\"].sum().abs()\n",
|
|
|
- " model_stats.append(model_stat)\n",
|
|
|
- " \n",
|
|
|
- " init = None\n",
|
|
|
- " if init_image is not None:\n",
|
|
|
- " init = Image.open(fetch(init_image)).convert('RGB')\n",
|
|
|
- " init = init.resize((args.side_x, args.side_y), Image.LANCZOS)\n",
|
|
|
- " init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)\n",
|
|
|
- " \n",
|
|
|
- " if args.perlin_init:\n",
|
|
|
- " if args.perlin_mode == 'color':\n",
|
|
|
- " init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)\n",
|
|
|
- " init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, False)\n",
|
|
|
- " elif args.perlin_mode == 'gray':\n",
|
|
|
- " init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, True)\n",
|
|
|
- " init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)\n",
|
|
|
- " else:\n",
|
|
|
- " init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)\n",
|
|
|
- " init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)\n",
|
|
|
- " # init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device)\n",
|
|
|
- " init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device).unsqueeze(0).mul(2).sub(1)\n",
|
|
|
- " del init2\n",
|
|
|
- " \n",
|
|
|
- " cur_t = None\n",
|
|
|
- " \n",
|
|
|
- " def cond_fn(x, t, y=None):\n",
|
|
|
- " with torch.enable_grad():\n",
|
|
|
- " x_is_NaN = False\n",
|
|
|
- " x = x.detach().requires_grad_()\n",
|
|
|
- " n = x.shape[0]\n",
|
|
|
- " if use_secondary_model is True:\n",
|
|
|
- " alpha = torch.tensor(diffusion.sqrt_alphas_cumprod[cur_t], device=device, dtype=torch.float32)\n",
|
|
|
- " sigma = torch.tensor(diffusion.sqrt_one_minus_alphas_cumprod[cur_t], device=device, dtype=torch.float32)\n",
|
|
|
- " cosine_t = alpha_sigma_to_t(alpha, sigma)\n",
|
|
|
- " out = secondary_model(x, cosine_t[None].repeat([n])).pred\n",
|
|
|
- " fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]\n",
|
|
|
- " x_in = out * fac + x * (1 - fac)\n",
|
|
|
- " x_in_grad = torch.zeros_like(x_in)\n",
|
|
|
- " else:\n",
|
|
|
- " my_t = torch.ones([n], device=device, dtype=torch.long) * cur_t\n",
|
|
|
- " out = diffusion.p_mean_variance(model, x, my_t, clip_denoised=False, model_kwargs={'y': y})\n",
|
|
|
- " fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]\n",
|
|
|
- " x_in = out['pred_xstart'] * fac + x * (1 - fac)\n",
|
|
|
- " x_in_grad = torch.zeros_like(x_in)\n",
|
|
|
- " for model_stat in model_stats:\n",
|
|
|
- " for i in range(args.cutn_batches):\n",
|
|
|
- " t_int = int(t.item())+1 #errors on last step without +1, need to find source\n",
|
|
|
- " #when using SLIP Base model the dimensions need to be hard coded to avoid AttributeError: 'VisionTransformer' object has no attribute 'input_resolution'\n",
|
|
|
- " try:\n",
|
|
|
- " input_resolution=model_stat[\"clip_model\"].visual.input_resolution\n",
|
|
|
- " except:\n",
|
|
|
- " input_resolution=224\n",
|
|
|
- "\n",
|
|
|
- " cuts = MakeCutoutsDango(input_resolution,\n",
|
|
|
- " Overview= args.cut_overview[1000-t_int], \n",
|
|
|
- " InnerCrop = args.cut_innercut[1000-t_int], IC_Size_Pow=args.cut_ic_pow, IC_Grey_P = args.cut_icgray_p[1000-t_int]\n",
|
|
|
- " )\n",
|
|
|
- " clip_in = normalize(cuts(x_in.add(1).div(2)))\n",
|
|
|
- " image_embeds = model_stat[\"clip_model\"].encode_image(clip_in).float()\n",
|
|
|
- " dists = spherical_dist_loss(image_embeds.unsqueeze(1), model_stat[\"target_embeds\"].unsqueeze(0))\n",
|
|
|
- " dists = dists.view([args.cut_overview[1000-t_int]+args.cut_innercut[1000-t_int], n, -1])\n",
|
|
|
- " losses = dists.mul(model_stat[\"weights\"]).sum(2).mean(0)\n",
|
|
|
- " loss_values.append(losses.sum().item()) # log loss, probably shouldn't do per cutn_batch\n",
|
|
|
- " x_in_grad += torch.autograd.grad(losses.sum() * clip_guidance_scale, x_in)[0] / cutn_batches\n",
|
|
|
- " tv_losses = tv_loss(x_in)\n",
|
|
|
- " if use_secondary_model is True:\n",
|
|
|
- " range_losses = range_loss(out)\n",
|
|
|
- " else:\n",
|
|
|
- " range_losses = range_loss(out['pred_xstart'])\n",
|
|
|
- " sat_losses = torch.abs(x_in - x_in.clamp(min=-1,max=1)).mean()\n",
|
|
|
- " loss = tv_losses.sum() * tv_scale + range_losses.sum() * range_scale + sat_losses.sum() * sat_scale\n",
|
|
|
- " if init is not None and args.init_scale:\n",
|
|
|
- " init_losses = lpips_model(x_in, init)\n",
|
|
|
- " loss = loss + init_losses.sum() * args.init_scale\n",
|
|
|
- " x_in_grad += torch.autograd.grad(loss, x_in)[0]\n",
|
|
|
- " if torch.isnan(x_in_grad).any()==False:\n",
|
|
|
- " grad = -torch.autograd.grad(x_in, x, x_in_grad)[0]\n",
|
|
|
- " else:\n",
|
|
|
- " # print(\"NaN'd\")\n",
|
|
|
- " x_is_NaN = True\n",
|
|
|
- " grad = torch.zeros_like(x)\n",
|
|
|
- " if args.clamp_grad and x_is_NaN == False:\n",
|
|
|
- " magnitude = grad.square().mean().sqrt()\n",
|
|
|
- " return grad * magnitude.clamp(max=args.clamp_max) / magnitude #min=-0.02, min=-clamp_max, \n",
|
|
|
- " return grad\n",
|
|
|
- " \n",
|
|
|
- " if args.diffusion_sampling_mode == 'ddim':\n",
|
|
|
- " sample_fn = diffusion.ddim_sample_loop_progressive\n",
|
|
|
- " else:\n",
|
|
|
- " sample_fn = diffusion.plms_sample_loop_progressive\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- " image_display = Output()\n",
|
|
|
- " for i in range(args.n_batches):\n",
|
|
|
- " if args.animation_mode == 'None':\n",
|
|
|
- " display.clear_output(wait=True)\n",
|
|
|
- " batchBar = tqdm(range(args.n_batches), desc =\"Batches\")\n",
|
|
|
- " batchBar.n = i\n",
|
|
|
- " batchBar.refresh()\n",
|
|
|
- " print('')\n",
|
|
|
- " display.display(image_display)\n",
|
|
|
- " gc.collect()\n",
|
|
|
- " torch.cuda.empty_cache()\n",
|
|
|
- " cur_t = diffusion.num_timesteps - skip_steps - 1\n",
|
|
|
- " total_steps = cur_t\n",
|
|
|
- "\n",
|
|
|
- " if perlin_init:\n",
|
|
|
- " init = regen_perlin()\n",
|
|
|
- "\n",
|
|
|
- " if args.diffusion_sampling_mode == 'ddim':\n",
|
|
|
- " samples = sample_fn(\n",
|
|
|
- " model,\n",
|
|
|
- " (batch_size, 3, args.side_y, args.side_x),\n",
|
|
|
- " clip_denoised=clip_denoised,\n",
|
|
|
- " model_kwargs={},\n",
|
|
|
- " cond_fn=cond_fn,\n",
|
|
|
- " progress=True,\n",
|
|
|
- " skip_timesteps=skip_steps,\n",
|
|
|
- " init_image=init,\n",
|
|
|
- " randomize_class=randomize_class,\n",
|
|
|
- " eta=eta,\n",
|
|
|
- " )\n",
|
|
|
- " else:\n",
|
|
|
- " samples = sample_fn(\n",
|
|
|
- " model,\n",
|
|
|
- " (batch_size, 3, args.side_y, args.side_x),\n",
|
|
|
- " clip_denoised=clip_denoised,\n",
|
|
|
- " model_kwargs={},\n",
|
|
|
- " cond_fn=cond_fn,\n",
|
|
|
- " progress=True,\n",
|
|
|
- " skip_timesteps=skip_steps,\n",
|
|
|
- " init_image=init,\n",
|
|
|
- " randomize_class=randomize_class,\n",
|
|
|
- " order=2,\n",
|
|
|
- " )\n",
|
|
|
- " \n",
|
|
|
- " \n",
|
|
|
- " # with run_display:\n",
|
|
|
- " # display.clear_output(wait=True)\n",
|
|
|
- " imgToSharpen = None\n",
|
|
|
- " for j, sample in enumerate(samples): \n",
|
|
|
- " cur_t -= 1\n",
|
|
|
- " intermediateStep = False\n",
|
|
|
- " if args.steps_per_checkpoint is not None:\n",
|
|
|
- " if j % steps_per_checkpoint == 0 and j > 0:\n",
|
|
|
- " intermediateStep = True\n",
|
|
|
- " elif j in args.intermediate_saves:\n",
|
|
|
- " intermediateStep = True\n",
|
|
|
- " with image_display:\n",
|
|
|
- " if j % args.display_rate == 0 or cur_t == -1 or intermediateStep == True:\n",
|
|
|
- " for k, image in enumerate(sample['pred_xstart']):\n",
|
|
|
- " # tqdm.write(f'Batch {i}, step {j}, output {k}:')\n",
|
|
|
- " current_time = datetime.now().strftime('%y%m%d-%H%M%S_%f')\n",
|
|
|
- " percent = math.ceil(j/total_steps*100)\n",
|
|
|
- " if args.n_batches > 0:\n",
|
|
|
- " #if intermediates are saved to the subfolder, don't append a step or percentage to the name\n",
|
|
|
- " if cur_t == -1 and args.intermediates_in_subfolder is True:\n",
|
|
|
- " save_num = f'{frame_num:04}' if animation_mode != \"None\" else i\n",
|
|
|
- " filename = f'{args.batch_name}({args.batchNum})_{save_num}.png'\n",
|
|
|
- " else:\n",
|
|
|
- " #If we're working with percentages, append it\n",
|
|
|
- " if args.steps_per_checkpoint is not None:\n",
|
|
|
- " filename = f'{args.batch_name}({args.batchNum})_{i:04}-{percent:02}%.png'\n",
|
|
|
- " # Or else, iIf we're working with specific steps, append those\n",
|
|
|
- " else:\n",
|
|
|
- " filename = f'{args.batch_name}({args.batchNum})_{i:04}-{j:03}.png'\n",
|
|
|
- " image = TF.to_pil_image(image.add(1).div(2).clamp(0, 1))\n",
|
|
|
- " if j % args.display_rate == 0 or cur_t == -1:\n",
|
|
|
- " image.save('progress.png')\n",
|
|
|
- " display.clear_output(wait=True)\n",
|
|
|
- " display.display(display.Image('progress.png'))\n",
|
|
|
- " if args.steps_per_checkpoint is not None:\n",
|
|
|
- " if j % args.steps_per_checkpoint == 0 and j > 0:\n",
|
|
|
- " if args.intermediates_in_subfolder is True:\n",
|
|
|
- " image.save(f'{partialFolder}/{filename}')\n",
|
|
|
- " else:\n",
|
|
|
- " image.save(f'{batchFolder}/{filename}')\n",
|
|
|
- " else:\n",
|
|
|
- " if j in args.intermediate_saves:\n",
|
|
|
- " if args.intermediates_in_subfolder is True:\n",
|
|
|
- " image.save(f'{partialFolder}/{filename}')\n",
|
|
|
- " else:\n",
|
|
|
- " image.save(f'{batchFolder}/{filename}')\n",
|
|
|
- " if cur_t == -1:\n",
|
|
|
- " if frame_num == 0:\n",
|
|
|
- " save_settings()\n",
|
|
|
- " if args.animation_mode != \"None\":\n",
|
|
|
- " image.save('prevFrame.png')\n",
|
|
|
- " if args.sharpen_preset != \"Off\" and animation_mode == \"None\":\n",
|
|
|
- " imgToSharpen = image\n",
|
|
|
- " if args.keep_unsharp is True:\n",
|
|
|
- " image.save(f'{unsharpenFolder}/{filename}')\n",
|
|
|
- " else:\n",
|
|
|
- " image.save(f'{batchFolder}/{filename}')\n",
|
|
|
- " if args.animation_mode == \"3D\":\n",
|
|
|
- " # If turbo, save a blended image\n",
|
|
|
- " if turbo_mode:\n",
|
|
|
- " # Mix new image with prevFrameScaled\n",
|
|
|
- " blend_factor = (1)/int(turbo_steps)\n",
|
|
|
- " newFrame = cv2.imread('prevFrame.png') # This is already updated..\n",
|
|
|
- " prev_frame_warped = cv2.imread('prevFrameScaled.png')\n",
|
|
|
- " blendedImage = cv2.addWeighted(newFrame, blend_factor, prev_frame_warped, (1-blend_factor), 0.0)\n",
|
|
|
- " cv2.imwrite(f'{batchFolder}/{filename}',blendedImage)\n",
|
|
|
- " else:\n",
|
|
|
- " image.save(f'{batchFolder}/{filename}')\n",
|
|
|
- " # if frame_num != args.max_frames-1:\n",
|
|
|
- " # display.clear_output()\n",
|
|
|
- "\n",
|
|
|
- " with image_display: \n",
|
|
|
- " if args.sharpen_preset != \"Off\" and animation_mode == \"None\":\n",
|
|
|
- " print('Starting Diffusion Sharpening...')\n",
|
|
|
- " do_superres(imgToSharpen, f'{batchFolder}/{filename}')\n",
|
|
|
- " display.clear_output()\n",
|
|
|
- " \n",
|
|
|
- " plt.plot(np.array(loss_values), 'r')\n",
|
|
|
- "\n",
|
|
|
- "def save_settings():\n",
|
|
|
- " setting_list = {\n",
|
|
|
- " 'text_prompts': text_prompts,\n",
|
|
|
- " 'image_prompts': image_prompts,\n",
|
|
|
- " 'clip_guidance_scale': clip_guidance_scale,\n",
|
|
|
- " 'tv_scale': tv_scale,\n",
|
|
|
- " 'range_scale': range_scale,\n",
|
|
|
- " 'sat_scale': sat_scale,\n",
|
|
|
- " # 'cutn': cutn,\n",
|
|
|
- " 'cutn_batches': cutn_batches,\n",
|
|
|
- " 'max_frames': max_frames,\n",
|
|
|
- " 'interp_spline': interp_spline,\n",
|
|
|
- " # 'rotation_per_frame': rotation_per_frame,\n",
|
|
|
- " 'init_image': init_image,\n",
|
|
|
- " 'init_scale': init_scale,\n",
|
|
|
- " 'skip_steps': skip_steps,\n",
|
|
|
- " # 'zoom_per_frame': zoom_per_frame,\n",
|
|
|
- " 'frames_scale': frames_scale,\n",
|
|
|
- " 'frames_skip_steps': frames_skip_steps,\n",
|
|
|
- " 'perlin_init': perlin_init,\n",
|
|
|
- " 'perlin_mode': perlin_mode,\n",
|
|
|
- " 'skip_augs': skip_augs,\n",
|
|
|
- " 'randomize_class': randomize_class,\n",
|
|
|
- " 'clip_denoised': clip_denoised,\n",
|
|
|
- " 'clamp_grad': clamp_grad,\n",
|
|
|
- " 'clamp_max': clamp_max,\n",
|
|
|
- " 'seed': seed,\n",
|
|
|
- " 'fuzzy_prompt': fuzzy_prompt,\n",
|
|
|
- " 'rand_mag': rand_mag,\n",
|
|
|
- " 'eta': eta,\n",
|
|
|
- " 'width': width_height[0],\n",
|
|
|
- " 'height': width_height[1],\n",
|
|
|
- " 'diffusion_model': diffusion_model,\n",
|
|
|
- " 'use_secondary_model': use_secondary_model,\n",
|
|
|
- " 'steps': steps,\n",
|
|
|
- " 'diffusion_steps': diffusion_steps,\n",
|
|
|
- " 'diffusion_sampling_mode': diffusion_sampling_mode,\n",
|
|
|
- " 'ViTB32': ViTB32,\n",
|
|
|
- " 'ViTB16': ViTB16,\n",
|
|
|
- " 'ViTL14': ViTL14,\n",
|
|
|
- " 'RN101': RN101,\n",
|
|
|
- " 'RN50': RN50,\n",
|
|
|
- " 'RN50x4': RN50x4,\n",
|
|
|
- " 'RN50x16': RN50x16,\n",
|
|
|
- " 'RN50x64': RN50x64,\n",
|
|
|
- " 'cut_overview': str(cut_overview),\n",
|
|
|
- " 'cut_innercut': str(cut_innercut),\n",
|
|
|
- " 'cut_ic_pow': cut_ic_pow,\n",
|
|
|
- " 'cut_icgray_p': str(cut_icgray_p),\n",
|
|
|
- " 'key_frames': key_frames,\n",
|
|
|
- " 'max_frames': max_frames,\n",
|
|
|
- " 'angle': angle,\n",
|
|
|
- " 'zoom': zoom,\n",
|
|
|
- " 'translation_x': translation_x,\n",
|
|
|
- " 'translation_y': translation_y,\n",
|
|
|
- " 'translation_z': translation_z,\n",
|
|
|
- " 'rotation_3d_x': rotation_3d_x,\n",
|
|
|
- " 'rotation_3d_y': rotation_3d_y,\n",
|
|
|
- " 'rotation_3d_z': rotation_3d_z,\n",
|
|
|
- " 'midas_depth_model': midas_depth_model,\n",
|
|
|
- " 'midas_weight': midas_weight,\n",
|
|
|
- " 'near_plane': near_plane,\n",
|
|
|
- " 'far_plane': far_plane,\n",
|
|
|
- " 'fov': fov,\n",
|
|
|
- " 'padding_mode': padding_mode,\n",
|
|
|
- " 'sampling_mode': sampling_mode,\n",
|
|
|
- " 'video_init_path':video_init_path,\n",
|
|
|
- " 'extract_nth_frame':extract_nth_frame,\n",
|
|
|
- " 'video_init_seed_continuity': video_init_seed_continuity,\n",
|
|
|
- " 'turbo_mode':turbo_mode,\n",
|
|
|
- " 'turbo_steps':turbo_steps,\n",
|
|
|
- " 'turbo_preroll':turbo_preroll,\n",
|
|
|
- " }\n",
|
|
|
- " # print('Settings:', setting_list)\n",
|
|
|
- " with open(f\"{batchFolder}/{batch_name}({batchNum})_settings.txt\", \"w+\") as f: #save settings\n",
|
|
|
- " json.dump(setting_list, f, ensure_ascii=False, indent=4)"
|
|
|
- ],
|
|
|
- "outputs": [],
|
|
|
- "execution_count": null
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "metadata": {
|
|
|
- "cellView": "form",
|
|
|
- "id": "DefSecModel"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "#@title 1.6 Define the secondary diffusion model\n",
|
|
|
- "\n",
|
|
|
- "def append_dims(x, n):\n",
|
|
|
- " return x[(Ellipsis, *(None,) * (n - x.ndim))]\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "def expand_to_planes(x, shape):\n",
|
|
|
- " return append_dims(x, len(shape)).repeat([1, 1, *shape[2:]])\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "def alpha_sigma_to_t(alpha, sigma):\n",
|
|
|
- " return torch.atan2(sigma, alpha) * 2 / math.pi\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "def t_to_alpha_sigma(t):\n",
|
|
|
- " return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "@dataclass\n",
|
|
|
- "class DiffusionOutput:\n",
|
|
|
- " v: torch.Tensor\n",
|
|
|
- " pred: torch.Tensor\n",
|
|
|
- " eps: torch.Tensor\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "class ConvBlock(nn.Sequential):\n",
|
|
|
- " def __init__(self, c_in, c_out):\n",
|
|
|
- " super().__init__(\n",
|
|
|
- " nn.Conv2d(c_in, c_out, 3, padding=1),\n",
|
|
|
- " nn.ReLU(inplace=True),\n",
|
|
|
- " )\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "class SkipBlock(nn.Module):\n",
|
|
|
- " def __init__(self, main, skip=None):\n",
|
|
|
- " super().__init__()\n",
|
|
|
- " self.main = nn.Sequential(*main)\n",
|
|
|
- " self.skip = skip if skip else nn.Identity()\n",
|
|
|
- "\n",
|
|
|
- " def forward(self, input):\n",
|
|
|
- " return torch.cat([self.main(input), self.skip(input)], dim=1)\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "class FourierFeatures(nn.Module):\n",
|
|
|
- " def __init__(self, in_features, out_features, std=1.):\n",
|
|
|
- " super().__init__()\n",
|
|
|
- " assert out_features % 2 == 0\n",
|
|
|
- " self.weight = nn.Parameter(torch.randn([out_features // 2, in_features]) * std)\n",
|
|
|
- "\n",
|
|
|
- " def forward(self, input):\n",
|
|
|
- " f = 2 * math.pi * input @ self.weight.T\n",
|
|
|
- " return torch.cat([f.cos(), f.sin()], dim=-1)\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "class SecondaryDiffusionImageNet(nn.Module):\n",
|
|
|
- " def __init__(self):\n",
|
|
|
- " super().__init__()\n",
|
|
|
- " c = 64 # The base channel count\n",
|
|
|
- "\n",
|
|
|
- " self.timestep_embed = FourierFeatures(1, 16)\n",
|
|
|
- "\n",
|
|
|
- " self.net = nn.Sequential(\n",
|
|
|
- " ConvBlock(3 + 16, c),\n",
|
|
|
- " ConvBlock(c, c),\n",
|
|
|
- " SkipBlock([\n",
|
|
|
- " nn.AvgPool2d(2),\n",
|
|
|
- " ConvBlock(c, c * 2),\n",
|
|
|
- " ConvBlock(c * 2, c * 2),\n",
|
|
|
- " SkipBlock([\n",
|
|
|
- " nn.AvgPool2d(2),\n",
|
|
|
- " ConvBlock(c * 2, c * 4),\n",
|
|
|
- " ConvBlock(c * 4, c * 4),\n",
|
|
|
- " SkipBlock([\n",
|
|
|
- " nn.AvgPool2d(2),\n",
|
|
|
- " ConvBlock(c * 4, c * 8),\n",
|
|
|
- " ConvBlock(c * 8, c * 4),\n",
|
|
|
- " nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),\n",
|
|
|
- " ]),\n",
|
|
|
- " ConvBlock(c * 8, c * 4),\n",
|
|
|
- " ConvBlock(c * 4, c * 2),\n",
|
|
|
- " nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),\n",
|
|
|
- " ]),\n",
|
|
|
- " ConvBlock(c * 4, c * 2),\n",
|
|
|
- " ConvBlock(c * 2, c),\n",
|
|
|
- " nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),\n",
|
|
|
- " ]),\n",
|
|
|
- " ConvBlock(c * 2, c),\n",
|
|
|
- " nn.Conv2d(c, 3, 3, padding=1),\n",
|
|
|
- " )\n",
|
|
|
- "\n",
|
|
|
- " def forward(self, input, t):\n",
|
|
|
- " timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape)\n",
|
|
|
- " v = self.net(torch.cat([input, timestep_embed], dim=1))\n",
|
|
|
- " alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t))\n",
|
|
|
- " pred = input * alphas - v * sigmas\n",
|
|
|
- " eps = input * sigmas + v * alphas\n",
|
|
|
- " return DiffusionOutput(v, pred, eps)\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "class SecondaryDiffusionImageNet2(nn.Module):\n",
|
|
|
- " def __init__(self):\n",
|
|
|
- " super().__init__()\n",
|
|
|
- " c = 64 # The base channel count\n",
|
|
|
- " cs = [c, c * 2, c * 2, c * 4, c * 4, c * 8]\n",
|
|
|
- "\n",
|
|
|
- " self.timestep_embed = FourierFeatures(1, 16)\n",
|
|
|
- " self.down = nn.AvgPool2d(2)\n",
|
|
|
- " self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)\n",
|
|
|
- "\n",
|
|
|
- " self.net = nn.Sequential(\n",
|
|
|
- " ConvBlock(3 + 16, cs[0]),\n",
|
|
|
- " ConvBlock(cs[0], cs[0]),\n",
|
|
|
- " SkipBlock([\n",
|
|
|
- " self.down,\n",
|
|
|
- " ConvBlock(cs[0], cs[1]),\n",
|
|
|
- " ConvBlock(cs[1], cs[1]),\n",
|
|
|
- " SkipBlock([\n",
|
|
|
- " self.down,\n",
|
|
|
- " ConvBlock(cs[1], cs[2]),\n",
|
|
|
- " ConvBlock(cs[2], cs[2]),\n",
|
|
|
- " SkipBlock([\n",
|
|
|
- " self.down,\n",
|
|
|
- " ConvBlock(cs[2], cs[3]),\n",
|
|
|
- " ConvBlock(cs[3], cs[3]),\n",
|
|
|
- " SkipBlock([\n",
|
|
|
- " self.down,\n",
|
|
|
- " ConvBlock(cs[3], cs[4]),\n",
|
|
|
- " ConvBlock(cs[4], cs[4]),\n",
|
|
|
- " SkipBlock([\n",
|
|
|
- " self.down,\n",
|
|
|
- " ConvBlock(cs[4], cs[5]),\n",
|
|
|
- " ConvBlock(cs[5], cs[5]),\n",
|
|
|
- " ConvBlock(cs[5], cs[5]),\n",
|
|
|
- " ConvBlock(cs[5], cs[4]),\n",
|
|
|
- " self.up,\n",
|
|
|
- " ]),\n",
|
|
|
- " ConvBlock(cs[4] * 2, cs[4]),\n",
|
|
|
- " ConvBlock(cs[4], cs[3]),\n",
|
|
|
- " self.up,\n",
|
|
|
- " ]),\n",
|
|
|
- " ConvBlock(cs[3] * 2, cs[3]),\n",
|
|
|
- " ConvBlock(cs[3], cs[2]),\n",
|
|
|
- " self.up,\n",
|
|
|
- " ]),\n",
|
|
|
- " ConvBlock(cs[2] * 2, cs[2]),\n",
|
|
|
- " ConvBlock(cs[2], cs[1]),\n",
|
|
|
- " self.up,\n",
|
|
|
- " ]),\n",
|
|
|
- " ConvBlock(cs[1] * 2, cs[1]),\n",
|
|
|
- " ConvBlock(cs[1], cs[0]),\n",
|
|
|
- " self.up,\n",
|
|
|
- " ]),\n",
|
|
|
- " ConvBlock(cs[0] * 2, cs[0]),\n",
|
|
|
- " nn.Conv2d(cs[0], 3, 3, padding=1),\n",
|
|
|
- " )\n",
|
|
|
- "\n",
|
|
|
- " def forward(self, input, t):\n",
|
|
|
- " timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape)\n",
|
|
|
- " v = self.net(torch.cat([input, timestep_embed], dim=1))\n",
|
|
|
- " alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t))\n",
|
|
|
- " pred = input * alphas - v * sigmas\n",
|
|
|
- " eps = input * sigmas + v * alphas\n",
|
|
|
- " return DiffusionOutput(v, pred, eps)"
|
|
|
- ],
|
|
|
- "outputs": [],
|
|
|
- "execution_count": null
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "metadata": {
|
|
|
- "cellView": "form",
|
|
|
- "id": "DefSuperRes"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "#@title 1.7 SuperRes Define\n",
|
|
|
- "class DDIMSampler(object):\n",
|
|
|
- " def __init__(self, model, schedule=\"linear\", **kwargs):\n",
|
|
|
- " super().__init__()\n",
|
|
|
- " self.model = model\n",
|
|
|
- " self.ddpm_num_timesteps = model.num_timesteps\n",
|
|
|
- " self.schedule = schedule\n",
|
|
|
- "\n",
|
|
|
- " def register_buffer(self, name, attr):\n",
|
|
|
- " if type(attr) == torch.Tensor:\n",
|
|
|
- " if attr.device != torch.device(\"cuda\"):\n",
|
|
|
- " attr = attr.to(torch.device(\"cuda\"))\n",
|
|
|
- " setattr(self, name, attr)\n",
|
|
|
- "\n",
|
|
|
- " def make_schedule(self, ddim_num_steps, ddim_discretize=\"uniform\", ddim_eta=0., verbose=True):\n",
|
|
|
- " self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,\n",
|
|
|
- " num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)\n",
|
|
|
- " alphas_cumprod = self.model.alphas_cumprod\n",
|
|
|
- " assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'\n",
|
|
|
- " to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)\n",
|
|
|
- "\n",
|
|
|
- " self.register_buffer('betas', to_torch(self.model.betas))\n",
|
|
|
- " self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))\n",
|
|
|
- " self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))\n",
|
|
|
- "\n",
|
|
|
- " # calculations for diffusion q(x_t | x_{t-1}) and others\n",
|
|
|
- " self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))\n",
|
|
|
- " self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))\n",
|
|
|
- " self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))\n",
|
|
|
- " self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))\n",
|
|
|
- " self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))\n",
|
|
|
- "\n",
|
|
|
- " # ddim sampling parameters\n",
|
|
|
- " ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),\n",
|
|
|
- " ddim_timesteps=self.ddim_timesteps,\n",
|
|
|
- " eta=ddim_eta,verbose=verbose)\n",
|
|
|
- " self.register_buffer('ddim_sigmas', ddim_sigmas)\n",
|
|
|
- " self.register_buffer('ddim_alphas', ddim_alphas)\n",
|
|
|
- " self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)\n",
|
|
|
- " self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))\n",
|
|
|
- " sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(\n",
|
|
|
- " (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (\n",
|
|
|
- " 1 - self.alphas_cumprod / self.alphas_cumprod_prev))\n",
|
|
|
- " self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)\n",
|
|
|
- "\n",
|
|
|
- " @torch.no_grad()\n",
|
|
|
- " def sample(self,\n",
|
|
|
- " S,\n",
|
|
|
- " batch_size,\n",
|
|
|
- " shape,\n",
|
|
|
- " conditioning=None,\n",
|
|
|
- " callback=None,\n",
|
|
|
- " normals_sequence=None,\n",
|
|
|
- " img_callback=None,\n",
|
|
|
- " quantize_x0=False,\n",
|
|
|
- " eta=0.,\n",
|
|
|
- " mask=None,\n",
|
|
|
- " x0=None,\n",
|
|
|
- " temperature=1.,\n",
|
|
|
- " noise_dropout=0.,\n",
|
|
|
- " score_corrector=None,\n",
|
|
|
- " corrector_kwargs=None,\n",
|
|
|
- " verbose=True,\n",
|
|
|
- " x_T=None,\n",
|
|
|
- " log_every_t=100,\n",
|
|
|
- " **kwargs\n",
|
|
|
- " ):\n",
|
|
|
- " if conditioning is not None:\n",
|
|
|
- " if isinstance(conditioning, dict):\n",
|
|
|
- " cbs = conditioning[list(conditioning.keys())[0]].shape[0]\n",
|
|
|
- " if cbs != batch_size:\n",
|
|
|
- " print(f\"Warning: Got {cbs} conditionings but batch-size is {batch_size}\")\n",
|
|
|
- " else:\n",
|
|
|
- " if conditioning.shape[0] != batch_size:\n",
|
|
|
- " print(f\"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}\")\n",
|
|
|
- "\n",
|
|
|
- " self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)\n",
|
|
|
- " # sampling\n",
|
|
|
- " C, H, W = shape\n",
|
|
|
- " size = (batch_size, C, H, W)\n",
|
|
|
- " # print(f'Data shape for DDIM sampling is {size}, eta {eta}')\n",
|
|
|
- "\n",
|
|
|
- " samples, intermediates = self.ddim_sampling(conditioning, size,\n",
|
|
|
- " callback=callback,\n",
|
|
|
- " img_callback=img_callback,\n",
|
|
|
- " quantize_denoised=quantize_x0,\n",
|
|
|
- " mask=mask, x0=x0,\n",
|
|
|
- " ddim_use_original_steps=False,\n",
|
|
|
- " noise_dropout=noise_dropout,\n",
|
|
|
- " temperature=temperature,\n",
|
|
|
- " score_corrector=score_corrector,\n",
|
|
|
- " corrector_kwargs=corrector_kwargs,\n",
|
|
|
- " x_T=x_T,\n",
|
|
|
- " log_every_t=log_every_t\n",
|
|
|
- " )\n",
|
|
|
- " return samples, intermediates\n",
|
|
|
- "\n",
|
|
|
- " @torch.no_grad()\n",
|
|
|
- " def ddim_sampling(self, cond, shape,\n",
|
|
|
- " x_T=None, ddim_use_original_steps=False,\n",
|
|
|
- " callback=None, timesteps=None, quantize_denoised=False,\n",
|
|
|
- " mask=None, x0=None, img_callback=None, log_every_t=100,\n",
|
|
|
- " temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):\n",
|
|
|
- " device = self.model.betas.device\n",
|
|
|
- " b = shape[0]\n",
|
|
|
- " if x_T is None:\n",
|
|
|
- " img = torch.randn(shape, device=device)\n",
|
|
|
- " else:\n",
|
|
|
- " img = x_T\n",
|
|
|
- "\n",
|
|
|
- " if timesteps is None:\n",
|
|
|
- " timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps\n",
|
|
|
- " elif timesteps is not None and not ddim_use_original_steps:\n",
|
|
|
- " subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1\n",
|
|
|
- " timesteps = self.ddim_timesteps[:subset_end]\n",
|
|
|
- "\n",
|
|
|
- " intermediates = {'x_inter': [img], 'pred_x0': [img]}\n",
|
|
|
- " time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)\n",
|
|
|
- " total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]\n",
|
|
|
- " print(f\"Running DDIM Sharpening with {total_steps} timesteps\")\n",
|
|
|
- "\n",
|
|
|
- " iterator = tqdm(time_range, desc='DDIM Sharpening', total=total_steps)\n",
|
|
|
- "\n",
|
|
|
- " for i, step in enumerate(iterator):\n",
|
|
|
- " index = total_steps - i - 1\n",
|
|
|
- " ts = torch.full((b,), step, device=device, dtype=torch.long)\n",
|
|
|
- "\n",
|
|
|
- " if mask is not None:\n",
|
|
|
- " assert x0 is not None\n",
|
|
|
- " img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?\n",
|
|
|
- " img = img_orig * mask + (1. - mask) * img\n",
|
|
|
- "\n",
|
|
|
- " outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,\n",
|
|
|
- " quantize_denoised=quantize_denoised, temperature=temperature,\n",
|
|
|
- " noise_dropout=noise_dropout, score_corrector=score_corrector,\n",
|
|
|
- " corrector_kwargs=corrector_kwargs)\n",
|
|
|
- " img, pred_x0 = outs\n",
|
|
|
- " if callback: callback(i)\n",
|
|
|
- " if img_callback: img_callback(pred_x0, i)\n",
|
|
|
- "\n",
|
|
|
- " if index % log_every_t == 0 or index == total_steps - 1:\n",
|
|
|
- " intermediates['x_inter'].append(img)\n",
|
|
|
- " intermediates['pred_x0'].append(pred_x0)\n",
|
|
|
- "\n",
|
|
|
- " return img, intermediates\n",
|
|
|
- "\n",
|
|
|
- " @torch.no_grad()\n",
|
|
|
- " def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,\n",
|
|
|
- " temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):\n",
|
|
|
- " b, *_, device = *x.shape, x.device\n",
|
|
|
- " e_t = self.model.apply_model(x, t, c)\n",
|
|
|
- " if score_corrector is not None:\n",
|
|
|
- " assert self.model.parameterization == \"eps\"\n",
|
|
|
- " e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)\n",
|
|
|
- "\n",
|
|
|
- " alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas\n",
|
|
|
- " alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev\n",
|
|
|
- " sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas\n",
|
|
|
- " sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas\n",
|
|
|
- " # select parameters corresponding to the currently considered timestep\n",
|
|
|
- " a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)\n",
|
|
|
- " a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)\n",
|
|
|
- " sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)\n",
|
|
|
- " sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)\n",
|
|
|
- "\n",
|
|
|
- " # current prediction for x_0\n",
|
|
|
- " pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()\n",
|
|
|
- " if quantize_denoised:\n",
|
|
|
- " pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)\n",
|
|
|
- " # direction pointing to x_t\n",
|
|
|
- " dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t\n",
|
|
|
- " noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature\n",
|
|
|
- " if noise_dropout > 0.:\n",
|
|
|
- " noise = torch.nn.functional.dropout(noise, p=noise_dropout)\n",
|
|
|
- " x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise\n",
|
|
|
- " return x_prev, pred_x0\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "def download_models(mode):\n",
|
|
|
- "\n",
|
|
|
- " if mode == \"superresolution\":\n",
|
|
|
- " # this is the small bsr light model\n",
|
|
|
- " url_conf = 'https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1'\n",
|
|
|
- " url_ckpt = 'https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1'\n",
|
|
|
- "\n",
|
|
|
- " path_conf = f'{model_path}/superres/project.yaml'\n",
|
|
|
- " path_ckpt = f'{model_path}/superres/last.ckpt'\n",
|
|
|
- "\n",
|
|
|
- " download_url(url_conf, path_conf)\n",
|
|
|
- " download_url(url_ckpt, path_ckpt)\n",
|
|
|
- "\n",
|
|
|
- " path_conf = path_conf + '/?dl=1' # fix it\n",
|
|
|
- " path_ckpt = path_ckpt + '/?dl=1' # fix it\n",
|
|
|
- " return path_conf, path_ckpt\n",
|
|
|
- "\n",
|
|
|
- " else:\n",
|
|
|
- " raise NotImplementedError\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "def load_model_from_config(config, ckpt):\n",
|
|
|
- " print(f\"Loading model from {ckpt}\")\n",
|
|
|
- " pl_sd = torch.load(ckpt, map_location=\"cpu\")\n",
|
|
|
- " global_step = pl_sd[\"global_step\"]\n",
|
|
|
- " sd = pl_sd[\"state_dict\"]\n",
|
|
|
- " model = instantiate_from_config(config.model)\n",
|
|
|
- " m, u = model.load_state_dict(sd, strict=False)\n",
|
|
|
- " model.cuda()\n",
|
|
|
- " model.eval()\n",
|
|
|
- " return {\"model\": model}, global_step\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "def get_model(mode):\n",
|
|
|
- " path_conf, path_ckpt = download_models(mode)\n",
|
|
|
- " config = OmegaConf.load(path_conf)\n",
|
|
|
- " model, step = load_model_from_config(config, path_ckpt)\n",
|
|
|
- " return model\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "def get_custom_cond(mode):\n",
|
|
|
- " dest = \"data/example_conditioning\"\n",
|
|
|
- "\n",
|
|
|
- " if mode == \"superresolution\":\n",
|
|
|
- " uploaded_img = files.upload()\n",
|
|
|
- " filename = next(iter(uploaded_img))\n",
|
|
|
- " name, filetype = filename.split(\".\") # todo assumes just one dot in name !\n",
|
|
|
- " os.rename(f\"{filename}\", f\"{dest}/{mode}/custom_{name}.{filetype}\")\n",
|
|
|
- "\n",
|
|
|
- " elif mode == \"text_conditional\":\n",
|
|
|
- " w = widgets.Text(value='A cake with cream!', disabled=True)\n",
|
|
|
- " display.display(w)\n",
|
|
|
- "\n",
|
|
|
- " with open(f\"{dest}/{mode}/custom_{w.value[:20]}.txt\", 'w') as f:\n",
|
|
|
- " f.write(w.value)\n",
|
|
|
- "\n",
|
|
|
- " elif mode == \"class_conditional\":\n",
|
|
|
- " w = widgets.IntSlider(min=0, max=1000)\n",
|
|
|
- " display.display(w)\n",
|
|
|
- " with open(f\"{dest}/{mode}/custom.txt\", 'w') as f:\n",
|
|
|
- " f.write(w.value)\n",
|
|
|
- "\n",
|
|
|
- " else:\n",
|
|
|
- " raise NotImplementedError(f\"cond not implemented for mode{mode}\")\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "def get_cond_options(mode):\n",
|
|
|
- " path = \"data/example_conditioning\"\n",
|
|
|
- " path = os.path.join(path, mode)\n",
|
|
|
- " onlyfiles = [f for f in sorted(os.listdir(path))]\n",
|
|
|
- " return path, onlyfiles\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "def select_cond_path(mode):\n",
|
|
|
- " path = \"data/example_conditioning\" # todo\n",
|
|
|
- " path = os.path.join(path, mode)\n",
|
|
|
- " onlyfiles = [f for f in sorted(os.listdir(path))]\n",
|
|
|
- "\n",
|
|
|
- " selected = widgets.RadioButtons(\n",
|
|
|
- " options=onlyfiles,\n",
|
|
|
- " description='Select conditioning:',\n",
|
|
|
- " disabled=False\n",
|
|
|
- " )\n",
|
|
|
- " display.display(selected)\n",
|
|
|
- " selected_path = os.path.join(path, selected.value)\n",
|
|
|
- " return selected_path\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "def get_cond(mode, img):\n",
|
|
|
- " example = dict()\n",
|
|
|
- " if mode == \"superresolution\":\n",
|
|
|
- " up_f = 4\n",
|
|
|
- " # visualize_cond_img(selected_path)\n",
|
|
|
- "\n",
|
|
|
- " c = img\n",
|
|
|
- " c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)\n",
|
|
|
- " c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], antialias=True)\n",
|
|
|
- " c_up = rearrange(c_up, '1 c h w -> 1 h w c')\n",
|
|
|
- " c = rearrange(c, '1 c h w -> 1 h w c')\n",
|
|
|
- " c = 2. * c - 1.\n",
|
|
|
- "\n",
|
|
|
- " c = c.to(torch.device(\"cuda\"))\n",
|
|
|
- " example[\"LR_image\"] = c\n",
|
|
|
- " example[\"image\"] = c_up\n",
|
|
|
- "\n",
|
|
|
- " return example\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "def visualize_cond_img(path):\n",
|
|
|
- " display.display(ipyimg(filename=path))\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "def sr_run(model, img, task, custom_steps, eta, resize_enabled=False, classifier_ckpt=None, global_step=None):\n",
|
|
|
- " # global stride\n",
|
|
|
- "\n",
|
|
|
- " example = get_cond(task, img)\n",
|
|
|
- "\n",
|
|
|
- " save_intermediate_vid = False\n",
|
|
|
- " n_runs = 1\n",
|
|
|
- " masked = False\n",
|
|
|
- " guider = None\n",
|
|
|
- " ckwargs = None\n",
|
|
|
- " mode = 'ddim'\n",
|
|
|
- " ddim_use_x0_pred = False\n",
|
|
|
- " temperature = 1.\n",
|
|
|
- " eta = eta\n",
|
|
|
- " make_progrow = True\n",
|
|
|
- " custom_shape = None\n",
|
|
|
- "\n",
|
|
|
- " height, width = example[\"image\"].shape[1:3]\n",
|
|
|
- " split_input = height >= 128 and width >= 128\n",
|
|
|
- "\n",
|
|
|
- " if split_input:\n",
|
|
|
- " ks = 128\n",
|
|
|
- " stride = 64\n",
|
|
|
- " vqf = 4 #\n",
|
|
|
- " model.split_input_params = {\"ks\": (ks, ks), \"stride\": (stride, stride),\n",
|
|
|
- " \"vqf\": vqf,\n",
|
|
|
- " \"patch_distributed_vq\": True,\n",
|
|
|
- " \"tie_braker\": False,\n",
|
|
|
- " \"clip_max_weight\": 0.5,\n",
|
|
|
- " \"clip_min_weight\": 0.01,\n",
|
|
|
- " \"clip_max_tie_weight\": 0.5,\n",
|
|
|
- " \"clip_min_tie_weight\": 0.01}\n",
|
|
|
- " else:\n",
|
|
|
- " if hasattr(model, \"split_input_params\"):\n",
|
|
|
- " delattr(model, \"split_input_params\")\n",
|
|
|
- "\n",
|
|
|
- " invert_mask = False\n",
|
|
|
- "\n",
|
|
|
- " x_T = None\n",
|
|
|
- " for n in range(n_runs):\n",
|
|
|
- " if custom_shape is not None:\n",
|
|
|
- " x_T = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)\n",
|
|
|
- " x_T = repeat(x_T, '1 c h w -> b c h w', b=custom_shape[0])\n",
|
|
|
- "\n",
|
|
|
- " logs = make_convolutional_sample(example, model,\n",
|
|
|
- " mode=mode, custom_steps=custom_steps,\n",
|
|
|
- " eta=eta, swap_mode=False , masked=masked,\n",
|
|
|
- " invert_mask=invert_mask, quantize_x0=False,\n",
|
|
|
- " custom_schedule=None, decode_interval=10,\n",
|
|
|
- " resize_enabled=resize_enabled, custom_shape=custom_shape,\n",
|
|
|
- " temperature=temperature, noise_dropout=0.,\n",
|
|
|
- " corrector=guider, corrector_kwargs=ckwargs, x_T=x_T, save_intermediate_vid=save_intermediate_vid,\n",
|
|
|
- " make_progrow=make_progrow,ddim_use_x0_pred=ddim_use_x0_pred\n",
|
|
|
- " )\n",
|
|
|
- " return logs\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "@torch.no_grad()\n",
|
|
|
- "def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,\n",
|
|
|
- " mask=None, x0=None, quantize_x0=False, img_callback=None,\n",
|
|
|
- " temperature=1., noise_dropout=0., score_corrector=None,\n",
|
|
|
- " corrector_kwargs=None, x_T=None, log_every_t=None\n",
|
|
|
- " ):\n",
|
|
|
- "\n",
|
|
|
- " ddim = DDIMSampler(model)\n",
|
|
|
- " bs = shape[0] # dont know where this comes from but wayne\n",
|
|
|
- " shape = shape[1:] # cut batch dim\n",
|
|
|
- " # print(f\"Sampling with eta = {eta}; steps: {steps}\")\n",
|
|
|
- " samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,\n",
|
|
|
- " normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,\n",
|
|
|
- " mask=mask, x0=x0, temperature=temperature, verbose=False,\n",
|
|
|
- " score_corrector=score_corrector,\n",
|
|
|
- " corrector_kwargs=corrector_kwargs, x_T=x_T)\n",
|
|
|
- "\n",
|
|
|
- " return samples, intermediates\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "@torch.no_grad()\n",
|
|
|
- "def make_convolutional_sample(batch, model, mode=\"vanilla\", custom_steps=None, eta=1.0, swap_mode=False, masked=False,\n",
|
|
|
- " invert_mask=True, quantize_x0=False, custom_schedule=None, decode_interval=1000,\n",
|
|
|
- " resize_enabled=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,\n",
|
|
|
- " corrector_kwargs=None, x_T=None, save_intermediate_vid=False, make_progrow=True,ddim_use_x0_pred=False):\n",
|
|
|
- " log = dict()\n",
|
|
|
- "\n",
|
|
|
- " z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,\n",
|
|
|
- " return_first_stage_outputs=True,\n",
|
|
|
- " force_c_encode=not (hasattr(model, 'split_input_params')\n",
|
|
|
- " and model.cond_stage_key == 'coordinates_bbox'),\n",
|
|
|
- " return_original_cond=True)\n",
|
|
|
- "\n",
|
|
|
- " log_every_t = 1 if save_intermediate_vid else None\n",
|
|
|
- "\n",
|
|
|
- " if custom_shape is not None:\n",
|
|
|
- " z = torch.randn(custom_shape)\n",
|
|
|
- " # print(f\"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}\")\n",
|
|
|
- "\n",
|
|
|
- " z0 = None\n",
|
|
|
- "\n",
|
|
|
- " log[\"input\"] = x\n",
|
|
|
- " log[\"reconstruction\"] = xrec\n",
|
|
|
- "\n",
|
|
|
- " if ismap(xc):\n",
|
|
|
- " log[\"original_conditioning\"] = model.to_rgb(xc)\n",
|
|
|
- " if hasattr(model, 'cond_stage_key'):\n",
|
|
|
- " log[model.cond_stage_key] = model.to_rgb(xc)\n",
|
|
|
- "\n",
|
|
|
- " else:\n",
|
|
|
- " log[\"original_conditioning\"] = xc if xc is not None else torch.zeros_like(x)\n",
|
|
|
- " if model.cond_stage_model:\n",
|
|
|
- " log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)\n",
|
|
|
- " if model.cond_stage_key =='class_label':\n",
|
|
|
- " log[model.cond_stage_key] = xc[model.cond_stage_key]\n",
|
|
|
- "\n",
|
|
|
- " with model.ema_scope(\"Plotting\"):\n",
|
|
|
- " t0 = time.time()\n",
|
|
|
- " img_cb = None\n",
|
|
|
- "\n",
|
|
|
- " sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,\n",
|
|
|
- " eta=eta,\n",
|
|
|
- " quantize_x0=quantize_x0, img_callback=img_cb, mask=None, x0=z0,\n",
|
|
|
- " temperature=temperature, noise_dropout=noise_dropout,\n",
|
|
|
- " score_corrector=corrector, corrector_kwargs=corrector_kwargs,\n",
|
|
|
- " x_T=x_T, log_every_t=log_every_t)\n",
|
|
|
- " t1 = time.time()\n",
|
|
|
- "\n",
|
|
|
- " if ddim_use_x0_pred:\n",
|
|
|
- " sample = intermediates['pred_x0'][-1]\n",
|
|
|
- "\n",
|
|
|
- " x_sample = model.decode_first_stage(sample)\n",
|
|
|
- "\n",
|
|
|
- " try:\n",
|
|
|
- " x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)\n",
|
|
|
- " log[\"sample_noquant\"] = x_sample_noquant\n",
|
|
|
- " log[\"sample_diff\"] = torch.abs(x_sample_noquant - x_sample)\n",
|
|
|
- " except:\n",
|
|
|
- " pass\n",
|
|
|
- "\n",
|
|
|
- " log[\"sample\"] = x_sample\n",
|
|
|
- " log[\"time\"] = t1 - t0\n",
|
|
|
- "\n",
|
|
|
- " return log\n",
|
|
|
- "\n",
|
|
|
- "sr_diffMode = 'superresolution'\n",
|
|
|
- "sr_model = get_model('superresolution')\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "def do_superres(img, filepath):\n",
|
|
|
- "\n",
|
|
|
- " if args.sharpen_preset == 'Faster':\n",
|
|
|
- " sr_diffusion_steps = \"25\" \n",
|
|
|
- " sr_pre_downsample = '1/2' \n",
|
|
|
- " if args.sharpen_preset == 'Fast':\n",
|
|
|
- " sr_diffusion_steps = \"100\" \n",
|
|
|
- " sr_pre_downsample = '1/2' \n",
|
|
|
- " if args.sharpen_preset == 'Slow':\n",
|
|
|
- " sr_diffusion_steps = \"25\" \n",
|
|
|
- " sr_pre_downsample = 'None' \n",
|
|
|
- " if args.sharpen_preset == 'Very Slow':\n",
|
|
|
- " sr_diffusion_steps = \"100\" \n",
|
|
|
- " sr_pre_downsample = 'None' \n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- " sr_post_downsample = 'Original Size'\n",
|
|
|
- " sr_diffusion_steps = int(sr_diffusion_steps)\n",
|
|
|
- " sr_eta = 1.0 \n",
|
|
|
- " sr_downsample_method = 'Lanczos' \n",
|
|
|
- "\n",
|
|
|
- " gc.collect()\n",
|
|
|
- " torch.cuda.empty_cache()\n",
|
|
|
- "\n",
|
|
|
- " im_og = img\n",
|
|
|
- " width_og, height_og = im_og.size\n",
|
|
|
- "\n",
|
|
|
- " #Downsample Pre\n",
|
|
|
- " if sr_pre_downsample == '1/2':\n",
|
|
|
- " downsample_rate = 2\n",
|
|
|
- " elif sr_pre_downsample == '1/4':\n",
|
|
|
- " downsample_rate = 4\n",
|
|
|
- " else:\n",
|
|
|
- " downsample_rate = 1\n",
|
|
|
- "\n",
|
|
|
- " width_downsampled_pre = width_og//downsample_rate\n",
|
|
|
- " height_downsampled_pre = height_og//downsample_rate\n",
|
|
|
- "\n",
|
|
|
- " if downsample_rate != 1:\n",
|
|
|
- " # print(f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')\n",
|
|
|
- " im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)\n",
|
|
|
- " # im_og.save('/content/temp.png')\n",
|
|
|
- " # filepath = '/content/temp.png'\n",
|
|
|
- "\n",
|
|
|
- " logs = sr_run(sr_model[\"model\"], im_og, sr_diffMode, sr_diffusion_steps, sr_eta)\n",
|
|
|
- "\n",
|
|
|
- " sample = logs[\"sample\"]\n",
|
|
|
- " sample = sample.detach().cpu()\n",
|
|
|
- " sample = torch.clamp(sample, -1., 1.)\n",
|
|
|
- " sample = (sample + 1.) / 2. * 255\n",
|
|
|
- " sample = sample.numpy().astype(np.uint8)\n",
|
|
|
- " sample = np.transpose(sample, (0, 2, 3, 1))\n",
|
|
|
- " a = Image.fromarray(sample[0])\n",
|
|
|
- "\n",
|
|
|
- " #Downsample Post\n",
|
|
|
- " if sr_post_downsample == '1/2':\n",
|
|
|
- " downsample_rate = 2\n",
|
|
|
- " elif sr_post_downsample == '1/4':\n",
|
|
|
- " downsample_rate = 4\n",
|
|
|
- " else:\n",
|
|
|
- " downsample_rate = 1\n",
|
|
|
- "\n",
|
|
|
- " width, height = a.size\n",
|
|
|
- " width_downsampled_post = width//downsample_rate\n",
|
|
|
- " height_downsampled_post = height//downsample_rate\n",
|
|
|
- "\n",
|
|
|
- " if sr_downsample_method == 'Lanczos':\n",
|
|
|
- " aliasing = Image.LANCZOS\n",
|
|
|
- " else:\n",
|
|
|
- " aliasing = Image.NEAREST\n",
|
|
|
- "\n",
|
|
|
- " if downsample_rate != 1:\n",
|
|
|
- " # print(f'Downsampling from [{width}, {height}] to [{width_downsampled_post}, {height_downsampled_post}]')\n",
|
|
|
- " a = a.resize((width_downsampled_post, height_downsampled_post), aliasing)\n",
|
|
|
- " elif sr_post_downsample == 'Original Size':\n",
|
|
|
- " # print(f'Downsampling from [{width}, {height}] to Original Size [{width_og}, {height_og}]')\n",
|
|
|
- " a = a.resize((width_og, height_og), aliasing)\n",
|
|
|
- "\n",
|
|
|
- " display.display(a)\n",
|
|
|
- " a.save(filepath)\n",
|
|
|
- " return\n",
|
|
|
- " print(f'Processing finished!')\n"
|
|
|
- ],
|
|
|
- "outputs": [],
|
|
|
- "execution_count": null
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "markdown",
|
|
|
- "metadata": {
|
|
|
- "id": "DiffClipSetTop"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "# 2. Diffusion and CLIP model settings"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "metadata": {
|
|
|
- "id": "ModelSettings"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "#@markdown ####**Models Settings:**\n",
|
|
|
- "diffusion_model = \"512x512_diffusion_uncond_finetune_008100\" #@param [\"256x256_diffusion_uncond\", \"512x512_diffusion_uncond_finetune_008100\"]\n",
|
|
|
- "use_secondary_model = True #@param {type: 'boolean'}\n",
|
|
|
- "diffusion_sampling_mode = 'ddim' #@param ['plms','ddim'] \n",
|
|
|
- "\n",
|
|
|
- "timestep_respacing = '250' #@param ['25','50','100','150','250','500','1000','ddim25','ddim50', 'ddim75', 'ddim100','ddim150','ddim250','ddim500','ddim1000'] \n",
|
|
|
- "diffusion_steps = 300 #@param {type: 'number'}\n",
|
|
|
- "use_checkpoint = True #@param {type: 'boolean'}\n",
|
|
|
- "ViTB32 = True #@param{type:\"boolean\"}\n",
|
|
|
- "ViTB16 = True #@param{type:\"boolean\"}\n",
|
|
|
- "ViTL14 = False #@param{type:\"boolean\"}\n",
|
|
|
- "RN101 = False #@param{type:\"boolean\"}\n",
|
|
|
- "RN50 = True #@param{type:\"boolean\"}\n",
|
|
|
- "RN50x4 = False #@param{type:\"boolean\"}\n",
|
|
|
- "RN50x16 = False #@param{type:\"boolean\"}\n",
|
|
|
- "RN50x64 = False #@param{type:\"boolean\"}\n",
|
|
|
- "SLIPB16 = False #@param{type:\"boolean\"}\n",
|
|
|
- "SLIPL16 = False #@param{type:\"boolean\"}\n",
|
|
|
- "\n",
|
|
|
- "#@markdown If you're having issues with model downloads, check this to compare SHA's:\n",
|
|
|
- "check_model_SHA = False #@param{type:\"boolean\"}\n",
|
|
|
- "\n",
|
|
|
- "model_256_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a'\n",
|
|
|
- "model_512_SHA = '9c111ab89e214862b76e1fa6a1b3f1d329b1a88281885943d2cdbe357ad57648'\n",
|
|
|
- "model_secondary_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a'\n",
|
|
|
- "\n",
|
|
|
- "model_256_link = 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt'\n",
|
|
|
- "model_512_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/512x512_diffusion_uncond_finetune_008100.pt'\n",
|
|
|
- "model_secondary_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/secondary_model_imagenet_2.pth'\n",
|
|
|
- "\n",
|
|
|
- "model_256_path = f'{model_path}/256x256_diffusion_uncond.pt'\n",
|
|
|
- "model_512_path = f'{model_path}/512x512_diffusion_uncond_finetune_008100.pt'\n",
|
|
|
- "model_secondary_path = f'{model_path}/secondary_model_imagenet_2.pth'\n",
|
|
|
- "\n",
|
|
|
- "# Download the diffusion model\n",
|
|
|
- "if diffusion_model == '256x256_diffusion_uncond':\n",
|
|
|
- " if os.path.exists(model_256_path) and check_model_SHA:\n",
|
|
|
- " print('Checking 256 Diffusion File')\n",
|
|
|
- " with open(model_256_path,\"rb\") as f:\n",
|
|
|
- " bytes = f.read() \n",
|
|
|
- " hash = hashlib.sha256(bytes).hexdigest();\n",
|
|
|
- " if hash == model_256_SHA:\n",
|
|
|
- " print('256 Model SHA matches')\n",
|
|
|
- " model_256_downloaded = True\n",
|
|
|
- " else: \n",
|
|
|
- " print(\"256 Model SHA doesn't match, redownloading...\")\n",
|
|
|
- " wget(model_256_link, model_path)\n",
|
|
|
- " model_256_downloaded = True\n",
|
|
|
- " elif os.path.exists(model_256_path) and not check_model_SHA or model_256_downloaded == True:\n",
|
|
|
- " print('256 Model already downloaded, check check_model_SHA if the file is corrupt')\n",
|
|
|
- " else: \n",
|
|
|
- " wget(model_256_link, model_path)\n",
|
|
|
- " model_256_downloaded = True\n",
|
|
|
- "elif diffusion_model == '512x512_diffusion_uncond_finetune_008100':\n",
|
|
|
- " if os.path.exists(model_512_path) and check_model_SHA:\n",
|
|
|
- " print('Checking 512 Diffusion File')\n",
|
|
|
- " with open(model_512_path,\"rb\") as f:\n",
|
|
|
- " bytes = f.read() \n",
|
|
|
- " hash = hashlib.sha256(bytes).hexdigest();\n",
|
|
|
- " if hash == model_512_SHA:\n",
|
|
|
- " print('512 Model SHA matches')\n",
|
|
|
- " model_512_downloaded = True\n",
|
|
|
- " else: \n",
|
|
|
- " print(\"512 Model SHA doesn't match, redownloading...\")\n",
|
|
|
- " wget(model_512_link, model_path)\n",
|
|
|
- " model_512_downloaded = True\n",
|
|
|
- " elif os.path.exists(model_512_path) and not check_model_SHA or model_512_downloaded == True:\n",
|
|
|
- " print('512 Model already downloaded, check check_model_SHA if the file is corrupt')\n",
|
|
|
- " else: \n",
|
|
|
- " wget(model_512_link, model_path)\n",
|
|
|
- " model_512_downloaded = True\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "# Download the secondary diffusion model v2\n",
|
|
|
- "if use_secondary_model == True:\n",
|
|
|
- " if os.path.exists(model_secondary_path) and check_model_SHA:\n",
|
|
|
- " print('Checking Secondary Diffusion File')\n",
|
|
|
- " with open(model_secondary_path,\"rb\") as f:\n",
|
|
|
- " bytes = f.read() \n",
|
|
|
- " hash = hashlib.sha256(bytes).hexdigest();\n",
|
|
|
- " if hash == model_secondary_SHA:\n",
|
|
|
- " print('Secondary Model SHA matches')\n",
|
|
|
- " model_secondary_downloaded = True\n",
|
|
|
- " else: \n",
|
|
|
- " print(\"Secondary Model SHA doesn't match, redownloading...\")\n",
|
|
|
- " wget(model_secondary_link, model_path)\n",
|
|
|
- " model_secondary_downloaded = True\n",
|
|
|
- " elif os.path.exists(model_secondary_path) and not check_model_SHA or model_secondary_downloaded == True:\n",
|
|
|
- " print('Secondary Model already downloaded, check check_model_SHA if the file is corrupt')\n",
|
|
|
- " else: \n",
|
|
|
- " wget(model_secondary_link, model_path)\n",
|
|
|
- " model_secondary_downloaded = True\n",
|
|
|
- "\n",
|
|
|
- "model_config = model_and_diffusion_defaults()\n",
|
|
|
- "if diffusion_model == '512x512_diffusion_uncond_finetune_008100':\n",
|
|
|
- " model_config.update({\n",
|
|
|
- " 'attention_resolutions': '32, 16, 8',\n",
|
|
|
- " 'class_cond': False,\n",
|
|
|
- " 'diffusion_steps': diffusion_steps,\n",
|
|
|
- " 'rescale_timesteps': True,\n",
|
|
|
- " 'timestep_respacing': timestep_respacing,\n",
|
|
|
- " 'image_size': 512,\n",
|
|
|
- " 'learn_sigma': True,\n",
|
|
|
- " 'noise_schedule': 'linear',\n",
|
|
|
- " 'num_channels': 256,\n",
|
|
|
- " 'num_head_channels': 64,\n",
|
|
|
- " 'num_res_blocks': 2,\n",
|
|
|
- " 'resblock_updown': True,\n",
|
|
|
- " 'use_checkpoint': use_checkpoint,\n",
|
|
|
- " 'use_fp16': True,\n",
|
|
|
- " 'use_scale_shift_norm': True,\n",
|
|
|
- " })\n",
|
|
|
- "elif diffusion_model == '256x256_diffusion_uncond':\n",
|
|
|
- " model_config.update({\n",
|
|
|
- " 'attention_resolutions': '32, 16, 8',\n",
|
|
|
- " 'class_cond': False,\n",
|
|
|
- " 'diffusion_steps': diffusion_steps,\n",
|
|
|
- " 'rescale_timesteps': True,\n",
|
|
|
- " 'timestep_respacing': timestep_respacing,\n",
|
|
|
- " 'image_size': 256,\n",
|
|
|
- " 'learn_sigma': True,\n",
|
|
|
- " 'noise_schedule': 'linear',\n",
|
|
|
- " 'num_channels': 256,\n",
|
|
|
- " 'num_head_channels': 64,\n",
|
|
|
- " 'num_res_blocks': 2,\n",
|
|
|
- " 'resblock_updown': True,\n",
|
|
|
- " 'use_checkpoint': use_checkpoint,\n",
|
|
|
- " 'use_fp16': True,\n",
|
|
|
- " 'use_scale_shift_norm': True,\n",
|
|
|
- " })\n",
|
|
|
- "\n",
|
|
|
- "secondary_model_ver = 2\n",
|
|
|
- "model_default = model_config['image_size']\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "if secondary_model_ver == 2:\n",
|
|
|
- " secondary_model = SecondaryDiffusionImageNet2()\n",
|
|
|
- " secondary_model.load_state_dict(torch.load(f'{model_path}/secondary_model_imagenet_2.pth', map_location='cpu'))\n",
|
|
|
- "secondary_model.eval().requires_grad_(False).to(device)\n",
|
|
|
- "\n",
|
|
|
- "clip_models = []\n",
|
|
|
- "if ViTB32 is True: clip_models.append(clip.load('ViT-B/32', jit=False)[0].eval().requires_grad_(False).to(device)) \n",
|
|
|
- "if ViTB16 is True: clip_models.append(clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device) ) \n",
|
|
|
- "if ViTL14 is True: clip_models.append(clip.load('ViT-L/14', jit=False)[0].eval().requires_grad_(False).to(device) ) \n",
|
|
|
- "if RN50 is True: clip_models.append(clip.load('RN50', jit=False)[0].eval().requires_grad_(False).to(device))\n",
|
|
|
- "if RN50x4 is True: clip_models.append(clip.load('RN50x4', jit=False)[0].eval().requires_grad_(False).to(device)) \n",
|
|
|
- "if RN50x16 is True: clip_models.append(clip.load('RN50x16', jit=False)[0].eval().requires_grad_(False).to(device)) \n",
|
|
|
- "if RN50x64 is True: clip_models.append(clip.load('RN50x64', jit=False)[0].eval().requires_grad_(False).to(device)) \n",
|
|
|
- "if RN101 is True: clip_models.append(clip.load('RN101', jit=False)[0].eval().requires_grad_(False).to(device)) \n",
|
|
|
- "\n",
|
|
|
- "if SLIPB16:\n",
|
|
|
- " SLIPB16model = SLIP_VITB16(ssl_mlp_dim=4096, ssl_emb_dim=256)\n",
|
|
|
- " if not os.path.exists(f'{model_path}/slip_base_100ep.pt'):\n",
|
|
|
- " wget(\"https://dl.fbaipublicfiles.com/slip/slip_base_100ep.pt\", model_path)\n",
|
|
|
- " sd = torch.load(f'{model_path}/slip_base_100ep.pt')\n",
|
|
|
- " real_sd = {}\n",
|
|
|
- " for k, v in sd['state_dict'].items():\n",
|
|
|
- " real_sd['.'.join(k.split('.')[1:])] = v\n",
|
|
|
- " del sd\n",
|
|
|
- " SLIPB16model.load_state_dict(real_sd)\n",
|
|
|
- " SLIPB16model.requires_grad_(False).eval().to(device)\n",
|
|
|
- "\n",
|
|
|
- " clip_models.append(SLIPB16model)\n",
|
|
|
- "\n",
|
|
|
- "if SLIPL16:\n",
|
|
|
- " SLIPL16model = SLIP_VITL16(ssl_mlp_dim=4096, ssl_emb_dim=256)\n",
|
|
|
- " if not os.path.exists(f'{model_path}/slip_large_100ep.pt'):\n",
|
|
|
- " wget(\"https://dl.fbaipublicfiles.com/slip/slip_large_100ep.pt\", model_path)\n",
|
|
|
- " sd = torch.load(f'{model_path}/slip_large_100ep.pt')\n",
|
|
|
- " real_sd = {}\n",
|
|
|
- " for k, v in sd['state_dict'].items():\n",
|
|
|
- " real_sd['.'.join(k.split('.')[1:])] = v\n",
|
|
|
- " del sd\n",
|
|
|
- " SLIPL16model.load_state_dict(real_sd)\n",
|
|
|
- " SLIPL16model.requires_grad_(False).eval().to(device)\n",
|
|
|
- "\n",
|
|
|
- " clip_models.append(SLIPL16model)\n",
|
|
|
- "\n",
|
|
|
- "normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])\n",
|
|
|
- "lpips_model = lpips.LPIPS(net='vgg').to(device)\n"
|
|
|
- ],
|
|
|
- "outputs": [],
|
|
|
- "execution_count": null
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "markdown",
|
|
|
- "metadata": {
|
|
|
- "id": "SettingsTop"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "# 3. Settings"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "metadata": {
|
|
|
- "id": "BasicSettings"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "#@markdown ####**Basic Settings:**\n",
|
|
|
- "batch_name = 'new_House' #@param{type: 'string'}\n",
|
|
|
- "steps = 300#@param [25,50,100,150,250,500,1000]{type: 'raw', allow-input: true}\n",
|
|
|
- "width_height = [1280, 720]#@param{type: 'raw'}\n",
|
|
|
- "clip_guidance_scale = 5000 #@param{type: 'number'}\n",
|
|
|
- "tv_scale = 0#@param{type: 'number'}\n",
|
|
|
- "range_scale = 150#@param{type: 'number'}\n",
|
|
|
- "sat_scale = 0#@param{type: 'number'}\n",
|
|
|
- "cutn_batches = 4#@param{type: 'number'}\n",
|
|
|
- "skip_augs = False#@param{type: 'boolean'}\n",
|
|
|
- "\n",
|
|
|
- "#@markdown ---\n",
|
|
|
- "\n",
|
|
|
- "#@markdown ####**Init Settings:**\n",
|
|
|
- "init_image = \"/content/drive/MyDrive/AI/Disco_Diffusion/init_images/xv_1_decoupe_noback.jpg\" #@param{type: 'string'}\n",
|
|
|
- "init_scale = 1000 #@param{type: 'integer'}\n",
|
|
|
- "skip_steps = 50 #@param{type: 'integer'}\n",
|
|
|
- "#@markdown *Make sure you set skip_steps to ~50% of your steps if you want to use an init image.*\n",
|
|
|
- "\n",
|
|
|
- "#Get corrected sizes\n",
|
|
|
- "side_x = (width_height[0]//64)*64;\n",
|
|
|
- "side_y = (width_height[1]//64)*64;\n",
|
|
|
- "if side_x != width_height[0] or side_y != width_height[1]:\n",
|
|
|
- " print(f'Changing output size to {side_x}x{side_y}. Dimensions must by multiples of 64.')\n",
|
|
|
- "\n",
|
|
|
- "#Update Model Settings\n",
|
|
|
- "timestep_respacing = f'ddim{steps}'\n",
|
|
|
- "diffusion_steps = (1000//steps)*steps if steps < 1000 else steps\n",
|
|
|
- "model_config.update({\n",
|
|
|
- " 'timestep_respacing': timestep_respacing,\n",
|
|
|
- " 'diffusion_steps': diffusion_steps,\n",
|
|
|
- "})\n",
|
|
|
- "\n",
|
|
|
- "#Make folder for batch\n",
|
|
|
- "batchFolder = f'{outDirPath}/{batch_name}'\n",
|
|
|
- "createPath(batchFolder)\n"
|
|
|
- ],
|
|
|
- "outputs": [],
|
|
|
- "execution_count": null
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "markdown",
|
|
|
- "metadata": {
|
|
|
- "id": "AnimSetTop"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "### Animation Settings"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "metadata": {
|
|
|
- "id": "AnimSettings"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "#@markdown ####**Animation Mode:**\n",
|
|
|
- "animation_mode = 'None' #@param ['None', '2D', '3D', 'Video Input'] {type:'string'}\n",
|
|
|
- "#@markdown *For animation, you probably want to turn `cutn_batches` to 1 to make it quicker.*\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "#@markdown ---\n",
|
|
|
- "\n",
|
|
|
- "#@markdown ####**Video Input Settings:**\n",
|
|
|
- "if is_colab:\n",
|
|
|
- " video_init_path = \"/content/training.mp4\" #@param {type: 'string'}\n",
|
|
|
- "else:\n",
|
|
|
- " video_init_path = \"training.mp4\" #@param {type: 'string'}\n",
|
|
|
- "extract_nth_frame = 2 #@param {type: 'number'}\n",
|
|
|
- "video_init_seed_continuity = True #@param {type: 'boolean'}\n",
|
|
|
- "\n",
|
|
|
- "if animation_mode == \"Video Input\":\n",
|
|
|
- " if is_colab:\n",
|
|
|
- " videoFramesFolder = f'/content/videoFrames'\n",
|
|
|
- " else:\n",
|
|
|
- " videoFramesFolder = f'videoFrames'\n",
|
|
|
- " createPath(videoFramesFolder)\n",
|
|
|
- " print(f\"Exporting Video Frames (1 every {extract_nth_frame})...\")\n",
|
|
|
- " try:\n",
|
|
|
- " for f in pathlib.Path(f'{videoFramesFolder}').glob('*.jpg'):\n",
|
|
|
- " f.unlink()\n",
|
|
|
- " except:\n",
|
|
|
- " print('')\n",
|
|
|
- " vf = f'\"select=not(mod(n\\,{extract_nth_frame}))\"'\n",
|
|
|
- " subprocess.run(['ffmpeg', '-i', f'{video_init_path}', '-vf', f'{vf}', '-vsync', 'vfr', '-q:v', '2', '-loglevel', 'error', '-stats', f'{videoFramesFolder}/%04d.jpg'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n",
|
|
|
- " #!ffmpeg -i {video_init_path} -vf {vf} -vsync vfr -q:v 2 -loglevel error -stats {videoFramesFolder}/%04d.jpg\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "#@markdown ---\n",
|
|
|
- "\n",
|
|
|
- "#@markdown ####**2D Animation Settings:**\n",
|
|
|
- "#@markdown `zoom` is a multiplier of dimensions, 1 is no zoom.\n",
|
|
|
- "#@markdown All rotations are provided in degrees.\n",
|
|
|
- "\n",
|
|
|
- "key_frames = True #@param {type:\"boolean\"}\n",
|
|
|
- "max_frames = 10000#@param {type:\"number\"}\n",
|
|
|
- "\n",
|
|
|
- "if animation_mode == \"Video Input\":\n",
|
|
|
- " max_frames = len(glob(f'{videoFramesFolder}/*.jpg'))\n",
|
|
|
- "\n",
|
|
|
- "interp_spline = 'Linear' #Do not change, currently will not look good. param ['Linear','Quadratic','Cubic']{type:\"string\"}\n",
|
|
|
- "angle = \"0:(0)\"#@param {type:\"string\"}\n",
|
|
|
- "zoom = \"0: (1), 10: (1.05)\"#@param {type:\"string\"}\n",
|
|
|
- "translation_x = \"0: (0)\"#@param {type:\"string\"}\n",
|
|
|
- "translation_y = \"0: (0)\"#@param {type:\"string\"}\n",
|
|
|
- "translation_z = \"0: (10.0)\"#@param {type:\"string\"}\n",
|
|
|
- "rotation_3d_x = \"0: (0)\"#@param {type:\"string\"}\n",
|
|
|
- "rotation_3d_y = \"0: (0)\"#@param {type:\"string\"}\n",
|
|
|
- "rotation_3d_z = \"0: (0)\"#@param {type:\"string\"}\n",
|
|
|
- "midas_depth_model = \"dpt_large\"#@param {type:\"string\"}\n",
|
|
|
- "midas_weight = 0.3#@param {type:\"number\"}\n",
|
|
|
- "near_plane = 200#@param {type:\"number\"}\n",
|
|
|
- "far_plane = 10000#@param {type:\"number\"}\n",
|
|
|
- "fov = 40#@param {type:\"number\"}\n",
|
|
|
- "padding_mode = 'border'#@param {type:\"string\"}\n",
|
|
|
- "sampling_mode = 'bicubic'#@param {type:\"string\"}\n",
|
|
|
- "\n",
|
|
|
- "#======= TURBO MODE\n",
|
|
|
- "#@markdown ---\n",
|
|
|
- "#@markdown ####**Turbo Mode (3D anim only):**\n",
|
|
|
- "#@markdown (Starts after frame 10,) skips diffusion steps and just uses depth map to warp images for skipped frames.\n",
|
|
|
- "#@markdown Speeds up rendering by 2x-4x, and may improve image coherence between frames. frame_blend_mode smooths abrupt texture changes across 2 frames.\n",
|
|
|
- "#@markdown For different settings tuned for Turbo Mode, refer to the original Disco-Turbo Github: https://github.com/zippy731/disco-diffusion-turbo\n",
|
|
|
- "\n",
|
|
|
- "turbo_mode = False #@param {type:\"boolean\"}\n",
|
|
|
- "turbo_steps = \"3\" #@param [\"2\",\"3\",\"4\",\"5\",\"6\"] {type:\"string\"}\n",
|
|
|
- "turbo_preroll = 10 # frames\n",
|
|
|
- "\n",
|
|
|
- "#insist turbo be used only w 3d anim.\n",
|
|
|
- "if turbo_mode and animation_mode != '3D':\n",
|
|
|
- " print('=====')\n",
|
|
|
- " print('Turbo mode only available with 3D animations. Disabling Turbo.')\n",
|
|
|
- " print('=====')\n",
|
|
|
- " turbo_mode = False\n",
|
|
|
- "\n",
|
|
|
- "#@markdown ---\n",
|
|
|
- "\n",
|
|
|
- "#@markdown ####**Coherency Settings:**\n",
|
|
|
- "#@markdown `frame_scale` tries to guide the new frame to looking like the old one. A good default is 1500.\n",
|
|
|
- "frames_scale = 1500 #@param{type: 'integer'}\n",
|
|
|
- "#@markdown `frame_skip_steps` will blur the previous frame - higher values will flicker less but struggle to add enough new detail to zoom into.\n",
|
|
|
- "frames_skip_steps = '60%' #@param ['40%', '50%', '60%', '70%', '80%'] {type: 'string'}\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "def parse_key_frames(string, prompt_parser=None):\n",
|
|
|
- " \"\"\"Given a string representing frame numbers paired with parameter values at that frame,\n",
|
|
|
- " return a dictionary with the frame numbers as keys and the parameter values as the values.\n",
|
|
|
- "\n",
|
|
|
- " Parameters\n",
|
|
|
- " ----------\n",
|
|
|
- " string: string\n",
|
|
|
- " Frame numbers paired with parameter values at that frame number, in the format\n",
|
|
|
- " 'framenumber1: (parametervalues1), framenumber2: (parametervalues2), ...'\n",
|
|
|
- " prompt_parser: function or None, optional\n",
|
|
|
- " If provided, prompt_parser will be applied to each string of parameter values.\n",
|
|
|
- " \n",
|
|
|
- " Returns\n",
|
|
|
- " -------\n",
|
|
|
- " dict\n",
|
|
|
- " Frame numbers as keys, parameter values at that frame number as values\n",
|
|
|
- "\n",
|
|
|
- " Raises\n",
|
|
|
- " ------\n",
|
|
|
- " RuntimeError\n",
|
|
|
- " If the input string does not match the expected format.\n",
|
|
|
- " \n",
|
|
|
- " Examples\n",
|
|
|
- " --------\n",
|
|
|
- " >>> parse_key_frames(\"10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)\")\n",
|
|
|
- " {10: 'Apple: 1| Orange: 0', 20: 'Apple: 0| Orange: 1| Peach: 1'}\n",
|
|
|
- "\n",
|
|
|
- " >>> parse_key_frames(\"10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)\", prompt_parser=lambda x: x.lower()))\n",
|
|
|
- " {10: 'apple: 1| orange: 0', 20: 'apple: 0| orange: 1| peach: 1'}\n",
|
|
|
- " \"\"\"\n",
|
|
|
- " import re\n",
|
|
|
- " pattern = r'((?P<frame>[0-9]+):[\\s]*[\\(](?P<param>[\\S\\s]*?)[\\)])'\n",
|
|
|
- " frames = dict()\n",
|
|
|
- " for match_object in re.finditer(pattern, string):\n",
|
|
|
- " frame = int(match_object.groupdict()['frame'])\n",
|
|
|
- " param = match_object.groupdict()['param']\n",
|
|
|
- " if prompt_parser:\n",
|
|
|
- " frames[frame] = prompt_parser(param)\n",
|
|
|
- " else:\n",
|
|
|
- " frames[frame] = param\n",
|
|
|
- "\n",
|
|
|
- " if frames == {} and len(string) != 0:\n",
|
|
|
- " raise RuntimeError('Key Frame string not correctly formatted')\n",
|
|
|
- " return frames\n",
|
|
|
- "\n",
|
|
|
- "def get_inbetweens(key_frames, integer=False):\n",
|
|
|
- " \"\"\"Given a dict with frame numbers as keys and a parameter value as values,\n",
|
|
|
- " return a pandas Series containing the value of the parameter at every frame from 0 to max_frames.\n",
|
|
|
- " Any values not provided in the input dict are calculated by linear interpolation between\n",
|
|
|
- " the values of the previous and next provided frames. If there is no previous provided frame, then\n",
|
|
|
- " the value is equal to the value of the next provided frame, or if there is no next provided frame,\n",
|
|
|
- " then the value is equal to the value of the previous provided frame. If no frames are provided,\n",
|
|
|
- " all frame values are NaN.\n",
|
|
|
- "\n",
|
|
|
- " Parameters\n",
|
|
|
- " ----------\n",
|
|
|
- " key_frames: dict\n",
|
|
|
- " A dict with integer frame numbers as keys and numerical values of a particular parameter as values.\n",
|
|
|
- " integer: Bool, optional\n",
|
|
|
- " If True, the values of the output series are converted to integers.\n",
|
|
|
- " Otherwise, the values are floats.\n",
|
|
|
- " \n",
|
|
|
- " Returns\n",
|
|
|
- " -------\n",
|
|
|
- " pd.Series\n",
|
|
|
- " A Series with length max_frames representing the parameter values for each frame.\n",
|
|
|
- " \n",
|
|
|
- " Examples\n",
|
|
|
- " --------\n",
|
|
|
- " >>> max_frames = 5\n",
|
|
|
- " >>> get_inbetweens({1: 5, 3: 6})\n",
|
|
|
- " 0 5.0\n",
|
|
|
- " 1 5.0\n",
|
|
|
- " 2 5.5\n",
|
|
|
- " 3 6.0\n",
|
|
|
- " 4 6.0\n",
|
|
|
- " dtype: float64\n",
|
|
|
- "\n",
|
|
|
- " >>> get_inbetweens({1: 5, 3: 6}, integer=True)\n",
|
|
|
- " 0 5\n",
|
|
|
- " 1 5\n",
|
|
|
- " 2 5\n",
|
|
|
- " 3 6\n",
|
|
|
- " 4 6\n",
|
|
|
- " dtype: int64\n",
|
|
|
- " \"\"\"\n",
|
|
|
- " key_frame_series = pd.Series([np.nan for a in range(max_frames)])\n",
|
|
|
- "\n",
|
|
|
- " for i, value in key_frames.items():\n",
|
|
|
- " key_frame_series[i] = value\n",
|
|
|
- " key_frame_series = key_frame_series.astype(float)\n",
|
|
|
- " \n",
|
|
|
- " interp_method = interp_spline\n",
|
|
|
- "\n",
|
|
|
- " if interp_method == 'Cubic' and len(key_frames.items()) <=3:\n",
|
|
|
- " interp_method = 'Quadratic'\n",
|
|
|
- " \n",
|
|
|
- " if interp_method == 'Quadratic' and len(key_frames.items()) <= 2:\n",
|
|
|
- " interp_method = 'Linear'\n",
|
|
|
- " \n",
|
|
|
- " \n",
|
|
|
- " key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()]\n",
|
|
|
- " key_frame_series[max_frames-1] = key_frame_series[key_frame_series.last_valid_index()]\n",
|
|
|
- " # key_frame_series = key_frame_series.interpolate(method=intrp_method,order=1, limit_direction='both')\n",
|
|
|
- " key_frame_series = key_frame_series.interpolate(method=interp_method.lower(),limit_direction='both')\n",
|
|
|
- " if integer:\n",
|
|
|
- " return key_frame_series.astype(int)\n",
|
|
|
- " return key_frame_series\n",
|
|
|
- "\n",
|
|
|
- "def split_prompts(prompts):\n",
|
|
|
- " prompt_series = pd.Series([np.nan for a in range(max_frames)])\n",
|
|
|
- " for i, prompt in prompts.items():\n",
|
|
|
- " prompt_series[i] = prompt\n",
|
|
|
- " # prompt_series = prompt_series.astype(str)\n",
|
|
|
- " prompt_series = prompt_series.ffill().bfill()\n",
|
|
|
- " return prompt_series\n",
|
|
|
- "\n",
|
|
|
- "if key_frames:\n",
|
|
|
- " try:\n",
|
|
|
- " angle_series = get_inbetweens(parse_key_frames(angle))\n",
|
|
|
- " except RuntimeError as e:\n",
|
|
|
- " print(\n",
|
|
|
- " \"WARNING: You have selected to use key frames, but you have not \"\n",
|
|
|
- " \"formatted `angle` correctly for key frames.\\n\"\n",
|
|
|
- " \"Attempting to interpret `angle` as \"\n",
|
|
|
- " f'\"0: ({angle})\"\\n'\n",
|
|
|
- " \"Please read the instructions to find out how to use key frames \"\n",
|
|
|
- " \"correctly.\\n\"\n",
|
|
|
- " )\n",
|
|
|
- " angle = f\"0: ({angle})\"\n",
|
|
|
- " angle_series = get_inbetweens(parse_key_frames(angle))\n",
|
|
|
- "\n",
|
|
|
- " try:\n",
|
|
|
- " zoom_series = get_inbetweens(parse_key_frames(zoom))\n",
|
|
|
- " except RuntimeError as e:\n",
|
|
|
- " print(\n",
|
|
|
- " \"WARNING: You have selected to use key frames, but you have not \"\n",
|
|
|
- " \"formatted `zoom` correctly for key frames.\\n\"\n",
|
|
|
- " \"Attempting to interpret `zoom` as \"\n",
|
|
|
- " f'\"0: ({zoom})\"\\n'\n",
|
|
|
- " \"Please read the instructions to find out how to use key frames \"\n",
|
|
|
- " \"correctly.\\n\"\n",
|
|
|
- " )\n",
|
|
|
- " zoom = f\"0: ({zoom})\"\n",
|
|
|
- " zoom_series = get_inbetweens(parse_key_frames(zoom))\n",
|
|
|
- "\n",
|
|
|
- " try:\n",
|
|
|
- " translation_x_series = get_inbetweens(parse_key_frames(translation_x))\n",
|
|
|
- " except RuntimeError as e:\n",
|
|
|
- " print(\n",
|
|
|
- " \"WARNING: You have selected to use key frames, but you have not \"\n",
|
|
|
- " \"formatted `translation_x` correctly for key frames.\\n\"\n",
|
|
|
- " \"Attempting to interpret `translation_x` as \"\n",
|
|
|
- " f'\"0: ({translation_x})\"\\n'\n",
|
|
|
- " \"Please read the instructions to find out how to use key frames \"\n",
|
|
|
- " \"correctly.\\n\"\n",
|
|
|
- " )\n",
|
|
|
- " translation_x = f\"0: ({translation_x})\"\n",
|
|
|
- " translation_x_series = get_inbetweens(parse_key_frames(translation_x))\n",
|
|
|
- "\n",
|
|
|
- " try:\n",
|
|
|
- " translation_y_series = get_inbetweens(parse_key_frames(translation_y))\n",
|
|
|
- " except RuntimeError as e:\n",
|
|
|
- " print(\n",
|
|
|
- " \"WARNING: You have selected to use key frames, but you have not \"\n",
|
|
|
- " \"formatted `translation_y` correctly for key frames.\\n\"\n",
|
|
|
- " \"Attempting to interpret `translation_y` as \"\n",
|
|
|
- " f'\"0: ({translation_y})\"\\n'\n",
|
|
|
- " \"Please read the instructions to find out how to use key frames \"\n",
|
|
|
- " \"correctly.\\n\"\n",
|
|
|
- " )\n",
|
|
|
- " translation_y = f\"0: ({translation_y})\"\n",
|
|
|
- " translation_y_series = get_inbetweens(parse_key_frames(translation_y))\n",
|
|
|
- "\n",
|
|
|
- " try:\n",
|
|
|
- " translation_z_series = get_inbetweens(parse_key_frames(translation_z))\n",
|
|
|
- " except RuntimeError as e:\n",
|
|
|
- " print(\n",
|
|
|
- " \"WARNING: You have selected to use key frames, but you have not \"\n",
|
|
|
- " \"formatted `translation_z` correctly for key frames.\\n\"\n",
|
|
|
- " \"Attempting to interpret `translation_z` as \"\n",
|
|
|
- " f'\"0: ({translation_z})\"\\n'\n",
|
|
|
- " \"Please read the instructions to find out how to use key frames \"\n",
|
|
|
- " \"correctly.\\n\"\n",
|
|
|
- " )\n",
|
|
|
- " translation_z = f\"0: ({translation_z})\"\n",
|
|
|
- " translation_z_series = get_inbetweens(parse_key_frames(translation_z))\n",
|
|
|
- "\n",
|
|
|
- " try:\n",
|
|
|
- " rotation_3d_x_series = get_inbetweens(parse_key_frames(rotation_3d_x))\n",
|
|
|
- " except RuntimeError as e:\n",
|
|
|
- " print(\n",
|
|
|
- " \"WARNING: You have selected to use key frames, but you have not \"\n",
|
|
|
- " \"formatted `rotation_3d_x` correctly for key frames.\\n\"\n",
|
|
|
- " \"Attempting to interpret `rotation_3d_x` as \"\n",
|
|
|
- " f'\"0: ({rotation_3d_x})\"\\n'\n",
|
|
|
- " \"Please read the instructions to find out how to use key frames \"\n",
|
|
|
- " \"correctly.\\n\"\n",
|
|
|
- " )\n",
|
|
|
- " rotation_3d_x = f\"0: ({rotation_3d_x})\"\n",
|
|
|
- " rotation_3d_x_series = get_inbetweens(parse_key_frames(rotation_3d_x))\n",
|
|
|
- "\n",
|
|
|
- " try:\n",
|
|
|
- " rotation_3d_y_series = get_inbetweens(parse_key_frames(rotation_3d_y))\n",
|
|
|
- " except RuntimeError as e:\n",
|
|
|
- " print(\n",
|
|
|
- " \"WARNING: You have selected to use key frames, but you have not \"\n",
|
|
|
- " \"formatted `rotation_3d_y` correctly for key frames.\\n\"\n",
|
|
|
- " \"Attempting to interpret `rotation_3d_y` as \"\n",
|
|
|
- " f'\"0: ({rotation_3d_y})\"\\n'\n",
|
|
|
- " \"Please read the instructions to find out how to use key frames \"\n",
|
|
|
- " \"correctly.\\n\"\n",
|
|
|
- " )\n",
|
|
|
- " rotation_3d_y = f\"0: ({rotation_3d_y})\"\n",
|
|
|
- " rotation_3d_y_series = get_inbetweens(parse_key_frames(rotation_3d_y))\n",
|
|
|
- "\n",
|
|
|
- " try:\n",
|
|
|
- " rotation_3d_z_series = get_inbetweens(parse_key_frames(rotation_3d_z))\n",
|
|
|
- " except RuntimeError as e:\n",
|
|
|
- " print(\n",
|
|
|
- " \"WARNING: You have selected to use key frames, but you have not \"\n",
|
|
|
- " \"formatted `rotation_3d_z` correctly for key frames.\\n\"\n",
|
|
|
- " \"Attempting to interpret `rotation_3d_z` as \"\n",
|
|
|
- " f'\"0: ({rotation_3d_z})\"\\n'\n",
|
|
|
- " \"Please read the instructions to find out how to use key frames \"\n",
|
|
|
- " \"correctly.\\n\"\n",
|
|
|
- " )\n",
|
|
|
- " rotation_3d_z = f\"0: ({rotation_3d_z})\"\n",
|
|
|
- " rotation_3d_z_series = get_inbetweens(parse_key_frames(rotation_3d_z))\n",
|
|
|
- "\n",
|
|
|
- "else:\n",
|
|
|
- " angle = float(angle)\n",
|
|
|
- " zoom = float(zoom)\n",
|
|
|
- " translation_x = float(translation_x)\n",
|
|
|
- " translation_y = float(translation_y)\n",
|
|
|
- " translation_z = float(translation_z)\n",
|
|
|
- " rotation_3d_x = float(rotation_3d_x)\n",
|
|
|
- " rotation_3d_y = float(rotation_3d_y)\n",
|
|
|
- " rotation_3d_z = float(rotation_3d_z)\n"
|
|
|
- ],
|
|
|
- "outputs": [],
|
|
|
- "execution_count": null
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "markdown",
|
|
|
- "metadata": {
|
|
|
- "id": "ExtraSetTop"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "### Extra Settings\n",
|
|
|
- " Partial Saves, Diffusion Sharpening, Advanced Settings, Cutn Scheduling"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "metadata": {
|
|
|
- "id": "ExtraSettings"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "#@markdown ####**Saving:**\n",
|
|
|
- "\n",
|
|
|
- "intermediate_saves = 4#@param{type: 'raw'}\n",
|
|
|
- "intermediates_in_subfolder = True #@param{type: 'boolean'}\n",
|
|
|
- "#@markdown Intermediate steps will save a copy at your specified intervals. You can either format it as a single integer or a list of specific steps \n",
|
|
|
- "\n",
|
|
|
- "#@markdown A value of `2` will save a copy at 33% and 66%. 0 will save none.\n",
|
|
|
- "\n",
|
|
|
- "#@markdown A value of `[5, 9, 34, 45]` will save at steps 5, 9, 34, and 45. (Make sure to include the brackets)\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "if type(intermediate_saves) is not list:\n",
|
|
|
- " if intermediate_saves:\n",
|
|
|
- " steps_per_checkpoint = math.floor((steps - skip_steps - 1) // (intermediate_saves+1))\n",
|
|
|
- " steps_per_checkpoint = steps_per_checkpoint if steps_per_checkpoint > 0 else 1\n",
|
|
|
- " print(f'Will save every {steps_per_checkpoint} steps')\n",
|
|
|
- " else:\n",
|
|
|
- " steps_per_checkpoint = steps+10\n",
|
|
|
- "else:\n",
|
|
|
- " steps_per_checkpoint = None\n",
|
|
|
- "\n",
|
|
|
- "if intermediate_saves and intermediates_in_subfolder is True:\n",
|
|
|
- " partialFolder = f'{batchFolder}/partials'\n",
|
|
|
- " createPath(partialFolder)\n",
|
|
|
- "\n",
|
|
|
- " #@markdown ---\n",
|
|
|
- "\n",
|
|
|
- "#@markdown ####**SuperRes Sharpening:**\n",
|
|
|
- "#@markdown *Sharpen each image using latent-diffusion. Does not run in animation mode. `keep_unsharp` will save both versions.*\n",
|
|
|
- "sharpen_preset = 'Fast' #@param ['Off', 'Faster', 'Fast', 'Slow', 'Very Slow']\n",
|
|
|
- "keep_unsharp = True #@param{type: 'boolean'}\n",
|
|
|
- "\n",
|
|
|
- "if sharpen_preset != 'Off' and keep_unsharp is True:\n",
|
|
|
- " unsharpenFolder = f'{batchFolder}/unsharpened'\n",
|
|
|
- " createPath(unsharpenFolder)\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- " #@markdown ---\n",
|
|
|
- "\n",
|
|
|
- "#@markdown ####**Advanced Settings:**\n",
|
|
|
- "#@markdown *There are a few extra advanced settings available if you double click this cell.*\n",
|
|
|
- "\n",
|
|
|
- "#@markdown *Perlin init will replace your init, so uncheck if using one.*\n",
|
|
|
- "\n",
|
|
|
- "perlin_init = False #@param{type: 'boolean'}\n",
|
|
|
- "perlin_mode = 'mixed' #@param ['mixed', 'color', 'gray']\n",
|
|
|
- "set_seed = 'random_seed' #@param{type: 'string'}\n",
|
|
|
- "eta = 0.8#@param{type: 'number'}\n",
|
|
|
- "clamp_grad = True #@param{type: 'boolean'}\n",
|
|
|
- "clamp_max = 0.05 #@param{type: 'number'}\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "### EXTRA ADVANCED SETTINGS:\n",
|
|
|
- "randomize_class = True\n",
|
|
|
- "clip_denoised = False\n",
|
|
|
- "fuzzy_prompt = False\n",
|
|
|
- "rand_mag = 0.05\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- " #@markdown ---\n",
|
|
|
- "\n",
|
|
|
- "#@markdown ####**Cutn Scheduling:**\n",
|
|
|
- "#@markdown Format: `[40]*400+[20]*600` = 40 cuts for the first 400 /1000 steps, then 20 for the last 600/1000\n",
|
|
|
- "\n",
|
|
|
- "#@markdown cut_overview and cut_innercut are cumulative for total cutn on any given step. Overview cuts see the entire image and are good for early structure, innercuts are your standard cutn.\n",
|
|
|
- "\n",
|
|
|
- "cut_overview = \"[12]*400+[4]*600\" #@param {type: 'string'} \n",
|
|
|
- "cut_innercut =\"[4]*400+[12]*600\"#@param {type: 'string'} \n",
|
|
|
- "cut_ic_pow = 1#@param {type: 'number'} \n",
|
|
|
- "cut_icgray_p = \"[0.2]*400+[0]*600\"#@param {type: 'string'}\n"
|
|
|
- ],
|
|
|
- "outputs": [],
|
|
|
- "execution_count": null
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "markdown",
|
|
|
- "metadata": {
|
|
|
- "id": "PromptsTop"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "### Prompts\n",
|
|
|
- "`animation_mode: None` will only use the first set. `animation_mode: 2D / Video` will run through them per the set frames and hold on the last one."
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "metadata": {
|
|
|
- "id": "Prompts"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "text_prompts = {\n",
|
|
|
- " 0: [\n",
|
|
|
- " \"megastructure in the cloud, blame!, contemporary house in the mist, artstation\",\n",
|
|
|
- " ]\n",
|
|
|
- "}\n",
|
|
|
- "\n",
|
|
|
- "image_prompts = {\n",
|
|
|
- " # 0:['ImagePromptsWorkButArentVeryGood.png:2',],\n",
|
|
|
- "}\n"
|
|
|
- ],
|
|
|
- "outputs": [],
|
|
|
- "execution_count": null
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "markdown",
|
|
|
- "metadata": {
|
|
|
- "id": "DiffuseTop"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "# 4. Diffuse!"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "metadata": {
|
|
|
- "id": "DoTheRun"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "#@title Do the Run!\n",
|
|
|
- "#@markdown `n_batches` ignored with animation modes.\n",
|
|
|
- "display_rate = 50#@param{type: 'number'}\n",
|
|
|
- "n_batches = 50#@param{type: 'number'}\n",
|
|
|
- "\n",
|
|
|
- "#Update Model Settings\n",
|
|
|
- "timestep_respacing = f'ddim{steps}'\n",
|
|
|
- "diffusion_steps = (1000//steps)*steps if steps < 1000 else steps\n",
|
|
|
- "model_config.update({\n",
|
|
|
- " 'timestep_respacing': timestep_respacing,\n",
|
|
|
- " 'diffusion_steps': diffusion_steps,\n",
|
|
|
- "})\n",
|
|
|
- "\n",
|
|
|
- "batch_size = 1 \n",
|
|
|
- "\n",
|
|
|
- "def move_files(start_num, end_num, old_folder, new_folder):\n",
|
|
|
- " for i in range(start_num, end_num):\n",
|
|
|
- " old_file = old_folder + f'/{batch_name}({batchNum})_{i:04}.png'\n",
|
|
|
- " new_file = new_folder + f'/{batch_name}({batchNum})_{i:04}.png'\n",
|
|
|
- " os.rename(old_file, new_file)\n",
|
|
|
- "\n",
|
|
|
- "#@markdown ---\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "resume_run = False #@param{type: 'boolean'}\n",
|
|
|
- "run_to_resume = 'latest' #@param{type: 'string'}\n",
|
|
|
- "resume_from_frame = 'latest' #@param{type: 'string'}\n",
|
|
|
- "retain_overwritten_frames = False #@param{type: 'boolean'}\n",
|
|
|
- "if retain_overwritten_frames is True:\n",
|
|
|
- " retainFolder = f'{batchFolder}/retained'\n",
|
|
|
- " createPath(retainFolder)\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "skip_step_ratio = int(frames_skip_steps.rstrip(\"%\")) / 100\n",
|
|
|
- "calc_frames_skip_steps = math.floor(steps * skip_step_ratio)\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- "if steps <= calc_frames_skip_steps:\n",
|
|
|
- " sys.exit(\"ERROR: You can't skip more steps than your total steps\")\n",
|
|
|
- "\n",
|
|
|
- "if resume_run:\n",
|
|
|
- " if run_to_resume == 'latest':\n",
|
|
|
- " try:\n",
|
|
|
- " batchNum\n",
|
|
|
- " except:\n",
|
|
|
- " batchNum = len(glob(f\"{batchFolder}/{batch_name}(*)_settings.txt\"))-1\n",
|
|
|
- " else:\n",
|
|
|
- " batchNum = int(run_to_resume)\n",
|
|
|
- " if resume_from_frame == 'latest':\n",
|
|
|
- " start_frame = len(glob(batchFolder+f\"/{batch_name}({batchNum})_*.png\"))\n",
|
|
|
- " if animation_mode != '3D' and turbo_mode == True and start_frame > turbo_preroll and start_frame % int(turbo_steps) != 0:\n",
|
|
|
- " start_frame = start_frame - (start_frame % int(turbo_steps))\n",
|
|
|
- " else:\n",
|
|
|
- " start_frame = int(resume_from_frame)+1\n",
|
|
|
- " if animation_mode != '3D' and turbo_mode == True and start_frame > turbo_preroll and start_frame % int(turbo_steps) != 0:\n",
|
|
|
- " start_frame = start_frame - (start_frame % int(turbo_steps))\n",
|
|
|
- " if retain_overwritten_frames is True:\n",
|
|
|
- " existing_frames = len(glob(batchFolder+f\"/{batch_name}({batchNum})_*.png\"))\n",
|
|
|
- " frames_to_save = existing_frames - start_frame\n",
|
|
|
- " print(f'Moving {frames_to_save} frames to the Retained folder')\n",
|
|
|
- " move_files(start_frame, existing_frames, batchFolder, retainFolder)\n",
|
|
|
- "else:\n",
|
|
|
- " start_frame = 0\n",
|
|
|
- " batchNum = len(glob(batchFolder+\"/*.txt\"))\n",
|
|
|
- " while os.path.isfile(f\"{batchFolder}/{batch_name}({batchNum})_settings.txt\") is True or os.path.isfile(f\"{batchFolder}/{batch_name}-{batchNum}_settings.txt\") is True:\n",
|
|
|
- " batchNum += 1\n",
|
|
|
- "\n",
|
|
|
- "print(f'Starting Run: {batch_name}({batchNum}) at frame {start_frame}')\n",
|
|
|
- "\n",
|
|
|
- "if set_seed == 'random_seed':\n",
|
|
|
- " random.seed()\n",
|
|
|
- " seed = random.randint(0, 2**32)\n",
|
|
|
- " # print(f'Using seed: {seed}')\n",
|
|
|
- "else:\n",
|
|
|
- " seed = int(set_seed)\n",
|
|
|
- "\n",
|
|
|
- "args = {\n",
|
|
|
- " 'batchNum': batchNum,\n",
|
|
|
- " 'prompts_series':split_prompts(text_prompts) if text_prompts else None,\n",
|
|
|
- " 'image_prompts_series':split_prompts(image_prompts) if image_prompts else None,\n",
|
|
|
- " 'seed': seed,\n",
|
|
|
- " 'display_rate':display_rate,\n",
|
|
|
- " 'n_batches':n_batches if animation_mode == 'None' else 1,\n",
|
|
|
- " 'batch_size':batch_size,\n",
|
|
|
- " 'batch_name': batch_name,\n",
|
|
|
- " 'steps': steps,\n",
|
|
|
- " 'diffusion_sampling_mode': diffusion_sampling_mode,\n",
|
|
|
- " 'width_height': width_height,\n",
|
|
|
- " 'clip_guidance_scale': clip_guidance_scale,\n",
|
|
|
- " 'tv_scale': tv_scale,\n",
|
|
|
- " 'range_scale': range_scale,\n",
|
|
|
- " 'sat_scale': sat_scale,\n",
|
|
|
- " 'cutn_batches': cutn_batches,\n",
|
|
|
- " 'init_image': init_image,\n",
|
|
|
- " 'init_scale': init_scale,\n",
|
|
|
- " 'skip_steps': skip_steps,\n",
|
|
|
- " 'sharpen_preset': sharpen_preset,\n",
|
|
|
- " 'keep_unsharp': keep_unsharp,\n",
|
|
|
- " 'side_x': side_x,\n",
|
|
|
- " 'side_y': side_y,\n",
|
|
|
- " 'timestep_respacing': timestep_respacing,\n",
|
|
|
- " 'diffusion_steps': diffusion_steps,\n",
|
|
|
- " 'animation_mode': animation_mode,\n",
|
|
|
- " 'video_init_path': video_init_path,\n",
|
|
|
- " 'extract_nth_frame': extract_nth_frame,\n",
|
|
|
- " 'video_init_seed_continuity': video_init_seed_continuity,\n",
|
|
|
- " 'key_frames': key_frames,\n",
|
|
|
- " 'max_frames': max_frames if animation_mode != \"None\" else 1,\n",
|
|
|
- " 'interp_spline': interp_spline,\n",
|
|
|
- " 'start_frame': start_frame,\n",
|
|
|
- " 'angle': angle,\n",
|
|
|
- " 'zoom': zoom,\n",
|
|
|
- " 'translation_x': translation_x,\n",
|
|
|
- " 'translation_y': translation_y,\n",
|
|
|
- " 'translation_z': translation_z,\n",
|
|
|
- " 'rotation_3d_x': rotation_3d_x,\n",
|
|
|
- " 'rotation_3d_y': rotation_3d_y,\n",
|
|
|
- " 'rotation_3d_z': rotation_3d_z,\n",
|
|
|
- " 'midas_depth_model': midas_depth_model,\n",
|
|
|
- " 'midas_weight': midas_weight,\n",
|
|
|
- " 'near_plane': near_plane,\n",
|
|
|
- " 'far_plane': far_plane,\n",
|
|
|
- " 'fov': fov,\n",
|
|
|
- " 'padding_mode': padding_mode,\n",
|
|
|
- " 'sampling_mode': sampling_mode,\n",
|
|
|
- " 'angle_series':angle_series,\n",
|
|
|
- " 'zoom_series':zoom_series,\n",
|
|
|
- " 'translation_x_series':translation_x_series,\n",
|
|
|
- " 'translation_y_series':translation_y_series,\n",
|
|
|
- " 'translation_z_series':translation_z_series,\n",
|
|
|
- " 'rotation_3d_x_series':rotation_3d_x_series,\n",
|
|
|
- " 'rotation_3d_y_series':rotation_3d_y_series,\n",
|
|
|
- " 'rotation_3d_z_series':rotation_3d_z_series,\n",
|
|
|
- " 'frames_scale': frames_scale,\n",
|
|
|
- " 'calc_frames_skip_steps': calc_frames_skip_steps,\n",
|
|
|
- " 'skip_step_ratio': skip_step_ratio,\n",
|
|
|
- " 'calc_frames_skip_steps': calc_frames_skip_steps,\n",
|
|
|
- " 'text_prompts': text_prompts,\n",
|
|
|
- " 'image_prompts': image_prompts,\n",
|
|
|
- " 'cut_overview': eval(cut_overview),\n",
|
|
|
- " 'cut_innercut': eval(cut_innercut),\n",
|
|
|
- " 'cut_ic_pow': cut_ic_pow,\n",
|
|
|
- " 'cut_icgray_p': eval(cut_icgray_p),\n",
|
|
|
- " 'intermediate_saves': intermediate_saves,\n",
|
|
|
- " 'intermediates_in_subfolder': intermediates_in_subfolder,\n",
|
|
|
- " 'steps_per_checkpoint': steps_per_checkpoint,\n",
|
|
|
- " 'perlin_init': perlin_init,\n",
|
|
|
- " 'perlin_mode': perlin_mode,\n",
|
|
|
- " 'set_seed': set_seed,\n",
|
|
|
- " 'eta': eta,\n",
|
|
|
- " 'clamp_grad': clamp_grad,\n",
|
|
|
- " 'clamp_max': clamp_max,\n",
|
|
|
- " 'skip_augs': skip_augs,\n",
|
|
|
- " 'randomize_class': randomize_class,\n",
|
|
|
- " 'clip_denoised': clip_denoised,\n",
|
|
|
- " 'fuzzy_prompt': fuzzy_prompt,\n",
|
|
|
- " 'rand_mag': rand_mag,\n",
|
|
|
- "}\n",
|
|
|
- "\n",
|
|
|
- "args = SimpleNamespace(**args)\n",
|
|
|
- "\n",
|
|
|
- "print('Prepping model...')\n",
|
|
|
- "model, diffusion = create_model_and_diffusion(**model_config)\n",
|
|
|
- "model.load_state_dict(torch.load(f'{model_path}/{diffusion_model}.pt', map_location='cpu'))\n",
|
|
|
- "model.requires_grad_(False).eval().to(device)\n",
|
|
|
- "for name, param in model.named_parameters():\n",
|
|
|
- " if 'qkv' in name or 'norm' in name or 'proj' in name:\n",
|
|
|
- " param.requires_grad_()\n",
|
|
|
- "if model_config['use_fp16']:\n",
|
|
|
- " model.convert_to_fp16()\n",
|
|
|
- "\n",
|
|
|
- "gc.collect()\n",
|
|
|
- "torch.cuda.empty_cache()\n",
|
|
|
- "try:\n",
|
|
|
- " do_run()\n",
|
|
|
- "except KeyboardInterrupt:\n",
|
|
|
- " pass\n",
|
|
|
- "finally:\n",
|
|
|
- " print('Seed used:', seed)\n",
|
|
|
- " gc.collect()\n",
|
|
|
- " torch.cuda.empty_cache()\n"
|
|
|
- ],
|
|
|
- "outputs": [],
|
|
|
- "execution_count": null
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "markdown",
|
|
|
- "metadata": {
|
|
|
- "id": "CreateVidTop"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "# 5. Create the video"
|
|
|
- ]
|
|
|
- },
|
|
|
- {
|
|
|
- "cell_type": "code",
|
|
|
- "metadata": {
|
|
|
- "id": "CreateVid"
|
|
|
- },
|
|
|
- "source": [
|
|
|
- "# @title ### **Create video**\n",
|
|
|
- "#@markdown Video file will save in the same folder as your images.\n",
|
|
|
- "\n",
|
|
|
- "skip_video_for_run_all = True #@param {type: 'boolean'}\n",
|
|
|
- "\n",
|
|
|
- "if skip_video_for_run_all == True:\n",
|
|
|
- " print('Skipping video creation, uncheck skip_video_for_run_all if you want to run it')\n",
|
|
|
- "\n",
|
|
|
- "else:\n",
|
|
|
- " # import subprocess in case this cell is run without the above cells\n",
|
|
|
- " import subprocess\n",
|
|
|
- " from base64 import b64encode\n",
|
|
|
- "\n",
|
|
|
- " latest_run = batchNum\n",
|
|
|
- "\n",
|
|
|
- " folder = batch_name #@param\n",
|
|
|
- " run = latest_run #@param\n",
|
|
|
- " final_frame = 'final_frame'\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- " init_frame = 1#@param {type:\"number\"} This is the frame where the video will start\n",
|
|
|
- " last_frame = final_frame#@param {type:\"number\"} You can change i to the number of the last frame you want to generate. It will raise an error if that number of frames does not exist.\n",
|
|
|
- " fps = 12#@param {type:\"number\"}\n",
|
|
|
- " # view_video_in_cell = True #@param {type: 'boolean'}\n",
|
|
|
- "\n",
|
|
|
- " frames = []\n",
|
|
|
- " # tqdm.write('Generating video...')\n",
|
|
|
- "\n",
|
|
|
- " if last_frame == 'final_frame':\n",
|
|
|
- " last_frame = len(glob(batchFolder+f\"/{folder}({run})_*.png\"))\n",
|
|
|
- " print(f'Total frames: {last_frame}')\n",
|
|
|
- "\n",
|
|
|
- " image_path = f\"{outDirPath}/{folder}/{folder}({run})_%04d.png\"\n",
|
|
|
- " filepath = f\"{outDirPath}/{folder}/{folder}({run}).mp4\"\n",
|
|
|
- "\n",
|
|
|
- "\n",
|
|
|
- " cmd = [\n",
|
|
|
- " 'ffmpeg',\n",
|
|
|
- " '-y',\n",
|
|
|
- " '-vcodec',\n",
|
|
|
- " 'png',\n",
|
|
|
- " '-r',\n",
|
|
|
- " str(fps),\n",
|
|
|
- " '-start_number',\n",
|
|
|
- " str(init_frame),\n",
|
|
|
- " '-i',\n",
|
|
|
- " image_path,\n",
|
|
|
- " '-frames:v',\n",
|
|
|
- " str(last_frame+1),\n",
|
|
|
- " '-c:v',\n",
|
|
|
- " 'libx264',\n",
|
|
|
- " '-vf',\n",
|
|
|
- " f'fps={fps}',\n",
|
|
|
- " '-pix_fmt',\n",
|
|
|
- " 'yuv420p',\n",
|
|
|
- " '-crf',\n",
|
|
|
- " '17',\n",
|
|
|
- " '-preset',\n",
|
|
|
- " 'veryslow',\n",
|
|
|
- " filepath\n",
|
|
|
- " ]\n",
|
|
|
- "\n",
|
|
|
- " process = subprocess.Popen(cmd, cwd=f'{batchFolder}', stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n",
|
|
|
- " stdout, stderr = process.communicate()\n",
|
|
|
- " if process.returncode != 0:\n",
|
|
|
- " print(stderr)\n",
|
|
|
- " raise RuntimeError(stderr)\n",
|
|
|
- " else:\n",
|
|
|
- " print(\"The video is ready and saved to the images folder\")\n",
|
|
|
- "\n",
|
|
|
- " # if view_video_in_cell:\n",
|
|
|
- " # mp4 = open(filepath,'rb').read()\n",
|
|
|
- " # data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
|
|
|
- " # display.HTML(f'<video width=400 controls><source src=\"{data_url}\" type=\"video/mp4\"></video>')\n",
|
|
|
- " \n"
|
|
|
- ],
|
|
|
- "outputs": [],
|
|
|
- "execution_count": null
|
|
|
- }
|
|
|
- ],
|
|
|
- "metadata": {
|
|
|
- "anaconda-cloud": {},
|
|
|
- "accelerator": "GPU",
|
|
|
- "colab": {
|
|
|
- "collapsed_sections": [
|
|
|
- "CreditsChTop",
|
|
|
- "TutorialTop",
|
|
|
- "CheckGPU",
|
|
|
- "InstallDeps",
|
|
|
- "DefMidasFns",
|
|
|
- "DefFns",
|
|
|
- "DefSecModel",
|
|
|
- "DefSuperRes",
|
|
|
- "AnimSetTop",
|
|
|
- "ExtraSetTop"
|
|
|
- ],
|
|
|
- "machine_shape": "hm",
|
|
|
- "name": "Copie de Disco Diffusion v5.1 [w/ Turbo]",
|
|
|
- "private_outputs": true,
|
|
|
- "provenance": []
|
|
|
- },
|
|
|
- "kernelspec": {
|
|
|
- "display_name": "Python 3",
|
|
|
- "language": "python",
|
|
|
- "name": "python3"
|
|
|
- },
|
|
|
- "language_info": {
|
|
|
- "codemirror_mode": {
|
|
|
- "name": "ipython",
|
|
|
- "version": 3
|
|
|
- },
|
|
|
- "file_extension": ".py",
|
|
|
- "mimetype": "text/x-python",
|
|
|
- "name": "python",
|
|
|
- "nbconvert_exporter": "python",
|
|
|
- "pygments_lexer": "ipython3",
|
|
|
- "version": "3.6.1"
|
|
|
- }
|
|
|
- },
|
|
|
- "nbformat": 4,
|
|
|
- "nbformat_minor": 0
|
|
|
+"cells": [
|
|
|
+{
|
|
|
+"cell_type": "markdown",
|
|
|
+"metadata": {
|
|
|
+"id": "1YwMUyt9LHG1"
|
|
|
+},
|
|
|
+"source": [
|
|
|
+"# Disco Diffusion v5 (Turbo+Smooth) - Now with 3D animation\n",
|
|
|
+"\n",
|
|
|
+"In case of confusion, Disco is the name of this notebook edit. The diffusion model in use is Katherine Crowson's fine-tuned 512x512 model\n",
|
|
|
+"\n",
|
|
|
+"For issues, join the [Disco Diffusion Discord](https://discord.gg/msEZBy4HxA) or message us on twitter at [@somnai_dreams](https://twitter.com/somnai_dreams) or [@gandamu](https://twitter.com/gandamu_ml)\n",
|
|
|
+"\n",
|
|
|
+"Credits & Changelog ⬇️\n"
|
|
|
+]
|
|
|
+},
|
|
|
+{
|
|
|
+"cell_type": "markdown",
|
|
|
+"metadata": {
|
|
|
+"id": "wX5omb9C7Bjz"
|
|
|
+},
|
|
|
+"source": [
|
|
|
+"Original notebook by Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). It uses either OpenAI's 256x256 unconditional ImageNet or Katherine Crowson's fine-tuned 512x512 diffusion model (https://github.com/openai/guided-diffusion), together with CLIP (https://github.com/openai/CLIP) to connect text prompts with images.\n",
|
|
|
+"\n",
|
|
|
+"Modified by Daniel Russell (https://github.com/russelldc, https://twitter.com/danielrussruss) to include (hopefully) optimal params for quick generations in 15-100 timesteps rather than 1000, as well as more robust augmentations.\n",
|
|
|
+"\n",
|
|
|
+"Further improvements from Dango233 and nsheppard helped improve the quality of diffusion in general, and especially so for shorter runs like this notebook aims to achieve.\n",
|
|
|
+"\n",
|
|
|
+"Vark added code to load in multiple Clip models at once, which all prompts are evaluated against, which may greatly improve accuracy.\n",
|
|
|
+"\n",
|
|
|
+"The latest zoom, pan, rotation, and keyframes features were taken from Chigozie Nri's VQGAN Zoom Notebook (https://github.com/chigozienri, https://twitter.com/chigozienri)\n",
|
|
|
+"\n",
|
|
|
+"Advanced DangoCutn Cutout method is also from Dango223.\n",
|
|
|
+"\n",
|
|
|
+"--\n",
|
|
|
+"\n",
|
|
|
+"Disco:\n",
|
|
|
+"\n",
|
|
|
+"Somnai (https://twitter.com/Somnai_dreams) added Diffusion Animation techniques, QoL improvements and various implementations of tech and techniques, mostly listed in the changelog below.\n",
|
|
|
+"\n",
|
|
|
+"3D animation implementation added by Adam Letts (https://twitter.com/gandamu_ml) in collaboration with Somnai.\n",
|
|
|
+"\n",
|
|
|
+"Turbo feature by Chris Allen (https://twitter.com/zippy731) "
|
|
|
+]
|
|
|
+},
|
|
|
+{
|
|
|
+"cell_type": "code",
|
|
|
+"execution_count": null,
|
|
|
+"metadata": {
|
|
|
+"id": "wDSYhyjqZQI9"
|
|
|
+},
|
|
|
+"outputs": [],
|
|
|
+"source": [
|
|
|
+"# @title Licensed under the MIT License\n",
|
|
|
+"\n",
|
|
|
+"# Copyright (c) 2021 Katherine Crowson \n",
|
|
|
+"\n",
|
|
|
+"# Permission is hereby granted, free of charge, to any person obtaining a copy\n",
|
|
|
+"# of this software and associated documentation files (the \"Software\"), to deal\n",
|
|
|
+"# in the Software without restriction, including without limitation the rights\n",
|
|
|
+"# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n",
|
|
|
+"# copies of the Software, and to permit persons to whom the Software is\n",
|
|
|
+"# furnished to do so, subject to the following conditions:\n",
|
|
|
+"\n",
|
|
|
+"# The above copyright notice and this permission notice shall be included in\n",
|
|
|
+"# all copies or substantial portions of the Software.\n",
|
|
|
+"\n",
|
|
|
+"# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n",
|
|
|
+"# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n",
|
|
|
+"# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n",
|
|
|
+"# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n",
|
|
|
+"# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n",
|
|
|
+"# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n",
|
|
|
+"# THE SOFTWARE.\n",
|
|
|
+"\n",
|
|
|
+"# --\n",
|
|
|
+"\n",
|
|
|
+"# @title Licensed under the MIT License\n",
|
|
|
+"\n",
|
|
|
+"# Copyright (c) 2021 Maxwell Ingham \n",
|
|
|
+"# Copyright (c) 2022 Adam Letts \n",
|
|
|
+"\n",
|
|
|
+"# Permission is hereby granted, free of charge, to any person obtaining a copy\n",
|
|
|
+"# of this software and associated documentation files (the \"Software\"), to deal\n",
|
|
|
+"# in the Software without restriction, including without limitation the rights\n",
|
|
|
+"# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell\n",
|
|
|
+"# copies of the Software, and to permit persons to whom the Software is\n",
|
|
|
+"# furnished to do so, subject to the following conditions:\n",
|
|
|
+"\n",
|
|
|
+"# The above copyright notice and this permission notice shall be included in\n",
|
|
|
+"# all copies or substantial portions of the Software.\n",
|
|
|
+"\n",
|
|
|
+"# THE SOFTWARE IS PROVIDED \"AS IS\", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR\n",
|
|
|
+"# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,\n",
|
|
|
+"# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE\n",
|
|
|
+"# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER\n",
|
|
|
+"# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,\n",
|
|
|
+"# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN\n",
|
|
|
+"# THE SOFTWARE."
|
|
|
+]
|
|
|
+},
|
|
|
+{
|
|
|
+"cell_type": "code",
|
|
|
+"execution_count": null,
|
|
|
+"metadata": {
|
|
|
+"cellView": "form",
|
|
|
+"id": "qFB3nwLSQI8X"
|
|
|
+},
|
|
|
+"outputs": [],
|
|
|
+"source": [
|
|
|
+"#@title <- View Changelog\n",
|
|
|
+"\n",
|
|
|
+"skip_for_run_all = True #@param {type: 'boolean'}\n",
|
|
|
+"\n",
|
|
|
+"if skip_for_run_all == False:\n",
|
|
|
+" print(\n",
|
|
|
+" '''\n",
|
|
|
+" v1 Update: Oct 29th 2021 - Somnai\n",
|
|
|
+"\n",
|
|
|
+" QoL improvements added by Somnai (@somnai_dreams), including user friendly UI, settings+prompt saving and improved google drive folder organization.\n",
|
|
|
+"\n",
|
|
|
+" v1.1 Update: Nov 13th 2021 - Somnai\n",
|
|
|
+"\n",
|
|
|
+" Now includes sizing options, intermediate saves and fixed image prompts and perlin inits. unexposed batch option since it doesn't work\n",
|
|
|
+"\n",
|
|
|
+" v2 Update: Nov 22nd 2021 - Somnai\n",
|
|
|
+"\n",
|
|
|
+" Initial addition of Katherine Crowson's Secondary Model Method (https://colab.research.google.com/drive/1mpkrhOjoyzPeSWy2r7T8EYRaU7amYOOi#scrollTo=X5gODNAMEUCR)\n",
|
|
|
+"\n",
|
|
|
+" Noticed settings were saving with the wrong name so corrected it. Let me know if you preferred the old scheme.\n",
|
|
|
+"\n",
|
|
|
+" v3 Update: Dec 24th 2021 - Somnai\n",
|
|
|
+"\n",
|
|
|
+" Implemented Dango's advanced cutout method\n",
|
|
|
+"\n",
|
|
|
+" Added SLIP models, thanks to NeuralDivergent\n",
|
|
|
+"\n",
|
|
|
+" Fixed issue with NaNs resulting in black images, with massive help and testing from @Softology\n",
|
|
|
+"\n",
|
|
|
+" Perlin now changes properly within batches (not sure where this perlin_regen code came from originally, but thank you)\n",
|
|
|
+"\n",
|
|
|
+" v4 Update: Jan 2021 - Somnai\n",
|
|
|
+"\n",
|
|
|
+" Implemented Diffusion Zooming\n",
|
|
|
+"\n",
|
|
|
+" Added Chigozie keyframing\n",
|
|
|
+"\n",
|
|
|
+" Made a bunch of edits to processes\n",
|
|
|
+" \n",
|
|
|
+" v4.1 Update: Jan 14th 2021 - Somnai\n",
|
|
|
+"\n",
|
|
|
+" Added video input mode\n",
|
|
|
+"\n",
|
|
|
+" Added license that somehow went missing\n",
|
|
|
+"\n",
|
|
|
+" Added improved prompt keyframing, fixed image_prompts and multiple prompts\n",
|
|
|
+"\n",
|
|
|
+" Improved UI\n",
|
|
|
+"\n",
|
|
|
+" Significant under the hood cleanup and improvement\n",
|
|
|
+"\n",
|
|
|
+" Refined defaults for each mode\n",
|
|
|
+"\n",
|
|
|
+" Added latent-diffusion SuperRes for sharpening\n",
|
|
|
+"\n",
|
|
|
+" Added resume run mode\n",
|
|
|
+"\n",
|
|
|
+" v4.9 Update: Feb 5th 2022 - gandamu / Adam Letts\n",
|
|
|
+"\n",
|
|
|
+" Added 3D\n",
|
|
|
+"\n",
|
|
|
+" Added brightness corrections to prevent animation from steadily going dark over time\n",
|
|
|
+"\n",
|
|
|
+" v4.91 Update: Feb 19th 2022 - gandamu / Adam Letts\n",
|
|
|
+"\n",
|
|
|
+" Cleaned up 3D implementation and made associated args accessible via Colab UI elements\n",
|
|
|
+"\n",
|
|
|
+" v4.92 Update: Feb 20th 2022 - gandamu / Adam Letts\n",
|
|
|
+"\n",
|
|
|
+" Separated transform code\n",
|
|
|
+" '''\n",
|
|
|
+" )"
|
|
|
+]
|
|
|
+},
|
|
|
+{
|
|
|
+"cell_type": "markdown",
|
|
|
+"metadata": {
|
|
|
+"id": "XTu6AjLyFQUq"
|
|
|
+},
|
|
|
+"source": [
|
|
|
+"#Tutorial"
|
|
|
+]
|
|
|
+},
|
|
|
+{
|
|
|
+"cell_type": "markdown",
|
|
|
+"metadata": {
|
|
|
+"id": "YR806W0wi3He"
|
|
|
+},
|
|
|
+"source": [
|
|
|
+"**Diffusion settings (Defaults are heavily outdated)**\n",
|
|
|
+"---\n",
|
|
|
+"\n",
|
|
|
+"This section is outdated as of v2\n",
|
|
|
+"\n",
|
|
|
+"Setting | Description | Default\n",
|
|
|
+"--- | --- | ---\n",
|
|
|
+"**Your vision:**\n",
|
|
|
+"`text_prompts` | A description of what you'd like the machine to generate. Think of it like writing the caption below your image on a website. | N/A\n",
|
|
|
+"`image_prompts` | Think of these images more as a description of their contents. | N/A\n",
|
|
|
+"**Image quality:**\n",
|
|
|
+"`clip_guidance_scale` | Controls how much the image should look like the prompt. | 1000\n",
|
|
|
+"`tv_scale` | Controls the smoothness of the final output. | 150\n",
|
|
|
+"`range_scale` | Controls how far out of range RGB values are allowed to be. | 150\n",
|
|
|
+"`sat_scale` | Controls how much saturation is allowed. From nshepperd's JAX notebook. | 0\n",
|
|
|
+"`cutn` | Controls how many crops to take from the image. | 16\n",
|
|
|
+"`cutn_batches` | Accumulate CLIP gradient from multiple batches of cuts | 2\n",
|
|
|
+"**Init settings:**\n",
|
|
|
+"`init_image` | URL or local path | None\n",
|
|
|
+"`init_scale` | This enhances the effect of the init image, a good value is 1000 | 0\n",
|
|
|
+"`skip_steps Controls the starting point along the diffusion timesteps | 0\n",
|
|
|
+"`perlin_init` | Option to start with random perlin noise | False\n",
|
|
|
+"`perlin_mode` | ('gray', 'color') | 'mixed'\n",
|
|
|
+"**Advanced:**\n",
|
|
|
+"`skip_augs` |Controls whether to skip torchvision augmentations | False\n",
|
|
|
+"`randomize_class` |Controls whether the imagenet class is randomly changed each iteration | True\n",
|
|
|
+"`clip_denoised` |Determines whether CLIP discriminates a noisy or denoised image | False\n",
|
|
|
+"`clamp_grad` |Experimental: Using adaptive clip grad in the cond_fn | True\n",
|
|
|
+"`seed` | Choose a random seed and print it at end of run for reproduction | random_seed\n",
|
|
|
+"`fuzzy_prompt` | Controls whether to add multiple noisy prompts to the prompt losses | False\n",
|
|
|
+"`rand_mag` |Controls the magnitude of the random noise | 0.1\n",
|
|
|
+"`eta` | DDIM hyperparameter | 0.5\n",
|
|
|
+"\n",
|
|
|
+"..\n",
|
|
|
+"\n",
|
|
|
+"**Model settings**\n",
|
|
|
+"---\n",
|
|
|
+"\n",
|
|
|
+"Setting | Description | Default\n",
|
|
|
+"--- | --- | ---\n",
|
|
|
+"**Diffusion:**\n",
|
|
|
+"`timestep_respacing` | Modify this value to decrease the number of timesteps. | ddim100\n",
|
|
|
+"`diffusion_steps` || 1000\n",
|
|
|
+"**Diffusion:**\n",
|
|
|
+"`clip_models` | Models of CLIP to load. Typically the more, the better but they all come at a hefty VRAM cost. | ViT-B/32, ViT-B/16, RN50x4"
|
|
|
+]
|
|
|
+},
|
|
|
+{
|
|
|
+"cell_type": "markdown",
|
|
|
+"metadata": {
|
|
|
+"id": "_9Eg9Kf5FlfK"
|
|
|
+},
|
|
|
+"source": [
|
|
|
+"# 1. Set Up"
|
|
|
+]
|
|
|
+},
|
|
|
+{
|
|
|
+"cell_type": "code",
|
|
|
+"execution_count": null,
|
|
|
+"metadata": {
|
|
|
+"cellView": "form",
|
|
|
+"id": "qZ3rNuAWAewx"
|
|
|
+},
|
|
|
+"outputs": [],
|
|
|
+"source": [
|
|
|
+"#@title 1.1 Check GPU Status\n",
|
|
|
+"!nvidia-smi -L"
|
|
|
+]
|
|
|
+},
|
|
|
+{
|
|
|
+"cell_type": "code",
|
|
|
+"execution_count": null,
|
|
|
+"metadata": {
|
|
|
+"cellView": "form",
|
|
|
+"id": "yZsjzwS0YGo6"
|
|
|
+},
|
|
|
+"outputs": [],
|
|
|
+"source": [
|
|
|
+"#@title 1.2 Prepare Folders\n",
|
|
|
+"\n",
|
|
|
+"try:\n",
|
|
|
+" from google.colab import drive\n",
|
|
|
+" print(\"Google Colab detected. Using Google Drive.\")\n",
|
|
|
+" is_colab = True\n",
|
|
|
+" #@markdown If you connect your Google Drive, you can save the final image of each run on your drive.\n",
|
|
|
+" google_drive = True #@param {type:\"boolean\"}\n",
|
|
|
+" #@markdown Click here if you'd like to save the diffusion model checkpoint file to (and/or load from) your Google Drive:\n",
|
|
|
+" save_models_to_google_drive = True #@param {type:\"boolean\"}\n",
|
|
|
+"except:\n",
|
|
|
+" is_colab = False\n",
|
|
|
+" google_drive = False\n",
|
|
|
+" save_models_to_google_drive = False\n",
|
|
|
+" print(\"Google Colab not detected.\")\n",
|
|
|
+"\n",
|
|
|
+"if is_colab:\n",
|
|
|
+" if google_drive is True:\n",
|
|
|
+" drive.mount('/content/drive')\n",
|
|
|
+" root_path = '/content/drive/MyDrive/AI/Disco_Diffusion'\n",
|
|
|
+" else:\n",
|
|
|
+" root_path = '/content'\n",
|
|
|
+"else:\n",
|
|
|
+" root_path = '.'\n",
|
|
|
+"\n",
|
|
|
+"import os\n",
|
|
|
+"from os import path\n",
|
|
|
+"#Simple create paths taken with modifications from Datamosh's Batch VQGAN+CLIP notebook\n",
|
|
|
+"def createPath(filepath):\n",
|
|
|
+" if path.exists(filepath) == False:\n",
|
|
|
+" os.makedirs(filepath)\n",
|
|
|
+" print(f'Made {filepath}')\n",
|
|
|
+" else:\n",
|
|
|
+" print(f'filepath {filepath} exists.')\n",
|
|
|
+"\n",
|
|
|
+"initDirPath = f'{root_path}/init_images'\n",
|
|
|
+"createPath(initDirPath)\n",
|
|
|
+"outDirPath = f'{root_path}/images_out'\n",
|
|
|
+"createPath(outDirPath)\n",
|
|
|
+"\n",
|
|
|
+"if is_colab:\n",
|
|
|
+" if google_drive and not save_models_to_google_drive or not google_drive:\n",
|
|
|
+" model_path = '/content/model'\n",
|
|
|
+" createPath(model_path)\n",
|
|
|
+" if google_drive and save_models_to_google_drive:\n",
|
|
|
+" model_path = f'{root_path}/model'\n",
|
|
|
+" createPath(model_path)\n",
|
|
|
+"else:\n",
|
|
|
+" model_path = f'{root_path}/model'\n",
|
|
|
+" createPath(model_path)\n",
|
|
|
+"\n",
|
|
|
+"# libraries = f'{root_path}/libraries'\n",
|
|
|
+"# createPath(libraries)"
|
|
|
+]
|
|
|
+},
|
|
|
+{
|
|
|
+"cell_type": "code",
|
|
|
+"execution_count": null,
|
|
|
+"metadata": {
|
|
|
+"cellView": "form",
|
|
|
+"id": "JmbrcrhpBPC6",
|
|
|
+"scrolled": true
|
|
|
+},
|
|
|
+"outputs": [],
|
|
|
+"source": [
|
|
|
+"#@title ### 1.3 Install and import dependencies\n",
|
|
|
+"\n",
|
|
|
+"from os.path import exists as path_exists\n",
|
|
|
+"\n",
|
|
|
+"if not is_colab:\n",
|
|
|
+" # If running locally, there's a good chance your env will need this in order to not crash upon np.matmul() or similar operations.\n",
|
|
|
+" os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'\n",
|
|
|
+"\n",
|
|
|
+"PROJECT_DIR = os.path.abspath(os.getcwd())\n",
|
|
|
+"USE_ADABINS = True\n",
|
|
|
+"\n",
|
|
|
+"if is_colab:\n",
|
|
|
+" if google_drive is not True:\n",
|
|
|
+" root_path = f'/content'\n",
|
|
|
+" model_path = '/content/models' \n",
|
|
|
+"else:\n",
|
|
|
+" root_path = f'.'\n",
|
|
|
+" model_path = f'{root_path}/model'\n",
|
|
|
+"\n",
|
|
|
+"model_256_downloaded = False\n",
|
|
|
+"model_512_downloaded = False\n",
|
|
|
+"model_secondary_downloaded = False\n",
|
|
|
+"\n",
|
|
|
+"#if is_colab:\n",
|
|
|
+"if True:\n",
|
|
|
+" !git clone https://github.com/openai/CLIP\n",
|
|
|
+" # !git clone https://github.com/facebookresearch/SLIP.git\n",
|
|
|
+" !git clone https://github.com/crowsonkb/guided-diffusion\n",
|
|
|
+" !git clone https://github.com/assafshocher/ResizeRight.git\n",
|
|
|
+" !pip install -e ./CLIP\n",
|
|
|
+" !pip install -e ./guided-diffusion\n",
|
|
|
+" !pip install lpips datetime timm\n",
|
|
|
+" !apt install imagemagick\n",
|
|
|
+" !git clone https://github.com/isl-org/MiDaS.git\n",
|
|
|
+" !git clone https://github.com/alembics/disco-diffusion.git\n",
|
|
|
+" # Rename a file to avoid a name conflict..\n",
|
|
|
+" !mv MiDaS/utils.py MiDaS/midas_utils.py\n",
|
|
|
+" !cp disco-diffusion/disco_xform_utils.py disco_xform_utils.py\n",
|
|
|
+"\n",
|
|
|
+"!mkdir model\n",
|
|
|
+"if not path_exists(f'{model_path}/dpt_large-midas-2f21e586.pt'):\n",
|
|
|
+" !wget https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt -P {model_path}\n",
|
|
|
+"\n",
|
|
|
+"import sys\n",
|
|
|
+"import torch\n",
|
|
|
+"\n",
|
|
|
+"#Install pytorch3d\n",
|
|
|
+"#if is_colab:\n",
|
|
|
+"if True:\n",
|
|
|
+" pyt_version_str=torch.__version__.split(\"+\")[0].replace(\".\", \"\")\n",
|
|
|
+" version_str=\"\".join([\n",
|
|
|
+" f\"py3{sys.version_info.minor}_cu\",\n",
|
|
|
+" torch.version.cuda.replace(\".\",\"\"),\n",
|
|
|
+" f\"_pyt{pyt_version_str}\"\n",
|
|
|
+" ])\n",
|
|
|
+" !pip install fvcore iopath\n",
|
|
|
+" !pip install --no-index --no-cache-dir pytorch3d -f https://dl.fbaipublicfiles.com/pytorch3d/packaging/wheels/{version_str}/download.html\n",
|
|
|
+"\n",
|
|
|
+"# sys.path.append('./SLIP')\n",
|
|
|
+"sys.path.append('ResizeRight')\n",
|
|
|
+"sys.path.append('MiDaS')\n",
|
|
|
+"from dataclasses import dataclass\n",
|
|
|
+"from functools import partial\n",
|
|
|
+"import cv2\n",
|
|
|
+"import pandas as pd\n",
|
|
|
+"import gc\n",
|
|
|
+"import io\n",
|
|
|
+"import math\n",
|
|
|
+"import timm\n",
|
|
|
+"from IPython import display\n",
|
|
|
+"import lpips\n",
|
|
|
+"from PIL import Image, ImageOps\n",
|
|
|
+"import requests\n",
|
|
|
+"from glob import glob\n",
|
|
|
+"import json\n",
|
|
|
+"from types import SimpleNamespace\n",
|
|
|
+"from torch import nn\n",
|
|
|
+"from torch.nn import functional as F\n",
|
|
|
+"import torchvision.transforms as T\n",
|
|
|
+"import torchvision.transforms.functional as TF\n",
|
|
|
+"from tqdm.notebook import tqdm\n",
|
|
|
+"sys.path.append('CLIP')\n",
|
|
|
+"sys.path.append('guided-diffusion')\n",
|
|
|
+"import clip\n",
|
|
|
+"from resize_right import resize\n",
|
|
|
+"# from models import SLIP_VITB16, SLIP, SLIP_VITL16\n",
|
|
|
+"from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults\n",
|
|
|
+"from datetime import datetime\n",
|
|
|
+"import numpy as np\n",
|
|
|
+"import matplotlib.pyplot as plt\n",
|
|
|
+"import random\n",
|
|
|
+"from ipywidgets import Output\n",
|
|
|
+"import hashlib\n",
|
|
|
+"\n",
|
|
|
+"#SuperRes\n",
|
|
|
+"#if is_colab:\n",
|
|
|
+"if True:\n",
|
|
|
+" !git clone https://github.com/CompVis/latent-diffusion.git\n",
|
|
|
+" !git clone https://github.com/CompVis/taming-transformers\n",
|
|
|
+" !pip install -e ./taming-transformers\n",
|
|
|
+" !pip install ipywidgets omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops wandb\n",
|
|
|
+"\n",
|
|
|
+"#SuperRes\n",
|
|
|
+"import ipywidgets as widgets\n",
|
|
|
+"import os\n",
|
|
|
+"sys.path.append(\".\")\n",
|
|
|
+"sys.path.append('taming-transformers')\n",
|
|
|
+"from taming.models import vqgan # checking correct import from taming\n",
|
|
|
+"from torchvision.datasets.utils import download_url\n",
|
|
|
+"if is_colab:\n",
|
|
|
+" %cd '/content/latent-diffusion'\n",
|
|
|
+"else:\n",
|
|
|
+" %cd 'latent-diffusion'\n",
|
|
|
+"from functools import partial\n",
|
|
|
+"from ldm.util import instantiate_from_config\n",
|
|
|
+"from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like\n",
|
|
|
+"# from ldm.models.diffusion.ddim import DDIMSampler\n",
|
|
|
+"from ldm.util import ismap\n",
|
|
|
+"if is_colab:\n",
|
|
|
+" %cd '/content'\n",
|
|
|
+" from google.colab import files\n",
|
|
|
+"else:\n",
|
|
|
+" %cd $PROJECT_DIR\n",
|
|
|
+"from IPython.display import Image as ipyimg\n",
|
|
|
+"from numpy import asarray\n",
|
|
|
+"from einops import rearrange, repeat\n",
|
|
|
+"import torch, torchvision\n",
|
|
|
+"import time\n",
|
|
|
+"from omegaconf import OmegaConf\n",
|
|
|
+"import warnings\n",
|
|
|
+"warnings.filterwarnings(\"ignore\", category=UserWarning)\n",
|
|
|
+"\n",
|
|
|
+"# AdaBins stuff\n",
|
|
|
+"if USE_ADABINS:\n",
|
|
|
+" #if is_colab:\n",
|
|
|
+" if True:\n",
|
|
|
+" !git clone https://github.com/shariqfarooq123/AdaBins.git\n",
|
|
|
+" if not path_exists(f'{model_path}/AdaBins_nyu.pt'):\n",
|
|
|
+" !wget https://cloudflare-ipfs.com/ipfs/Qmd2mMnDLWePKmgfS8m6ntAg4nhV5VkUyAydYBp8cWWeB7/AdaBins_nyu.pt -P {model_path}\n",
|
|
|
+" !mkdir pretrained\n",
|
|
|
+" !cp -P {model_path}/AdaBins_nyu.pt pretrained/AdaBins_nyu.pt\n",
|
|
|
+" sys.path.append('AdaBins')\n",
|
|
|
+" from infer import InferenceHelper\n",
|
|
|
+" MAX_ADABINS_AREA = 500000\n",
|
|
|
+"\n",
|
|
|
+"import torch\n",
|
|
|
+"DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
|
|
|
+"print('Using device:', DEVICE)\n",
|
|
|
+"device = DEVICE # At least one of the modules expects this name..\n",
|
|
|
+"\n",
|
|
|
+"if torch.cuda.get_device_capability(DEVICE) == (8,0): ## A100 fix thanks to Emad\n",
|
|
|
+" print('Disabling CUDNN for A100 gpu', file=sys.stderr)\n",
|
|
|
+" torch.backends.cudnn.enabled = False\n",
|
|
|
+" \n",
|
|
|
+"%load_ext autoreload \n",
|
|
|
+"%autoreload 2"
|
|
|
+]
|
|
|
+},
|
|
|
+{
|
|
|
+"cell_type": "code",
|
|
|
+"execution_count": null,
|
|
|
+"metadata": {},
|
|
|
+"outputs": [],
|
|
|
+"source": [
|
|
|
+"# This should print \"True\" - if not, pytorch can't see your GPU\n",
|
|
|
+"import torch\n",
|
|
|
+"torch.cuda.is_available()"
|
|
|
+]
|
|
|
+},
|
|
|
+{
|
|
|
+"cell_type": "code",
|
|
|
+"execution_count": null,
|
|
|
+"metadata": {
|
|
|
+"cellView": "form",
|
|
|
+"id": "BLk3J0h3MtON"
|
|
|
+},
|
|
|
+"outputs": [],
|
|
|
+"source": [
|
|
|
+"#@title ### 1.4 Define Midas functions\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"from midas.dpt_depth import DPTDepthModel\n",
|
|
|
+"from midas.midas_net import MidasNet\n",
|
|
|
+"from midas.midas_net_custom import MidasNet_small\n",
|
|
|
+"from midas.transforms import Resize, NormalizeImage, PrepareForNet\n",
|
|
|
+"\n",
|
|
|
+"# Initialize MiDaS depth model.\n",
|
|
|
+"# It remains resident in VRAM and likely takes around 2GB VRAM.\n",
|
|
|
+"# You could instead initialize it for each frame (and free it after each frame) to save VRAM.. but initializing it is slow.\n",
|
|
|
+"default_models = {\n",
|
|
|
+" \"midas_v21_small\": f\"{model_path}/midas_v21_small-70d6b9c8.pt\",\n",
|
|
|
+" \"midas_v21\": f\"{model_path}/midas_v21-f6b98070.pt\",\n",
|
|
|
+" \"dpt_large\": f\"{model_path}/dpt_large-midas-2f21e586.pt\",\n",
|
|
|
+" \"dpt_hybrid\": f\"{model_path}/dpt_hybrid-midas-501f0c75.pt\",\n",
|
|
|
+" \"dpt_hybrid_nyu\": f\"{model_path}/dpt_hybrid_nyu-2ce69ec7.pt\",}\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"def init_midas_depth_model(midas_model_type=\"dpt_large\", optimize=True):\n",
|
|
|
+" midas_model = None\n",
|
|
|
+" net_w = None\n",
|
|
|
+" net_h = None\n",
|
|
|
+" resize_mode = None\n",
|
|
|
+" normalization = None\n",
|
|
|
+"\n",
|
|
|
+" print(f\"Initializing MiDaS '{midas_model_type}' depth model...\")\n",
|
|
|
+" # load network\n",
|
|
|
+" midas_model_path = default_models[midas_model_type]\n",
|
|
|
+"\n",
|
|
|
+" if midas_model_type == \"dpt_large\": # DPT-Large\n",
|
|
|
+" midas_model = DPTDepthModel(\n",
|
|
|
+" path=midas_model_path,\n",
|
|
|
+" backbone=\"vitl16_384\",\n",
|
|
|
+" non_negative=True,\n",
|
|
|
+" )\n",
|
|
|
+" net_w, net_h = 384, 384\n",
|
|
|
+" resize_mode = \"minimal\"\n",
|
|
|
+" normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n",
|
|
|
+" elif midas_model_type == \"dpt_hybrid\": #DPT-Hybrid\n",
|
|
|
+" midas_model = DPTDepthModel(\n",
|
|
|
+" path=midas_model_path,\n",
|
|
|
+" backbone=\"vitb_rn50_384\",\n",
|
|
|
+" non_negative=True,\n",
|
|
|
+" )\n",
|
|
|
+" net_w, net_h = 384, 384\n",
|
|
|
+" resize_mode=\"minimal\"\n",
|
|
|
+" normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n",
|
|
|
+" elif midas_model_type == \"dpt_hybrid_nyu\": #DPT-Hybrid-NYU\n",
|
|
|
+" midas_model = DPTDepthModel(\n",
|
|
|
+" path=midas_model_path,\n",
|
|
|
+" backbone=\"vitb_rn50_384\",\n",
|
|
|
+" non_negative=True,\n",
|
|
|
+" )\n",
|
|
|
+" net_w, net_h = 384, 384\n",
|
|
|
+" resize_mode=\"minimal\"\n",
|
|
|
+" normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n",
|
|
|
+" elif midas_model_type == \"midas_v21\":\n",
|
|
|
+" midas_model = MidasNet(midas_model_path, non_negative=True)\n",
|
|
|
+" net_w, net_h = 384, 384\n",
|
|
|
+" resize_mode=\"upper_bound\"\n",
|
|
|
+" normalization = NormalizeImage(\n",
|
|
|
+" mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]\n",
|
|
|
+" )\n",
|
|
|
+" elif midas_model_type == \"midas_v21_small\":\n",
|
|
|
+" midas_model = MidasNet_small(midas_model_path, features=64, backbone=\"efficientnet_lite3\", exportable=True, non_negative=True, blocks={'expand': True})\n",
|
|
|
+" net_w, net_h = 256, 256\n",
|
|
|
+" resize_mode=\"upper_bound\"\n",
|
|
|
+" normalization = NormalizeImage(\n",
|
|
|
+" mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]\n",
|
|
|
+" )\n",
|
|
|
+" else:\n",
|
|
|
+" print(f\"midas_model_type '{midas_model_type}' not implemented\")\n",
|
|
|
+" assert False\n",
|
|
|
+"\n",
|
|
|
+" midas_transform = T.Compose(\n",
|
|
|
+" [\n",
|
|
|
+" Resize(\n",
|
|
|
+" net_w,\n",
|
|
|
+" net_h,\n",
|
|
|
+" resize_target=None,\n",
|
|
|
+" keep_aspect_ratio=True,\n",
|
|
|
+" ensure_multiple_of=32,\n",
|
|
|
+" resize_method=resize_mode,\n",
|
|
|
+" image_interpolation_method=cv2.INTER_CUBIC,\n",
|
|
|
+" ),\n",
|
|
|
+" normalization,\n",
|
|
|
+" PrepareForNet(),\n",
|
|
|
+" ]\n",
|
|
|
+" )\n",
|
|
|
+"\n",
|
|
|
+" midas_model.eval()\n",
|
|
|
+" \n",
|
|
|
+" if optimize==True:\n",
|
|
|
+" if DEVICE == torch.device(\"cuda\"):\n",
|
|
|
+" midas_model = midas_model.to(memory_format=torch.channels_last) \n",
|
|
|
+" midas_model = midas_model.half()\n",
|
|
|
+"\n",
|
|
|
+" midas_model.to(DEVICE)\n",
|
|
|
+"\n",
|
|
|
+" print(f\"MiDaS '{midas_model_type}' depth model initialized.\")\n",
|
|
|
+" return midas_model, midas_transform, net_w, net_h, resize_mode, normalization\n"
|
|
|
+]
|
|
|
+},
|
|
|
+{
|
|
|
+"cell_type": "code",
|
|
|
+"execution_count": null,
|
|
|
+"metadata": {
|
|
|
+"cellView": "form",
|
|
|
+"id": "FpZczxnOnPIU"
|
|
|
+},
|
|
|
+"outputs": [],
|
|
|
+"source": [
|
|
|
+"#@title 1.5 Define necessary functions\n",
|
|
|
+"\n",
|
|
|
+"# https://gist.github.com/adefossez/0646dbe9ed4005480a2407c62aac8869\n",
|
|
|
+"\n",
|
|
|
+"import pytorch3d.transforms as p3dT\n",
|
|
|
+"import disco_xform_utils as dxf\n",
|
|
|
+"\n",
|
|
|
+"def interp(t):\n",
|
|
|
+" return 3 * t**2 - 2 * t ** 3\n",
|
|
|
+"\n",
|
|
|
+"def perlin(width, height, scale=10, device=None):\n",
|
|
|
+" gx, gy = torch.randn(2, width + 1, height + 1, 1, 1, device=device)\n",
|
|
|
+" xs = torch.linspace(0, 1, scale + 1)[:-1, None].to(device)\n",
|
|
|
+" ys = torch.linspace(0, 1, scale + 1)[None, :-1].to(device)\n",
|
|
|
+" wx = 1 - interp(xs)\n",
|
|
|
+" wy = 1 - interp(ys)\n",
|
|
|
+" dots = 0\n",
|
|
|
+" dots += wx * wy * (gx[:-1, :-1] * xs + gy[:-1, :-1] * ys)\n",
|
|
|
+" dots += (1 - wx) * wy * (-gx[1:, :-1] * (1 - xs) + gy[1:, :-1] * ys)\n",
|
|
|
+" dots += wx * (1 - wy) * (gx[:-1, 1:] * xs - gy[:-1, 1:] * (1 - ys))\n",
|
|
|
+" dots += (1 - wx) * (1 - wy) * (-gx[1:, 1:] * (1 - xs) - gy[1:, 1:] * (1 - ys))\n",
|
|
|
+" return dots.permute(0, 2, 1, 3).contiguous().view(width * scale, height * scale)\n",
|
|
|
+"\n",
|
|
|
+"def perlin_ms(octaves, width, height, grayscale, device=device):\n",
|
|
|
+" out_array = [0.5] if grayscale else [0.5, 0.5, 0.5]\n",
|
|
|
+" # out_array = [0.0] if grayscale else [0.0, 0.0, 0.0]\n",
|
|
|
+" for i in range(1 if grayscale else 3):\n",
|
|
|
+" scale = 2 ** len(octaves)\n",
|
|
|
+" oct_width = width\n",
|
|
|
+" oct_height = height\n",
|
|
|
+" for oct in octaves:\n",
|
|
|
+" p = perlin(oct_width, oct_height, scale, device)\n",
|
|
|
+" out_array[i] += p * oct\n",
|
|
|
+" scale //= 2\n",
|
|
|
+" oct_width *= 2\n",
|
|
|
+" oct_height *= 2\n",
|
|
|
+" return torch.cat(out_array)\n",
|
|
|
+"\n",
|
|
|
+"def create_perlin_noise(octaves=[1, 1, 1, 1], width=2, height=2, grayscale=True):\n",
|
|
|
+" out = perlin_ms(octaves, width, height, grayscale)\n",
|
|
|
+" if grayscale:\n",
|
|
|
+" out = TF.resize(size=(side_y, side_x), img=out.unsqueeze(0))\n",
|
|
|
+" out = TF.to_pil_image(out.clamp(0, 1)).convert('RGB')\n",
|
|
|
+" else:\n",
|
|
|
+" out = out.reshape(-1, 3, out.shape[0]//3, out.shape[1])\n",
|
|
|
+" out = TF.resize(size=(side_y, side_x), img=out)\n",
|
|
|
+" out = TF.to_pil_image(out.clamp(0, 1).squeeze())\n",
|
|
|
+"\n",
|
|
|
+" out = ImageOps.autocontrast(out)\n",
|
|
|
+" return out\n",
|
|
|
+"\n",
|
|
|
+"def regen_perlin():\n",
|
|
|
+" if perlin_mode == 'color':\n",
|
|
|
+" init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)\n",
|
|
|
+" init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, False)\n",
|
|
|
+" elif perlin_mode == 'gray':\n",
|
|
|
+" init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, True)\n",
|
|
|
+" init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)\n",
|
|
|
+" else:\n",
|
|
|
+" init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)\n",
|
|
|
+" init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)\n",
|
|
|
+"\n",
|
|
|
+" init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device).unsqueeze(0).mul(2).sub(1)\n",
|
|
|
+" del init2\n",
|
|
|
+" return init.expand(batch_size, -1, -1, -1)\n",
|
|
|
+"\n",
|
|
|
+"def fetch(url_or_path):\n",
|
|
|
+" if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):\n",
|
|
|
+" r = requests.get(url_or_path)\n",
|
|
|
+" r.raise_for_status()\n",
|
|
|
+" fd = io.BytesIO()\n",
|
|
|
+" fd.write(r.content)\n",
|
|
|
+" fd.seek(0)\n",
|
|
|
+" return fd\n",
|
|
|
+" return open(url_or_path, 'rb')\n",
|
|
|
+"\n",
|
|
|
+"def read_image_workaround(path):\n",
|
|
|
+" \"\"\"OpenCV reads images as BGR, Pillow saves them as RGB. Work around\n",
|
|
|
+" this incompatibility to avoid colour inversions.\"\"\"\n",
|
|
|
+" im_tmp = cv2.imread(path)\n",
|
|
|
+" return cv2.cvtColor(im_tmp, cv2.COLOR_BGR2RGB)\n",
|
|
|
+"\n",
|
|
|
+"def parse_prompt(prompt):\n",
|
|
|
+" if prompt.startswith('http://') or prompt.startswith('https://'):\n",
|
|
|
+" vals = prompt.rsplit(':', 2)\n",
|
|
|
+" vals = [vals[0] + ':' + vals[1], *vals[2:]]\n",
|
|
|
+" else:\n",
|
|
|
+" vals = prompt.rsplit(':', 1)\n",
|
|
|
+" vals = vals + ['', '1'][len(vals):]\n",
|
|
|
+" return vals[0], float(vals[1])\n",
|
|
|
+"\n",
|
|
|
+"def sinc(x):\n",
|
|
|
+" return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))\n",
|
|
|
+"\n",
|
|
|
+"def lanczos(x, a):\n",
|
|
|
+" cond = torch.logical_and(-a < x, x < a)\n",
|
|
|
+" out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))\n",
|
|
|
+" return out / out.sum()\n",
|
|
|
+"\n",
|
|
|
+"def ramp(ratio, width):\n",
|
|
|
+" n = math.ceil(width / ratio + 1)\n",
|
|
|
+" out = torch.empty([n])\n",
|
|
|
+" cur = 0\n",
|
|
|
+" for i in range(out.shape[0]):\n",
|
|
|
+" out[i] = cur\n",
|
|
|
+" cur += ratio\n",
|
|
|
+" return torch.cat([-out[1:].flip([0]), out])[1:-1]\n",
|
|
|
+"\n",
|
|
|
+"def resample(input, size, align_corners=True):\n",
|
|
|
+" n, c, h, w = input.shape\n",
|
|
|
+" dh, dw = size\n",
|
|
|
+"\n",
|
|
|
+" input = input.reshape([n * c, 1, h, w])\n",
|
|
|
+"\n",
|
|
|
+" if dh < h:\n",
|
|
|
+" kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)\n",
|
|
|
+" pad_h = (kernel_h.shape[0] - 1) // 2\n",
|
|
|
+" input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')\n",
|
|
|
+" input = F.conv2d(input, kernel_h[None, None, :, None])\n",
|
|
|
+"\n",
|
|
|
+" if dw < w:\n",
|
|
|
+" kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)\n",
|
|
|
+" pad_w = (kernel_w.shape[0] - 1) // 2\n",
|
|
|
+" input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')\n",
|
|
|
+" input = F.conv2d(input, kernel_w[None, None, None, :])\n",
|
|
|
+"\n",
|
|
|
+" input = input.reshape([n, c, h, w])\n",
|
|
|
+" return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)\n",
|
|
|
+"\n",
|
|
|
+"class MakeCutouts(nn.Module):\n",
|
|
|
+" def __init__(self, cut_size, cutn, skip_augs=False):\n",
|
|
|
+" super().__init__()\n",
|
|
|
+" self.cut_size = cut_size\n",
|
|
|
+" self.cutn = cutn\n",
|
|
|
+" self.skip_augs = skip_augs\n",
|
|
|
+" self.augs = T.Compose([\n",
|
|
|
+" T.RandomHorizontalFlip(p=0.5),\n",
|
|
|
+" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
+" T.RandomAffine(degrees=15, translate=(0.1, 0.1)),\n",
|
|
|
+" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
+" T.RandomPerspective(distortion_scale=0.4, p=0.7),\n",
|
|
|
+" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
+" T.RandomGrayscale(p=0.15),\n",
|
|
|
+" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
+" # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),\n",
|
|
|
+" ])\n",
|
|
|
+"\n",
|
|
|
+" def forward(self, input):\n",
|
|
|
+" input = T.Pad(input.shape[2]//4, fill=0)(input)\n",
|
|
|
+" sideY, sideX = input.shape[2:4]\n",
|
|
|
+" max_size = min(sideX, sideY)\n",
|
|
|
+"\n",
|
|
|
+" cutouts = []\n",
|
|
|
+" for ch in range(self.cutn):\n",
|
|
|
+" if ch > self.cutn - self.cutn//4:\n",
|
|
|
+" cutout = input.clone()\n",
|
|
|
+" else:\n",
|
|
|
+" size = int(max_size * torch.zeros(1,).normal_(mean=.8, std=.3).clip(float(self.cut_size/max_size), 1.))\n",
|
|
|
+" offsetx = torch.randint(0, abs(sideX - size + 1), ())\n",
|
|
|
+" offsety = torch.randint(0, abs(sideY - size + 1), ())\n",
|
|
|
+" cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]\n",
|
|
|
+"\n",
|
|
|
+" if not self.skip_augs:\n",
|
|
|
+" cutout = self.augs(cutout)\n",
|
|
|
+" cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))\n",
|
|
|
+" del cutout\n",
|
|
|
+"\n",
|
|
|
+" cutouts = torch.cat(cutouts, dim=0)\n",
|
|
|
+" return cutouts\n",
|
|
|
+"\n",
|
|
|
+"cutout_debug = False\n",
|
|
|
+"padargs = {}\n",
|
|
|
+"\n",
|
|
|
+"class MakeCutoutsDango(nn.Module):\n",
|
|
|
+" def __init__(self, cut_size,\n",
|
|
|
+" Overview=4, \n",
|
|
|
+" InnerCrop = 0, IC_Size_Pow=0.5, IC_Grey_P = 0.2\n",
|
|
|
+" ):\n",
|
|
|
+" super().__init__()\n",
|
|
|
+" self.cut_size = cut_size\n",
|
|
|
+" self.Overview = Overview\n",
|
|
|
+" self.InnerCrop = InnerCrop\n",
|
|
|
+" self.IC_Size_Pow = IC_Size_Pow\n",
|
|
|
+" self.IC_Grey_P = IC_Grey_P\n",
|
|
|
+" if args.animation_mode == 'None':\n",
|
|
|
+" self.augs = T.Compose([\n",
|
|
|
+" T.RandomHorizontalFlip(p=0.5),\n",
|
|
|
+" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
+" T.RandomAffine(degrees=10, translate=(0.05, 0.05), interpolation = T.InterpolationMode.BILINEAR),\n",
|
|
|
+" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
+" T.RandomGrayscale(p=0.1),\n",
|
|
|
+" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
+" T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),\n",
|
|
|
+" ])\n",
|
|
|
+" elif args.animation_mode == 'Video Input':\n",
|
|
|
+" self.augs = T.Compose([\n",
|
|
|
+" T.RandomHorizontalFlip(p=0.5),\n",
|
|
|
+" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
+" T.RandomAffine(degrees=15, translate=(0.1, 0.1)),\n",
|
|
|
+" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
+" T.RandomPerspective(distortion_scale=0.4, p=0.7),\n",
|
|
|
+" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
+" T.RandomGrayscale(p=0.15),\n",
|
|
|
+" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
+" # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),\n",
|
|
|
+" ])\n",
|
|
|
+" elif args.animation_mode == '2D' or args.animation_mode == '3D':\n",
|
|
|
+" self.augs = T.Compose([\n",
|
|
|
+" T.RandomHorizontalFlip(p=0.4),\n",
|
|
|
+" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
+" T.RandomAffine(degrees=10, translate=(0.05, 0.05), interpolation = T.InterpolationMode.BILINEAR),\n",
|
|
|
+" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
+" T.RandomGrayscale(p=0.1),\n",
|
|
|
+" T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),\n",
|
|
|
+" T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.3),\n",
|
|
|
+" ])\n",
|
|
|
+" \n",
|
|
|
+"\n",
|
|
|
+" def forward(self, input):\n",
|
|
|
+" cutouts = []\n",
|
|
|
+" gray = T.Grayscale(3)\n",
|
|
|
+" sideY, sideX = input.shape[2:4]\n",
|
|
|
+" max_size = min(sideX, sideY)\n",
|
|
|
+" min_size = min(sideX, sideY, self.cut_size)\n",
|
|
|
+" l_size = max(sideX, sideY)\n",
|
|
|
+" output_shape = [1,3,self.cut_size,self.cut_size] \n",
|
|
|
+" output_shape_2 = [1,3,self.cut_size+2,self.cut_size+2]\n",
|
|
|
+" pad_input = F.pad(input,((sideY-max_size)//2,(sideY-max_size)//2,(sideX-max_size)//2,(sideX-max_size)//2), **padargs)\n",
|
|
|
+" cutout = resize(pad_input, out_shape=output_shape)\n",
|
|
|
+"\n",
|
|
|
+" if self.Overview>0:\n",
|
|
|
+" if self.Overview<=4:\n",
|
|
|
+" if self.Overview>=1:\n",
|
|
|
+" cutouts.append(cutout)\n",
|
|
|
+" if self.Overview>=2:\n",
|
|
|
+" cutouts.append(gray(cutout))\n",
|
|
|
+" if self.Overview>=3:\n",
|
|
|
+" cutouts.append(TF.hflip(cutout))\n",
|
|
|
+" if self.Overview==4:\n",
|
|
|
+" cutouts.append(gray(TF.hflip(cutout)))\n",
|
|
|
+" else:\n",
|
|
|
+" cutout = resize(pad_input, out_shape=output_shape)\n",
|
|
|
+" for _ in range(self.Overview):\n",
|
|
|
+" cutouts.append(cutout)\n",
|
|
|
+"\n",
|
|
|
+" if cutout_debug:\n",
|
|
|
+" if is_colab:\n",
|
|
|
+" TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save(\"/content/cutout_overview0.jpg\",quality=99)\n",
|
|
|
+" else:\n",
|
|
|
+" TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save(\"cutout_overview0.jpg\",quality=99)\n",
|
|
|
+"\n",
|
|
|
+" \n",
|
|
|
+" if self.InnerCrop >0:\n",
|
|
|
+" for i in range(self.InnerCrop):\n",
|
|
|
+" size = int(torch.rand([])**self.IC_Size_Pow * (max_size - min_size) + min_size)\n",
|
|
|
+" offsetx = torch.randint(0, sideX - size + 1, ())\n",
|
|
|
+" offsety = torch.randint(0, sideY - size + 1, ())\n",
|
|
|
+" cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]\n",
|
|
|
+" if i <= int(self.IC_Grey_P * self.InnerCrop):\n",
|
|
|
+" cutout = gray(cutout)\n",
|
|
|
+" cutout = resize(cutout, out_shape=output_shape)\n",
|
|
|
+" cutouts.append(cutout)\n",
|
|
|
+" if cutout_debug:\n",
|
|
|
+" if is_colab:\n",
|
|
|
+" TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save(\"/content/cutout_InnerCrop.jpg\",quality=99)\n",
|
|
|
+" else:\n",
|
|
|
+" TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save(\"cutout_InnerCrop.jpg\",quality=99)\n",
|
|
|
+" cutouts = torch.cat(cutouts)\n",
|
|
|
+" if skip_augs is not True: cutouts=self.augs(cutouts)\n",
|
|
|
+" return cutouts\n",
|
|
|
+"\n",
|
|
|
+"def spherical_dist_loss(x, y):\n",
|
|
|
+" x = F.normalize(x, dim=-1)\n",
|
|
|
+" y = F.normalize(y, dim=-1)\n",
|
|
|
+" return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2) \n",
|
|
|
+"\n",
|
|
|
+"def tv_loss(input):\n",
|
|
|
+" \"\"\"L2 total variation loss, as in Mahendran et al.\"\"\"\n",
|
|
|
+" input = F.pad(input, (0, 1, 0, 1), 'replicate')\n",
|
|
|
+" x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]\n",
|
|
|
+" y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]\n",
|
|
|
+" return (x_diff**2 + y_diff**2).mean([1, 2, 3])\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"def range_loss(input):\n",
|
|
|
+" return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])\n",
|
|
|
+"\n",
|
|
|
+"stop_on_next_loop = False # Make sure GPU memory doesn't get corrupted from cancelling the run mid-way through, allow a full frame to complete\n",
|
|
|
+"\n",
|
|
|
+"def do_run():\n",
|
|
|
+" seed = args.seed\n",
|
|
|
+" print(range(args.start_frame, args.max_frames))\n",
|
|
|
+"\n",
|
|
|
+" if (args.animation_mode == \"3D\") and (args.midas_weight > 0.0):\n",
|
|
|
+" midas_model, midas_transform, midas_net_w, midas_net_h, midas_resize_mode, midas_normalization = init_midas_depth_model(args.midas_depth_model)\n",
|
|
|
+" for frame_num in range(args.start_frame, args.max_frames):\n",
|
|
|
+" if stop_on_next_loop:\n",
|
|
|
+" break\n",
|
|
|
+" \n",
|
|
|
+" display.clear_output(wait=True)\n",
|
|
|
+"\n",
|
|
|
+" # Print Frame progress if animation mode is on\n",
|
|
|
+" if args.animation_mode != \"None\":\n",
|
|
|
+" batchBar = tqdm(range(args.max_frames), desc =\"Frames\")\n",
|
|
|
+" batchBar.n = frame_num\n",
|
|
|
+" batchBar.refresh()\n",
|
|
|
+"\n",
|
|
|
+" \n",
|
|
|
+" # Inits if not video frames\n",
|
|
|
+" if args.animation_mode != \"Video Input\":\n",
|
|
|
+" if args.init_image == '':\n",
|
|
|
+" init_image = None\n",
|
|
|
+" else:\n",
|
|
|
+" init_image = args.init_image\n",
|
|
|
+" init_scale = args.init_scale\n",
|
|
|
+" skip_steps = args.skip_steps\n",
|
|
|
+"\n",
|
|
|
+" if args.animation_mode == \"2D\":\n",
|
|
|
+" if args.key_frames:\n",
|
|
|
+" angle = args.angle_series[frame_num]\n",
|
|
|
+" zoom = args.zoom_series[frame_num]\n",
|
|
|
+" translation_x = args.translation_x_series[frame_num]\n",
|
|
|
+" translation_y = args.translation_y_series[frame_num]\n",
|
|
|
+" print(\n",
|
|
|
+" f'angle: {angle}',\n",
|
|
|
+" f'zoom: {zoom}',\n",
|
|
|
+" f'translation_x: {translation_x}',\n",
|
|
|
+" f'translation_y: {translation_y}',\n",
|
|
|
+" )\n",
|
|
|
+" \n",
|
|
|
+" if frame_num > 0:\n",
|
|
|
+" seed = seed + 1 \n",
|
|
|
+" if resume_run and frame_num == start_frame:\n",
|
|
|
+" img_0 = cv2.imread(batchFolder+f\"/{batch_name}({batchNum})_{start_frame-1:04}.png\")\n",
|
|
|
+" else:\n",
|
|
|
+" img_0 = cv2.imread('prevFrame.png')\n",
|
|
|
+" center = (1*img_0.shape[1]//2, 1*img_0.shape[0]//2)\n",
|
|
|
+" trans_mat = np.float32(\n",
|
|
|
+" [[1, 0, translation_x],\n",
|
|
|
+" [0, 1, translation_y]]\n",
|
|
|
+" )\n",
|
|
|
+" rot_mat = cv2.getRotationMatrix2D( center, angle, zoom )\n",
|
|
|
+" trans_mat = np.vstack([trans_mat, [0,0,1]])\n",
|
|
|
+" rot_mat = np.vstack([rot_mat, [0,0,1]])\n",
|
|
|
+" transformation_matrix = np.matmul(rot_mat, trans_mat)\n",
|
|
|
+" img_0 = cv2.warpPerspective(\n",
|
|
|
+" img_0,\n",
|
|
|
+" transformation_matrix,\n",
|
|
|
+" (img_0.shape[1], img_0.shape[0]),\n",
|
|
|
+" borderMode=cv2.BORDER_WRAP\n",
|
|
|
+" )\n",
|
|
|
+"\n",
|
|
|
+" cv2.imwrite('prevFrameScaled.png', img_0)\n",
|
|
|
+" init_image = 'prevFrameScaled.png'\n",
|
|
|
+" init_scale = args.frames_scale\n",
|
|
|
+" skip_steps = args.calc_frames_skip_steps\n",
|
|
|
+"\n",
|
|
|
+" if args.animation_mode == \"3D\":\n",
|
|
|
+" if args.key_frames:\n",
|
|
|
+" angle = args.angle_series[frame_num]\n",
|
|
|
+" #zoom = args.zoom_series[frame_num]\n",
|
|
|
+" translation_x = args.translation_x_series[frame_num]\n",
|
|
|
+" translation_y = args.translation_y_series[frame_num]\n",
|
|
|
+" translation_z = args.translation_z_series[frame_num]\n",
|
|
|
+" rotation_3d_x = args.rotation_3d_x_series[frame_num]\n",
|
|
|
+" rotation_3d_y = args.rotation_3d_y_series[frame_num]\n",
|
|
|
+" rotation_3d_z = args.rotation_3d_z_series[frame_num]\n",
|
|
|
+" print(\n",
|
|
|
+" f'angle: {angle}',\n",
|
|
|
+" #f'zoom: {zoom}',\n",
|
|
|
+" f'translation_x: {translation_x}',\n",
|
|
|
+" f'translation_y: {translation_y}',\n",
|
|
|
+" f'translation_z: {translation_z}',\n",
|
|
|
+" f'rotation_3d_x: {rotation_3d_x}',\n",
|
|
|
+" f'rotation_3d_y: {rotation_3d_y}',\n",
|
|
|
+" f'rotation_3d_z: {rotation_3d_z}',\n",
|
|
|
+" )\n",
|
|
|
+"\n",
|
|
|
+" if frame_num > 0:\n",
|
|
|
+" seed = seed + 1 \n",
|
|
|
+" ### Turbo mode prep\n",
|
|
|
+" turbo_prevScaled_path = batchFolder+f\"/turbo/prevFrameScaled.png\" if is_colab else 'prevFrameScaled.png'\n",
|
|
|
+" turbo_oldScaled_path = batchFolder+f\"/turbo/oldFrameScaled.png\" if is_colab else 'oldFrameScaled.png'\n",
|
|
|
+" turbo_preroll = 10 # frames\n",
|
|
|
+" non_turbo_prevFrm = '/content/prevFrame.png' if is_colab else 'prevFrame.png' \n",
|
|
|
+" if resume_run and frame_num == start_frame:\n",
|
|
|
+" img_filepath = batchFolder+f\"/{batch_name}({batchNum})_{start_frame-1:04}.png\"\n",
|
|
|
+" if turbo_mode == True and frame_num > turbo_preroll:\n",
|
|
|
+" img_filepath = turbo_prevScaled_path\n",
|
|
|
+" else:\n",
|
|
|
+" img_filepath = '/content/prevFrame.png' if is_colab else 'prevFrame.png'\n",
|
|
|
+" \n",
|
|
|
+" \n",
|
|
|
+" #warp prior frame\n",
|
|
|
+" trans_scale = 1.0/200.0\n",
|
|
|
+" translate_xyz = [-translation_x*trans_scale, translation_y*trans_scale, -translation_z*trans_scale]\n",
|
|
|
+" rotate_xyz = [rotation_3d_x, rotation_3d_y, rotation_3d_z]\n",
|
|
|
+" print('translation:',translate_xyz)\n",
|
|
|
+" print('rotation:',rotate_xyz)\n",
|
|
|
+" rot_mat = p3dT.euler_angles_to_matrix(torch.tensor(rotate_xyz, device=device), \"XYZ\").unsqueeze(0)\n",
|
|
|
+" print(\"rot_mat: \" + str(rot_mat))\n",
|
|
|
+" next_step_pil = dxf.transform_image_3d(img_filepath, midas_model, midas_transform, DEVICE,\n",
|
|
|
+" rot_mat, translate_xyz, args.near_plane, args.far_plane,\n",
|
|
|
+" args.fov, padding_mode=args.padding_mode,\n",
|
|
|
+" sampling_mode=args.sampling_mode, midas_weight=args.midas_weight)\n",
|
|
|
+" next_step_pil.save('prevFrameScaled.png')\n",
|
|
|
+" if turbo_mode == True:\n",
|
|
|
+" next_step_pil.save(turbo_prevScaled_path)#stash for turbo\n",
|
|
|
+" turbo_blend = False # default for non-turbo frame saving\n",
|
|
|
+" if turbo_mode == True and frame_num == turbo_preroll: #start tracking oldframe\n",
|
|
|
+" next_step_pil.save(turbo_oldScaled_path)#stash for later blending\n",
|
|
|
+" if turbo_mode == True and frame_num > turbo_preroll:\n",
|
|
|
+" \n",
|
|
|
+" #set up 2 warped image sequences, old & new, to blend toward new diff image\n",
|
|
|
+" old_frame = dxf.transform_image_3d(turbo_oldScaled_path, midas_model, midas_transform, DEVICE,\n",
|
|
|
+" rot_mat, translate_xyz, args.near_plane, args.far_plane,\n",
|
|
|
+" args.fov, padding_mode=args.padding_mode,\n",
|
|
|
+" sampling_mode=args.sampling_mode, midas_weight=args.midas_weight)\n",
|
|
|
+" old_frame.save(turbo_oldScaled_path)\n",
|
|
|
+" if frame_num % int(turbo_steps) != 0: \n",
|
|
|
+" print('turbo skip this frame: skipping clip diffusion steps')\n",
|
|
|
+" filename = f'{args.batch_name}({args.batchNum})_{frame_num:04}.png'\n",
|
|
|
+" blend_factor = ((frame_num % int(turbo_steps))+1)/int(turbo_steps)\n",
|
|
|
+" print('turbo skip this frame: skipping clip diffusion steps and saving blended frame')\n",
|
|
|
+" newWarpedImg = cv2.imread(turbo_prevScaled_path)#this is already updated..\n",
|
|
|
+" oldWarpedImg = cv2.imread(turbo_oldScaled_path)\n",
|
|
|
+" blendedImage = cv2.addWeighted(newWarpedImg, blend_factor, oldWarpedImg,1-blend_factor, 0.0)\n",
|
|
|
+" cv2.imwrite(f'{batchFolder}/{filename}',blendedImage)\n",
|
|
|
+" \n",
|
|
|
+" next_step_pil.save(f'{non_turbo_prevFrm}') # save it also as prev_frame to feed next iteration\n",
|
|
|
+" #turbo_blend = False\n",
|
|
|
+" continue # done. exit frame loop\n",
|
|
|
+" else:\n",
|
|
|
+" #if not a skip frame, will run diff and need to blend.\n",
|
|
|
+" oldWarpedImg = cv2.imread(turbo_prevScaled_path)#swap old img\n",
|
|
|
+" cv2.imwrite(turbo_oldScaled_path ,oldWarpedImg)#swap in for blending later \n",
|
|
|
+" #turbo_blend = True # flag to blend frames after diff generated...\n",
|
|
|
+" print('clip/diff this frame - generate clip diff image')\n",
|
|
|
+" init_image = 'prevFrameScaled.png'\n",
|
|
|
+" init_scale = args.frames_scale\n",
|
|
|
+" skip_steps = args.calc_frames_skip_steps\n",
|
|
|
+"\n",
|
|
|
+" if args.animation_mode == \"Video Input\":\n",
|
|
|
+" seed = seed + 1 \n",
|
|
|
+" init_image = f'{videoFramesFolder}/{frame_num+1:04}.jpg'\n",
|
|
|
+" init_scale = args.frames_scale\n",
|
|
|
+" skip_steps = args.calc_frames_skip_steps\n",
|
|
|
+"\n",
|
|
|
+" loss_values = []\n",
|
|
|
+" \n",
|
|
|
+" if seed is not None:\n",
|
|
|
+" np.random.seed(seed)\n",
|
|
|
+" random.seed(seed)\n",
|
|
|
+" torch.manual_seed(seed)\n",
|
|
|
+" torch.cuda.manual_seed_all(seed)\n",
|
|
|
+" torch.backends.cudnn.deterministic = True\n",
|
|
|
+" \n",
|
|
|
+" target_embeds, weights = [], []\n",
|
|
|
+" \n",
|
|
|
+" if args.prompts_series is not None and frame_num >= len(args.prompts_series):\n",
|
|
|
+" frame_prompt = args.prompts_series[-1]\n",
|
|
|
+" elif args.prompts_series is not None:\n",
|
|
|
+" frame_prompt = args.prompts_series[frame_num]\n",
|
|
|
+" else:\n",
|
|
|
+" frame_prompt = []\n",
|
|
|
+" \n",
|
|
|
+" print(args.image_prompts_series)\n",
|
|
|
+" if args.image_prompts_series is not None and frame_num >= len(args.image_prompts_series):\n",
|
|
|
+" image_prompt = args.image_prompts_series[-1]\n",
|
|
|
+" elif args.image_prompts_series is not None:\n",
|
|
|
+" image_prompt = args.image_prompts_series[frame_num]\n",
|
|
|
+" else:\n",
|
|
|
+" image_prompt = []\n",
|
|
|
+"\n",
|
|
|
+" print(f'Frame Prompt: {frame_prompt}')\n",
|
|
|
+"\n",
|
|
|
+" model_stats = []\n",
|
|
|
+" for clip_model in clip_models:\n",
|
|
|
+" cutn = 16\n",
|
|
|
+" model_stat = {\"clip_model\":None,\"target_embeds\":[],\"make_cutouts\":None,\"weights\":[]}\n",
|
|
|
+" model_stat[\"clip_model\"] = clip_model\n",
|
|
|
+" \n",
|
|
|
+" \n",
|
|
|
+" for prompt in frame_prompt:\n",
|
|
|
+" txt, weight = parse_prompt(prompt)\n",
|
|
|
+" txt = clip_model.encode_text(clip.tokenize(prompt).to(device)).float()\n",
|
|
|
+" \n",
|
|
|
+" if args.fuzzy_prompt:\n",
|
|
|
+" for i in range(25):\n",
|
|
|
+" model_stat[\"target_embeds\"].append((txt + torch.randn(txt.shape).cuda() * args.rand_mag).clamp(0,1))\n",
|
|
|
+" model_stat[\"weights\"].append(weight)\n",
|
|
|
+" else:\n",
|
|
|
+" model_stat[\"target_embeds\"].append(txt)\n",
|
|
|
+" model_stat[\"weights\"].append(weight)\n",
|
|
|
+" \n",
|
|
|
+" if image_prompt:\n",
|
|
|
+" model_stat[\"make_cutouts\"] = MakeCutouts(clip_model.visual.input_resolution, cutn, skip_augs=skip_augs) \n",
|
|
|
+" for prompt in image_prompt:\n",
|
|
|
+" path, weight = parse_prompt(prompt)\n",
|
|
|
+" img = Image.open(fetch(path)).convert('RGB')\n",
|
|
|
+" img = TF.resize(img, min(side_x, side_y, *img.size), T.InterpolationMode.LANCZOS)\n",
|
|
|
+" batch = model_stat[\"make_cutouts\"](TF.to_tensor(img).to(device).unsqueeze(0).mul(2).sub(1))\n",
|
|
|
+" embed = clip_model.encode_image(normalize(batch)).float()\n",
|
|
|
+" if fuzzy_prompt:\n",
|
|
|
+" for i in range(25):\n",
|
|
|
+" model_stat[\"target_embeds\"].append((embed + torch.randn(embed.shape).cuda() * rand_mag).clamp(0,1))\n",
|
|
|
+" weights.extend([weight / cutn] * cutn)\n",
|
|
|
+" else:\n",
|
|
|
+" model_stat[\"target_embeds\"].append(embed)\n",
|
|
|
+" model_stat[\"weights\"].extend([weight / cutn] * cutn)\n",
|
|
|
+" \n",
|
|
|
+" model_stat[\"target_embeds\"] = torch.cat(model_stat[\"target_embeds\"])\n",
|
|
|
+" model_stat[\"weights\"] = torch.tensor(model_stat[\"weights\"], device=device)\n",
|
|
|
+" if model_stat[\"weights\"].sum().abs() < 1e-3:\n",
|
|
|
+" raise RuntimeError('The weights must not sum to 0.')\n",
|
|
|
+" model_stat[\"weights\"] /= model_stat[\"weights\"].sum().abs()\n",
|
|
|
+" model_stats.append(model_stat)\n",
|
|
|
+" \n",
|
|
|
+" init = None\n",
|
|
|
+" if init_image is not None:\n",
|
|
|
+" init = Image.open(fetch(init_image)).convert('RGB')\n",
|
|
|
+" init = init.resize((args.side_x, args.side_y), Image.LANCZOS)\n",
|
|
|
+" init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)\n",
|
|
|
+" \n",
|
|
|
+" if args.perlin_init:\n",
|
|
|
+" if args.perlin_mode == 'color':\n",
|
|
|
+" init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)\n",
|
|
|
+" init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, False)\n",
|
|
|
+" elif args.perlin_mode == 'gray':\n",
|
|
|
+" init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, True)\n",
|
|
|
+" init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)\n",
|
|
|
+" else:\n",
|
|
|
+" init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)\n",
|
|
|
+" init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)\n",
|
|
|
+" # init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device)\n",
|
|
|
+" init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device).unsqueeze(0).mul(2).sub(1)\n",
|
|
|
+" del init2\n",
|
|
|
+" \n",
|
|
|
+" cur_t = None\n",
|
|
|
+" \n",
|
|
|
+" def cond_fn(x, t, y=None):\n",
|
|
|
+" with torch.enable_grad():\n",
|
|
|
+" x_is_NaN = False\n",
|
|
|
+" x = x.detach().requires_grad_()\n",
|
|
|
+" n = x.shape[0]\n",
|
|
|
+" if use_secondary_model is True:\n",
|
|
|
+" alpha = torch.tensor(diffusion.sqrt_alphas_cumprod[cur_t], device=device, dtype=torch.float32)\n",
|
|
|
+" sigma = torch.tensor(diffusion.sqrt_one_minus_alphas_cumprod[cur_t], device=device, dtype=torch.float32)\n",
|
|
|
+" cosine_t = alpha_sigma_to_t(alpha, sigma)\n",
|
|
|
+" out = secondary_model(x, cosine_t[None].repeat([n])).pred\n",
|
|
|
+" fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]\n",
|
|
|
+" x_in = out * fac + x * (1 - fac)\n",
|
|
|
+" x_in_grad = torch.zeros_like(x_in)\n",
|
|
|
+" else:\n",
|
|
|
+" my_t = torch.ones([n], device=device, dtype=torch.long) * cur_t\n",
|
|
|
+" out = diffusion.p_mean_variance(model, x, my_t, clip_denoised=False, model_kwargs={'y': y})\n",
|
|
|
+" fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]\n",
|
|
|
+" x_in = out['pred_xstart'] * fac + x * (1 - fac)\n",
|
|
|
+" x_in_grad = torch.zeros_like(x_in)\n",
|
|
|
+" for model_stat in model_stats:\n",
|
|
|
+" for i in range(args.cutn_batches):\n",
|
|
|
+" t_int = int(t.item())+1 #errors on last step without +1, need to find source\n",
|
|
|
+" #when using SLIP Base model the dimensions need to be hard coded to avoid AttributeError: 'VisionTransformer' object has no attribute 'input_resolution'\n",
|
|
|
+" try:\n",
|
|
|
+" input_resolution=model_stat[\"clip_model\"].visual.input_resolution\n",
|
|
|
+" except:\n",
|
|
|
+" input_resolution=224\n",
|
|
|
+"\n",
|
|
|
+" cuts = MakeCutoutsDango(input_resolution,\n",
|
|
|
+" Overview= args.cut_overview[1000-t_int], \n",
|
|
|
+" InnerCrop = args.cut_innercut[1000-t_int], IC_Size_Pow=args.cut_ic_pow, IC_Grey_P = args.cut_icgray_p[1000-t_int]\n",
|
|
|
+" )\n",
|
|
|
+" clip_in = normalize(cuts(x_in.add(1).div(2)))\n",
|
|
|
+" image_embeds = model_stat[\"clip_model\"].encode_image(clip_in).float()\n",
|
|
|
+" dists = spherical_dist_loss(image_embeds.unsqueeze(1), model_stat[\"target_embeds\"].unsqueeze(0))\n",
|
|
|
+" dists = dists.view([args.cut_overview[1000-t_int]+args.cut_innercut[1000-t_int], n, -1])\n",
|
|
|
+" losses = dists.mul(model_stat[\"weights\"]).sum(2).mean(0)\n",
|
|
|
+" loss_values.append(losses.sum().item()) # log loss, probably shouldn't do per cutn_batch\n",
|
|
|
+" x_in_grad += torch.autograd.grad(losses.sum() * clip_guidance_scale, x_in)[0] / cutn_batches\n",
|
|
|
+" tv_losses = tv_loss(x_in)\n",
|
|
|
+" if use_secondary_model is True:\n",
|
|
|
+" range_losses = range_loss(out)\n",
|
|
|
+" else:\n",
|
|
|
+" range_losses = range_loss(out['pred_xstart'])\n",
|
|
|
+" sat_losses = torch.abs(x_in - x_in.clamp(min=-1,max=1)).mean()\n",
|
|
|
+" loss = tv_losses.sum() * tv_scale + range_losses.sum() * range_scale + sat_losses.sum() * sat_scale\n",
|
|
|
+" if init is not None and args.init_scale:\n",
|
|
|
+" init_losses = lpips_model(x_in, init)\n",
|
|
|
+" loss = loss + init_losses.sum() * args.init_scale\n",
|
|
|
+" x_in_grad += torch.autograd.grad(loss, x_in)[0]\n",
|
|
|
+" if torch.isnan(x_in_grad).any()==False:\n",
|
|
|
+" grad = -torch.autograd.grad(x_in, x, x_in_grad)[0]\n",
|
|
|
+" else:\n",
|
|
|
+" # print(\"NaN'd\")\n",
|
|
|
+" x_is_NaN = True\n",
|
|
|
+" grad = torch.zeros_like(x)\n",
|
|
|
+" if args.clamp_grad and x_is_NaN == False:\n",
|
|
|
+" magnitude = grad.square().mean().sqrt()\n",
|
|
|
+" return grad * magnitude.clamp(max=args.clamp_max) / magnitude #min=-0.02, min=-clamp_max, \n",
|
|
|
+" return grad\n",
|
|
|
+" \n",
|
|
|
+" if model_config['timestep_respacing'].startswith('ddim'):\n",
|
|
|
+" sample_fn = diffusion.ddim_sample_loop_progressive\n",
|
|
|
+" else:\n",
|
|
|
+" sample_fn = diffusion.p_sample_loop_progressive\n",
|
|
|
+" \n",
|
|
|
+"\n",
|
|
|
+" image_display = Output()\n",
|
|
|
+" for i in range(args.n_batches):\n",
|
|
|
+" if args.animation_mode == 'None':\n",
|
|
|
+" display.clear_output(wait=True)\n",
|
|
|
+" batchBar = tqdm(range(args.n_batches), desc =\"Batches\")\n",
|
|
|
+" batchBar.n = i\n",
|
|
|
+" batchBar.refresh()\n",
|
|
|
+" print('')\n",
|
|
|
+" display.display(image_display)\n",
|
|
|
+" gc.collect()\n",
|
|
|
+" torch.cuda.empty_cache()\n",
|
|
|
+" cur_t = diffusion.num_timesteps - skip_steps - 1\n",
|
|
|
+" total_steps = cur_t\n",
|
|
|
+"\n",
|
|
|
+" if perlin_init:\n",
|
|
|
+" init = regen_perlin()\n",
|
|
|
+"\n",
|
|
|
+" if model_config['timestep_respacing'].startswith('ddim'):\n",
|
|
|
+" samples = sample_fn(\n",
|
|
|
+" model,\n",
|
|
|
+" (batch_size, 3, args.side_y, args.side_x),\n",
|
|
|
+" clip_denoised=clip_denoised,\n",
|
|
|
+" model_kwargs={},\n",
|
|
|
+" cond_fn=cond_fn,\n",
|
|
|
+" progress=True,\n",
|
|
|
+" skip_timesteps=skip_steps,\n",
|
|
|
+" init_image=init,\n",
|
|
|
+" randomize_class=randomize_class,\n",
|
|
|
+" eta=eta,\n",
|
|
|
+" )\n",
|
|
|
+" else:\n",
|
|
|
+" samples = sample_fn(\n",
|
|
|
+" model,\n",
|
|
|
+" (batch_size, 3, args.side_y, args.side_x),\n",
|
|
|
+" clip_denoised=clip_denoised,\n",
|
|
|
+" model_kwargs={},\n",
|
|
|
+" cond_fn=cond_fn,\n",
|
|
|
+" progress=True,\n",
|
|
|
+" skip_timesteps=skip_steps,\n",
|
|
|
+" init_image=init,\n",
|
|
|
+" randomize_class=randomize_class,\n",
|
|
|
+" )\n",
|
|
|
+" \n",
|
|
|
+" \n",
|
|
|
+" # with run_display:\n",
|
|
|
+" # display.clear_output(wait=True)\n",
|
|
|
+" imgToSharpen = None\n",
|
|
|
+" for j, sample in enumerate(samples): \n",
|
|
|
+" cur_t -= 1\n",
|
|
|
+" intermediateStep = False\n",
|
|
|
+" if args.steps_per_checkpoint is not None:\n",
|
|
|
+" if j % steps_per_checkpoint == 0 and j > 0:\n",
|
|
|
+" intermediateStep = True\n",
|
|
|
+" elif j in args.intermediate_saves:\n",
|
|
|
+" intermediateStep = True\n",
|
|
|
+" with image_display:\n",
|
|
|
+" if j % args.display_rate == 0 or cur_t == -1 or intermediateStep == True:\n",
|
|
|
+" for k, image in enumerate(sample['pred_xstart']):\n",
|
|
|
+" # tqdm.write(f'Batch {i}, step {j}, output {k}:')\n",
|
|
|
+" current_time = datetime.now().strftime('%y%m%d-%H%M%S_%f')\n",
|
|
|
+" percent = math.ceil(j/total_steps*100)\n",
|
|
|
+" if args.n_batches > 0:\n",
|
|
|
+" #if intermediates are saved to the subfolder, don't append a step or percentage to the name\n",
|
|
|
+" if cur_t == -1 and args.intermediates_in_subfolder is True:\n",
|
|
|
+" save_num = f'{frame_num:04}' if animation_mode != \"None\" else i\n",
|
|
|
+" filename = f'{args.batch_name}({args.batchNum})_{save_num}.png'\n",
|
|
|
+" else:\n",
|
|
|
+" #If we're working with percentages, append it\n",
|
|
|
+" if args.steps_per_checkpoint is not None:\n",
|
|
|
+" filename = f'{args.batch_name}({args.batchNum})_{i:04}-{percent:02}%.png'\n",
|
|
|
+" # Or else, iIf we're working with specific steps, append those\n",
|
|
|
+" else:\n",
|
|
|
+" filename = f'{args.batch_name}({args.batchNum})_{i:04}-{j:03}.png'\n",
|
|
|
+" image = TF.to_pil_image(image.add(1).div(2).clamp(0, 1))\n",
|
|
|
+" if j % args.display_rate == 0 or cur_t == -1:\n",
|
|
|
+" image.save('progress.png')\n",
|
|
|
+" display.clear_output(wait=True)\n",
|
|
|
+" display.display(display.Image('progress.png'))\n",
|
|
|
+" if args.steps_per_checkpoint is not None:\n",
|
|
|
+" if j % args.steps_per_checkpoint == 0 and j > 0:\n",
|
|
|
+" if args.intermediates_in_subfolder is True:\n",
|
|
|
+" image.save(f'{partialFolder}/{filename}')\n",
|
|
|
+" else:\n",
|
|
|
+" image.save(f'{batchFolder}/{filename}')\n",
|
|
|
+" else:\n",
|
|
|
+" if j in args.intermediate_saves:\n",
|
|
|
+" if args.intermediates_in_subfolder is True:\n",
|
|
|
+" image.save(f'{partialFolder}/{filename}')\n",
|
|
|
+" else:\n",
|
|
|
+" image.save(f'{batchFolder}/{filename}')\n",
|
|
|
+" if cur_t == -1:\n",
|
|
|
+" if frame_num == 0:\n",
|
|
|
+" save_settings()\n",
|
|
|
+" if args.animation_mode != \"None\":\n",
|
|
|
+" image.save('prevFrame.png')\n",
|
|
|
+" if args.sharpen_preset != \"Off\" and animation_mode == \"None\":\n",
|
|
|
+" imgToSharpen = image\n",
|
|
|
+" if args.keep_unsharp is True:\n",
|
|
|
+" image.save(f'{unsharpenFolder}/{filename}')\n",
|
|
|
+" else:\n",
|
|
|
+" #if turbo_mode, save a blended image \n",
|
|
|
+" if turbo_mode == True:\n",
|
|
|
+" #mix new image with prevFrameScaled\n",
|
|
|
+" blend_factor = (1)/int(turbo_steps)\n",
|
|
|
+" newFrame = cv2.imread('prevFrame.png')#this got updated just above..\n",
|
|
|
+" prev_frame_warped = cv2.imread(turbo_prevScaled_path)\n",
|
|
|
+" blendedImage = cv2.addWeighted(newFrame, blend_factor, prev_frame_warped, (1-blend_factor), 0.0)\n",
|
|
|
+" cv2.imwrite(f'{batchFolder}/{filename}',blendedImage)\n",
|
|
|
+" #turbo_blend = False # reset to false\n",
|
|
|
+" else:\n",
|
|
|
+" #non-turbo, just save normally\n",
|
|
|
+" image.save(f'{batchFolder}/{filename}')\n",
|
|
|
+" # if frame_num != args.max_frames-1:\n",
|
|
|
+" # display.clear_output()\n",
|
|
|
+"\n",
|
|
|
+" with image_display: \n",
|
|
|
+" if args.sharpen_preset != \"Off\" and animation_mode == \"None\":\n",
|
|
|
+" print('Starting Diffusion Sharpening...')\n",
|
|
|
+" do_superres(imgToSharpen, f'{batchFolder}/{filename}')\n",
|
|
|
+" display.clear_output()\n",
|
|
|
+" \n",
|
|
|
+" plt.plot(np.array(loss_values), 'r')\n",
|
|
|
+"\n",
|
|
|
+"def save_settings():\n",
|
|
|
+" setting_list = {\n",
|
|
|
+" 'text_prompts': text_prompts,\n",
|
|
|
+" 'image_prompts': image_prompts,\n",
|
|
|
+" 'clip_guidance_scale': clip_guidance_scale,\n",
|
|
|
+" 'tv_scale': tv_scale,\n",
|
|
|
+" 'range_scale': range_scale,\n",
|
|
|
+" 'sat_scale': sat_scale,\n",
|
|
|
+" # 'cutn': cutn,\n",
|
|
|
+" 'cutn_batches': cutn_batches,\n",
|
|
|
+" 'max_frames': max_frames,\n",
|
|
|
+" 'interp_spline': interp_spline,\n",
|
|
|
+" # 'rotation_per_frame': rotation_per_frame,\n",
|
|
|
+" 'init_image': init_image,\n",
|
|
|
+" 'init_scale': init_scale,\n",
|
|
|
+" 'skip_steps': skip_steps,\n",
|
|
|
+" # 'zoom_per_frame': zoom_per_frame,\n",
|
|
|
+" 'frames_scale': frames_scale,\n",
|
|
|
+" 'frames_skip_steps': frames_skip_steps,\n",
|
|
|
+" 'perlin_init': perlin_init,\n",
|
|
|
+" 'perlin_mode': perlin_mode,\n",
|
|
|
+" 'skip_augs': skip_augs,\n",
|
|
|
+" 'randomize_class': randomize_class,\n",
|
|
|
+" 'clip_denoised': clip_denoised,\n",
|
|
|
+" 'clamp_grad': clamp_grad,\n",
|
|
|
+" 'clamp_max': clamp_max,\n",
|
|
|
+" 'seed': seed,\n",
|
|
|
+" 'fuzzy_prompt': fuzzy_prompt,\n",
|
|
|
+" 'rand_mag': rand_mag,\n",
|
|
|
+" 'eta': eta,\n",
|
|
|
+" 'width': width_height[0],\n",
|
|
|
+" 'height': width_height[1],\n",
|
|
|
+" 'diffusion_model': diffusion_model,\n",
|
|
|
+" 'use_secondary_model': use_secondary_model,\n",
|
|
|
+" 'steps': steps,\n",
|
|
|
+" 'diffusion_steps': diffusion_steps,\n",
|
|
|
+" 'ViTB32': ViTB32,\n",
|
|
|
+" 'ViTB16': ViTB16,\n",
|
|
|
+" 'ViTL14': ViTL14,\n",
|
|
|
+" 'RN101': RN101,\n",
|
|
|
+" 'RN50': RN50,\n",
|
|
|
+" 'RN50x4': RN50x4,\n",
|
|
|
+" 'RN50x16': RN50x16,\n",
|
|
|
+" 'RN50x64': RN50x64,\n",
|
|
|
+" 'cut_overview': str(cut_overview),\n",
|
|
|
+" 'cut_innercut': str(cut_innercut),\n",
|
|
|
+" 'cut_ic_pow': cut_ic_pow,\n",
|
|
|
+" 'cut_icgray_p': str(cut_icgray_p),\n",
|
|
|
+" 'key_frames': key_frames,\n",
|
|
|
+" 'max_frames': max_frames,\n",
|
|
|
+" 'angle': angle,\n",
|
|
|
+" 'zoom': zoom,\n",
|
|
|
+" 'translation_x': translation_x,\n",
|
|
|
+" 'translation_y': translation_y,\n",
|
|
|
+" 'translation_z': translation_z,\n",
|
|
|
+" 'rotation_3d_x': rotation_3d_x,\n",
|
|
|
+" 'rotation_3d_y': rotation_3d_y,\n",
|
|
|
+" 'rotation_3d_z': rotation_3d_z,\n",
|
|
|
+" 'midas_depth_model': midas_depth_model,\n",
|
|
|
+" 'midas_weight': midas_weight,\n",
|
|
|
+" 'near_plane': near_plane,\n",
|
|
|
+" 'far_plane': far_plane,\n",
|
|
|
+" 'fov': fov,\n",
|
|
|
+" 'padding_mode': padding_mode,\n",
|
|
|
+" 'sampling_mode': sampling_mode,\n",
|
|
|
+" 'video_init_path':video_init_path,\n",
|
|
|
+" 'extract_nth_frame':extract_nth_frame,\n",
|
|
|
+" 'turbo_mode':turbo_mode,\n",
|
|
|
+" 'turbo_steps':turbo_steps,\n",
|
|
|
+" }\n",
|
|
|
+" # print('Settings:', setting_list)\n",
|
|
|
+" with open(f\"{batchFolder}/{batch_name}({batchNum})_settings.txt\", \"w+\") as f: #save settings\n",
|
|
|
+" json.dump(setting_list, f, ensure_ascii=False, indent=4)\n",
|
|
|
+" "
|
|
|
+]
|
|
|
+},
|
|
|
+{
|
|
|
+"cell_type": "code",
|
|
|
+"execution_count": null,
|
|
|
+"metadata": {
|
|
|
+"cellView": "form",
|
|
|
+"id": "TI4oAu0N4ksZ"
|
|
|
+},
|
|
|
+"outputs": [],
|
|
|
+"source": [
|
|
|
+"#@title 1.6 Define the secondary diffusion model\n",
|
|
|
+"\n",
|
|
|
+"def append_dims(x, n):\n",
|
|
|
+" return x[(Ellipsis, *(None,) * (n - x.ndim))]\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"def expand_to_planes(x, shape):\n",
|
|
|
+" return append_dims(x, len(shape)).repeat([1, 1, *shape[2:]])\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"def alpha_sigma_to_t(alpha, sigma):\n",
|
|
|
+" return torch.atan2(sigma, alpha) * 2 / math.pi\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"def t_to_alpha_sigma(t):\n",
|
|
|
+" return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"@dataclass\n",
|
|
|
+"class DiffusionOutput:\n",
|
|
|
+" v: torch.Tensor\n",
|
|
|
+" pred: torch.Tensor\n",
|
|
|
+" eps: torch.Tensor\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"class ConvBlock(nn.Sequential):\n",
|
|
|
+" def __init__(self, c_in, c_out):\n",
|
|
|
+" super().__init__(\n",
|
|
|
+" nn.Conv2d(c_in, c_out, 3, padding=1),\n",
|
|
|
+" nn.ReLU(inplace=True),\n",
|
|
|
+" )\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"class SkipBlock(nn.Module):\n",
|
|
|
+" def __init__(self, main, skip=None):\n",
|
|
|
+" super().__init__()\n",
|
|
|
+" self.main = nn.Sequential(*main)\n",
|
|
|
+" self.skip = skip if skip else nn.Identity()\n",
|
|
|
+"\n",
|
|
|
+" def forward(self, input):\n",
|
|
|
+" return torch.cat([self.main(input), self.skip(input)], dim=1)\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"class FourierFeatures(nn.Module):\n",
|
|
|
+" def __init__(self, in_features, out_features, std=1.):\n",
|
|
|
+" super().__init__()\n",
|
|
|
+" assert out_features % 2 == 0\n",
|
|
|
+" self.weight = nn.Parameter(torch.randn([out_features // 2, in_features]) * std)\n",
|
|
|
+"\n",
|
|
|
+" def forward(self, input):\n",
|
|
|
+" f = 2 * math.pi * input @ self.weight.T\n",
|
|
|
+" return torch.cat([f.cos(), f.sin()], dim=-1)\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"class SecondaryDiffusionImageNet(nn.Module):\n",
|
|
|
+" def __init__(self):\n",
|
|
|
+" super().__init__()\n",
|
|
|
+" c = 64 # The base channel count\n",
|
|
|
+"\n",
|
|
|
+" self.timestep_embed = FourierFeatures(1, 16)\n",
|
|
|
+"\n",
|
|
|
+" self.net = nn.Sequential(\n",
|
|
|
+" ConvBlock(3 + 16, c),\n",
|
|
|
+" ConvBlock(c, c),\n",
|
|
|
+" SkipBlock([\n",
|
|
|
+" nn.AvgPool2d(2),\n",
|
|
|
+" ConvBlock(c, c * 2),\n",
|
|
|
+" ConvBlock(c * 2, c * 2),\n",
|
|
|
+" SkipBlock([\n",
|
|
|
+" nn.AvgPool2d(2),\n",
|
|
|
+" ConvBlock(c * 2, c * 4),\n",
|
|
|
+" ConvBlock(c * 4, c * 4),\n",
|
|
|
+" SkipBlock([\n",
|
|
|
+" nn.AvgPool2d(2),\n",
|
|
|
+" ConvBlock(c * 4, c * 8),\n",
|
|
|
+" ConvBlock(c * 8, c * 4),\n",
|
|
|
+" nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),\n",
|
|
|
+" ]),\n",
|
|
|
+" ConvBlock(c * 8, c * 4),\n",
|
|
|
+" ConvBlock(c * 4, c * 2),\n",
|
|
|
+" nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),\n",
|
|
|
+" ]),\n",
|
|
|
+" ConvBlock(c * 4, c * 2),\n",
|
|
|
+" ConvBlock(c * 2, c),\n",
|
|
|
+" nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),\n",
|
|
|
+" ]),\n",
|
|
|
+" ConvBlock(c * 2, c),\n",
|
|
|
+" nn.Conv2d(c, 3, 3, padding=1),\n",
|
|
|
+" )\n",
|
|
|
+"\n",
|
|
|
+" def forward(self, input, t):\n",
|
|
|
+" timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape)\n",
|
|
|
+" v = self.net(torch.cat([input, timestep_embed], dim=1))\n",
|
|
|
+" alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t))\n",
|
|
|
+" pred = input * alphas - v * sigmas\n",
|
|
|
+" eps = input * sigmas + v * alphas\n",
|
|
|
+" return DiffusionOutput(v, pred, eps)\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"class SecondaryDiffusionImageNet2(nn.Module):\n",
|
|
|
+" def __init__(self):\n",
|
|
|
+" super().__init__()\n",
|
|
|
+" c = 64 # The base channel count\n",
|
|
|
+" cs = [c, c * 2, c * 2, c * 4, c * 4, c * 8]\n",
|
|
|
+"\n",
|
|
|
+" self.timestep_embed = FourierFeatures(1, 16)\n",
|
|
|
+" self.down = nn.AvgPool2d(2)\n",
|
|
|
+" self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)\n",
|
|
|
+"\n",
|
|
|
+" self.net = nn.Sequential(\n",
|
|
|
+" ConvBlock(3 + 16, cs[0]),\n",
|
|
|
+" ConvBlock(cs[0], cs[0]),\n",
|
|
|
+" SkipBlock([\n",
|
|
|
+" self.down,\n",
|
|
|
+" ConvBlock(cs[0], cs[1]),\n",
|
|
|
+" ConvBlock(cs[1], cs[1]),\n",
|
|
|
+" SkipBlock([\n",
|
|
|
+" self.down,\n",
|
|
|
+" ConvBlock(cs[1], cs[2]),\n",
|
|
|
+" ConvBlock(cs[2], cs[2]),\n",
|
|
|
+" SkipBlock([\n",
|
|
|
+" self.down,\n",
|
|
|
+" ConvBlock(cs[2], cs[3]),\n",
|
|
|
+" ConvBlock(cs[3], cs[3]),\n",
|
|
|
+" SkipBlock([\n",
|
|
|
+" self.down,\n",
|
|
|
+" ConvBlock(cs[3], cs[4]),\n",
|
|
|
+" ConvBlock(cs[4], cs[4]),\n",
|
|
|
+" SkipBlock([\n",
|
|
|
+" self.down,\n",
|
|
|
+" ConvBlock(cs[4], cs[5]),\n",
|
|
|
+" ConvBlock(cs[5], cs[5]),\n",
|
|
|
+" ConvBlock(cs[5], cs[5]),\n",
|
|
|
+" ConvBlock(cs[5], cs[4]),\n",
|
|
|
+" self.up,\n",
|
|
|
+" ]),\n",
|
|
|
+" ConvBlock(cs[4] * 2, cs[4]),\n",
|
|
|
+" ConvBlock(cs[4], cs[3]),\n",
|
|
|
+" self.up,\n",
|
|
|
+" ]),\n",
|
|
|
+" ConvBlock(cs[3] * 2, cs[3]),\n",
|
|
|
+" ConvBlock(cs[3], cs[2]),\n",
|
|
|
+" self.up,\n",
|
|
|
+" ]),\n",
|
|
|
+" ConvBlock(cs[2] * 2, cs[2]),\n",
|
|
|
+" ConvBlock(cs[2], cs[1]),\n",
|
|
|
+" self.up,\n",
|
|
|
+" ]),\n",
|
|
|
+" ConvBlock(cs[1] * 2, cs[1]),\n",
|
|
|
+" ConvBlock(cs[1], cs[0]),\n",
|
|
|
+" self.up,\n",
|
|
|
+" ]),\n",
|
|
|
+" ConvBlock(cs[0] * 2, cs[0]),\n",
|
|
|
+" nn.Conv2d(cs[0], 3, 3, padding=1),\n",
|
|
|
+" )\n",
|
|
|
+"\n",
|
|
|
+" def forward(self, input, t):\n",
|
|
|
+" timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape)\n",
|
|
|
+" v = self.net(torch.cat([input, timestep_embed], dim=1))\n",
|
|
|
+" alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t))\n",
|
|
|
+" pred = input * alphas - v * sigmas\n",
|
|
|
+" eps = input * sigmas + v * alphas\n",
|
|
|
+" return DiffusionOutput(v, pred, eps)\n"
|
|
|
+]
|
|
|
+},
|
|
|
+{
|
|
|
+"cell_type": "code",
|
|
|
+"execution_count": null,
|
|
|
+"metadata": {
|
|
|
+"cellView": "form",
|
|
|
+"id": "NJS2AUAnvn-D",
|
|
|
+"scrolled": true
|
|
|
+},
|
|
|
+"outputs": [],
|
|
|
+"source": [
|
|
|
+"#@title 1.7 SuperRes Define\n",
|
|
|
+"class DDIMSampler(object):\n",
|
|
|
+" def __init__(self, model, schedule=\"linear\", **kwargs):\n",
|
|
|
+" super().__init__()\n",
|
|
|
+" self.model = model\n",
|
|
|
+" self.ddpm_num_timesteps = model.num_timesteps\n",
|
|
|
+" self.schedule = schedule\n",
|
|
|
+"\n",
|
|
|
+" def register_buffer(self, name, attr):\n",
|
|
|
+" if type(attr) == torch.Tensor:\n",
|
|
|
+" if attr.device != torch.device(\"cuda\"):\n",
|
|
|
+" attr = attr.to(torch.device(\"cuda\"))\n",
|
|
|
+" setattr(self, name, attr)\n",
|
|
|
+"\n",
|
|
|
+" def make_schedule(self, ddim_num_steps, ddim_discretize=\"uniform\", ddim_eta=0., verbose=True):\n",
|
|
|
+" self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,\n",
|
|
|
+" num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)\n",
|
|
|
+" alphas_cumprod = self.model.alphas_cumprod\n",
|
|
|
+" assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'\n",
|
|
|
+" to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)\n",
|
|
|
+"\n",
|
|
|
+" self.register_buffer('betas', to_torch(self.model.betas))\n",
|
|
|
+" self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))\n",
|
|
|
+" self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))\n",
|
|
|
+"\n",
|
|
|
+" # calculations for diffusion q(x_t | x_{t-1}) and others\n",
|
|
|
+" self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))\n",
|
|
|
+" self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))\n",
|
|
|
+" self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))\n",
|
|
|
+" self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))\n",
|
|
|
+" self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))\n",
|
|
|
+"\n",
|
|
|
+" # ddim sampling parameters\n",
|
|
|
+" ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),\n",
|
|
|
+" ddim_timesteps=self.ddim_timesteps,\n",
|
|
|
+" eta=ddim_eta,verbose=verbose)\n",
|
|
|
+" self.register_buffer('ddim_sigmas', ddim_sigmas)\n",
|
|
|
+" self.register_buffer('ddim_alphas', ddim_alphas)\n",
|
|
|
+" self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)\n",
|
|
|
+" self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))\n",
|
|
|
+" sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(\n",
|
|
|
+" (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (\n",
|
|
|
+" 1 - self.alphas_cumprod / self.alphas_cumprod_prev))\n",
|
|
|
+" self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)\n",
|
|
|
+"\n",
|
|
|
+" @torch.no_grad()\n",
|
|
|
+" def sample(self,\n",
|
|
|
+" S,\n",
|
|
|
+" batch_size,\n",
|
|
|
+" shape,\n",
|
|
|
+" conditioning=None,\n",
|
|
|
+" callback=None,\n",
|
|
|
+" normals_sequence=None,\n",
|
|
|
+" img_callback=None,\n",
|
|
|
+" quantize_x0=False,\n",
|
|
|
+" eta=0.,\n",
|
|
|
+" mask=None,\n",
|
|
|
+" x0=None,\n",
|
|
|
+" temperature=1.,\n",
|
|
|
+" noise_dropout=0.,\n",
|
|
|
+" score_corrector=None,\n",
|
|
|
+" corrector_kwargs=None,\n",
|
|
|
+" verbose=True,\n",
|
|
|
+" x_T=None,\n",
|
|
|
+" log_every_t=100,\n",
|
|
|
+" **kwargs\n",
|
|
|
+" ):\n",
|
|
|
+" if conditioning is not None:\n",
|
|
|
+" if isinstance(conditioning, dict):\n",
|
|
|
+" cbs = conditioning[list(conditioning.keys())[0]].shape[0]\n",
|
|
|
+" if cbs != batch_size:\n",
|
|
|
+" print(f\"Warning: Got {cbs} conditionings but batch-size is {batch_size}\")\n",
|
|
|
+" else:\n",
|
|
|
+" if conditioning.shape[0] != batch_size:\n",
|
|
|
+" print(f\"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}\")\n",
|
|
|
+"\n",
|
|
|
+" self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)\n",
|
|
|
+" # sampling\n",
|
|
|
+" C, H, W = shape\n",
|
|
|
+" size = (batch_size, C, H, W)\n",
|
|
|
+" # print(f'Data shape for DDIM sampling is {size}, eta {eta}')\n",
|
|
|
+"\n",
|
|
|
+" samples, intermediates = self.ddim_sampling(conditioning, size,\n",
|
|
|
+" callback=callback,\n",
|
|
|
+" img_callback=img_callback,\n",
|
|
|
+" quantize_denoised=quantize_x0,\n",
|
|
|
+" mask=mask, x0=x0,\n",
|
|
|
+" ddim_use_original_steps=False,\n",
|
|
|
+" noise_dropout=noise_dropout,\n",
|
|
|
+" temperature=temperature,\n",
|
|
|
+" score_corrector=score_corrector,\n",
|
|
|
+" corrector_kwargs=corrector_kwargs,\n",
|
|
|
+" x_T=x_T,\n",
|
|
|
+" log_every_t=log_every_t\n",
|
|
|
+" )\n",
|
|
|
+" return samples, intermediates\n",
|
|
|
+"\n",
|
|
|
+" @torch.no_grad()\n",
|
|
|
+" def ddim_sampling(self, cond, shape,\n",
|
|
|
+" x_T=None, ddim_use_original_steps=False,\n",
|
|
|
+" callback=None, timesteps=None, quantize_denoised=False,\n",
|
|
|
+" mask=None, x0=None, img_callback=None, log_every_t=100,\n",
|
|
|
+" temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):\n",
|
|
|
+" device = self.model.betas.device\n",
|
|
|
+" b = shape[0]\n",
|
|
|
+" if x_T is None:\n",
|
|
|
+" img = torch.randn(shape, device=device)\n",
|
|
|
+" else:\n",
|
|
|
+" img = x_T\n",
|
|
|
+"\n",
|
|
|
+" if timesteps is None:\n",
|
|
|
+" timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps\n",
|
|
|
+" elif timesteps is not None and not ddim_use_original_steps:\n",
|
|
|
+" subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1\n",
|
|
|
+" timesteps = self.ddim_timesteps[:subset_end]\n",
|
|
|
+"\n",
|
|
|
+" intermediates = {'x_inter': [img], 'pred_x0': [img]}\n",
|
|
|
+" time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)\n",
|
|
|
+" total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]\n",
|
|
|
+" print(f\"Running DDIM Sharpening with {total_steps} timesteps\")\n",
|
|
|
+"\n",
|
|
|
+" iterator = tqdm(time_range, desc='DDIM Sharpening', total=total_steps)\n",
|
|
|
+"\n",
|
|
|
+" for i, step in enumerate(iterator):\n",
|
|
|
+" index = total_steps - i - 1\n",
|
|
|
+" ts = torch.full((b,), step, device=device, dtype=torch.long)\n",
|
|
|
+"\n",
|
|
|
+" if mask is not None:\n",
|
|
|
+" assert x0 is not None\n",
|
|
|
+" img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?\n",
|
|
|
+" img = img_orig * mask + (1. - mask) * img\n",
|
|
|
+"\n",
|
|
|
+" outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,\n",
|
|
|
+" quantize_denoised=quantize_denoised, temperature=temperature,\n",
|
|
|
+" noise_dropout=noise_dropout, score_corrector=score_corrector,\n",
|
|
|
+" corrector_kwargs=corrector_kwargs)\n",
|
|
|
+" img, pred_x0 = outs\n",
|
|
|
+" if callback: callback(i)\n",
|
|
|
+" if img_callback: img_callback(pred_x0, i)\n",
|
|
|
+"\n",
|
|
|
+" if index % log_every_t == 0 or index == total_steps - 1:\n",
|
|
|
+" intermediates['x_inter'].append(img)\n",
|
|
|
+" intermediates['pred_x0'].append(pred_x0)\n",
|
|
|
+"\n",
|
|
|
+" return img, intermediates\n",
|
|
|
+"\n",
|
|
|
+" @torch.no_grad()\n",
|
|
|
+" def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,\n",
|
|
|
+" temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):\n",
|
|
|
+" b, *_, device = *x.shape, x.device\n",
|
|
|
+" e_t = self.model.apply_model(x, t, c)\n",
|
|
|
+" if score_corrector is not None:\n",
|
|
|
+" assert self.model.parameterization == \"eps\"\n",
|
|
|
+" e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)\n",
|
|
|
+"\n",
|
|
|
+" alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas\n",
|
|
|
+" alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev\n",
|
|
|
+" sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas\n",
|
|
|
+" sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas\n",
|
|
|
+" # select parameters corresponding to the currently considered timestep\n",
|
|
|
+" a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)\n",
|
|
|
+" a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)\n",
|
|
|
+" sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)\n",
|
|
|
+" sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)\n",
|
|
|
+"\n",
|
|
|
+" # current prediction for x_0\n",
|
|
|
+" pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()\n",
|
|
|
+" if quantize_denoised:\n",
|
|
|
+" pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)\n",
|
|
|
+" # direction pointing to x_t\n",
|
|
|
+" dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t\n",
|
|
|
+" noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature\n",
|
|
|
+" if noise_dropout > 0.:\n",
|
|
|
+" noise = torch.nn.functional.dropout(noise, p=noise_dropout)\n",
|
|
|
+" x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise\n",
|
|
|
+" return x_prev, pred_x0\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"def download_models(mode):\n",
|
|
|
+"\n",
|
|
|
+" if mode == \"superresolution\":\n",
|
|
|
+" # this is the small bsr light model\n",
|
|
|
+" url_conf = 'https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1'\n",
|
|
|
+" url_ckpt = 'https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1'\n",
|
|
|
+"\n",
|
|
|
+" path_conf = f'{model_path}/superres/project.yaml'\n",
|
|
|
+" path_ckpt = f'{model_path}/superres/last.ckpt'\n",
|
|
|
+"\n",
|
|
|
+" download_url(url_conf, path_conf)\n",
|
|
|
+" download_url(url_ckpt, path_ckpt)\n",
|
|
|
+"\n",
|
|
|
+" path_conf = path_conf + '/?dl=1' # fix it\n",
|
|
|
+" path_ckpt = path_ckpt + '/?dl=1' # fix it\n",
|
|
|
+" return path_conf, path_ckpt\n",
|
|
|
+"\n",
|
|
|
+" else:\n",
|
|
|
+" raise NotImplementedError\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"def load_model_from_config(config, ckpt):\n",
|
|
|
+" print(f\"Loading model from {ckpt}\")\n",
|
|
|
+" pl_sd = torch.load(ckpt, map_location=\"cpu\")\n",
|
|
|
+" global_step = pl_sd[\"global_step\"]\n",
|
|
|
+" sd = pl_sd[\"state_dict\"]\n",
|
|
|
+" model = instantiate_from_config(config.model)\n",
|
|
|
+" m, u = model.load_state_dict(sd, strict=False)\n",
|
|
|
+" model.cuda()\n",
|
|
|
+" model.eval()\n",
|
|
|
+" return {\"model\": model}, global_step\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"def get_model(mode):\n",
|
|
|
+" path_conf, path_ckpt = download_models(mode)\n",
|
|
|
+" config = OmegaConf.load(path_conf)\n",
|
|
|
+" model, step = load_model_from_config(config, path_ckpt)\n",
|
|
|
+" return model\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"def get_custom_cond(mode):\n",
|
|
|
+" dest = \"data/example_conditioning\"\n",
|
|
|
+"\n",
|
|
|
+" if mode == \"superresolution\":\n",
|
|
|
+" uploaded_img = files.upload()\n",
|
|
|
+" filename = next(iter(uploaded_img))\n",
|
|
|
+" name, filetype = filename.split(\".\") # todo assumes just one dot in name !\n",
|
|
|
+" os.rename(f\"{filename}\", f\"{dest}/{mode}/custom_{name}.{filetype}\")\n",
|
|
|
+"\n",
|
|
|
+" elif mode == \"text_conditional\":\n",
|
|
|
+" w = widgets.Text(value='A cake with cream!', disabled=True)\n",
|
|
|
+" display.display(w)\n",
|
|
|
+"\n",
|
|
|
+" with open(f\"{dest}/{mode}/custom_{w.value[:20]}.txt\", 'w') as f:\n",
|
|
|
+" f.write(w.value)\n",
|
|
|
+"\n",
|
|
|
+" elif mode == \"class_conditional\":\n",
|
|
|
+" w = widgets.IntSlider(min=0, max=1000)\n",
|
|
|
+" display.display(w)\n",
|
|
|
+" with open(f\"{dest}/{mode}/custom.txt\", 'w') as f:\n",
|
|
|
+" f.write(w.value)\n",
|
|
|
+"\n",
|
|
|
+" else:\n",
|
|
|
+" raise NotImplementedError(f\"cond not implemented for mode{mode}\")\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"def get_cond_options(mode):\n",
|
|
|
+" path = \"data/example_conditioning\"\n",
|
|
|
+" path = os.path.join(path, mode)\n",
|
|
|
+" onlyfiles = [f for f in sorted(os.listdir(path))]\n",
|
|
|
+" return path, onlyfiles\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"def select_cond_path(mode):\n",
|
|
|
+" path = \"data/example_conditioning\" # todo\n",
|
|
|
+" path = os.path.join(path, mode)\n",
|
|
|
+" onlyfiles = [f for f in sorted(os.listdir(path))]\n",
|
|
|
+"\n",
|
|
|
+" selected = widgets.RadioButtons(\n",
|
|
|
+" options=onlyfiles,\n",
|
|
|
+" description='Select conditioning:',\n",
|
|
|
+" disabled=False\n",
|
|
|
+" )\n",
|
|
|
+" display.display(selected)\n",
|
|
|
+" selected_path = os.path.join(path, selected.value)\n",
|
|
|
+" return selected_path\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"def get_cond(mode, img):\n",
|
|
|
+" example = dict()\n",
|
|
|
+" if mode == \"superresolution\":\n",
|
|
|
+" up_f = 4\n",
|
|
|
+" # visualize_cond_img(selected_path)\n",
|
|
|
+"\n",
|
|
|
+" c = img\n",
|
|
|
+" c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)\n",
|
|
|
+" c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], antialias=True)\n",
|
|
|
+" c_up = rearrange(c_up, '1 c h w -> 1 h w c')\n",
|
|
|
+" c = rearrange(c, '1 c h w -> 1 h w c')\n",
|
|
|
+" c = 2. * c - 1.\n",
|
|
|
+"\n",
|
|
|
+" c = c.to(torch.device(\"cuda\"))\n",
|
|
|
+" example[\"LR_image\"] = c\n",
|
|
|
+" example[\"image\"] = c_up\n",
|
|
|
+"\n",
|
|
|
+" return example\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"def visualize_cond_img(path):\n",
|
|
|
+" display.display(ipyimg(filename=path))\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"def sr_run(model, img, task, custom_steps, eta, resize_enabled=False, classifier_ckpt=None, global_step=None):\n",
|
|
|
+" # global stride\n",
|
|
|
+"\n",
|
|
|
+" example = get_cond(task, img)\n",
|
|
|
+"\n",
|
|
|
+" save_intermediate_vid = False\n",
|
|
|
+" n_runs = 1\n",
|
|
|
+" masked = False\n",
|
|
|
+" guider = None\n",
|
|
|
+" ckwargs = None\n",
|
|
|
+" mode = 'ddim'\n",
|
|
|
+" ddim_use_x0_pred = False\n",
|
|
|
+" temperature = 1.\n",
|
|
|
+" eta = eta\n",
|
|
|
+" make_progrow = True\n",
|
|
|
+" custom_shape = None\n",
|
|
|
+"\n",
|
|
|
+" height, width = example[\"image\"].shape[1:3]\n",
|
|
|
+" split_input = height >= 128 and width >= 128\n",
|
|
|
+"\n",
|
|
|
+" if split_input:\n",
|
|
|
+" ks = 128\n",
|
|
|
+" stride = 64\n",
|
|
|
+" vqf = 4 #\n",
|
|
|
+" model.split_input_params = {\"ks\": (ks, ks), \"stride\": (stride, stride),\n",
|
|
|
+" \"vqf\": vqf,\n",
|
|
|
+" \"patch_distributed_vq\": True,\n",
|
|
|
+" \"tie_braker\": False,\n",
|
|
|
+" \"clip_max_weight\": 0.5,\n",
|
|
|
+" \"clip_min_weight\": 0.01,\n",
|
|
|
+" \"clip_max_tie_weight\": 0.5,\n",
|
|
|
+" \"clip_min_tie_weight\": 0.01}\n",
|
|
|
+" else:\n",
|
|
|
+" if hasattr(model, \"split_input_params\"):\n",
|
|
|
+" delattr(model, \"split_input_params\")\n",
|
|
|
+"\n",
|
|
|
+" invert_mask = False\n",
|
|
|
+"\n",
|
|
|
+" x_T = None\n",
|
|
|
+" for n in range(n_runs):\n",
|
|
|
+" if custom_shape is not None:\n",
|
|
|
+" x_T = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)\n",
|
|
|
+" x_T = repeat(x_T, '1 c h w -> b c h w', b=custom_shape[0])\n",
|
|
|
+"\n",
|
|
|
+" logs = make_convolutional_sample(example, model,\n",
|
|
|
+" mode=mode, custom_steps=custom_steps,\n",
|
|
|
+" eta=eta, swap_mode=False , masked=masked,\n",
|
|
|
+" invert_mask=invert_mask, quantize_x0=False,\n",
|
|
|
+" custom_schedule=None, decode_interval=10,\n",
|
|
|
+" resize_enabled=resize_enabled, custom_shape=custom_shape,\n",
|
|
|
+" temperature=temperature, noise_dropout=0.,\n",
|
|
|
+" corrector=guider, corrector_kwargs=ckwargs, x_T=x_T, save_intermediate_vid=save_intermediate_vid,\n",
|
|
|
+" make_progrow=make_progrow,ddim_use_x0_pred=ddim_use_x0_pred\n",
|
|
|
+" )\n",
|
|
|
+" return logs\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"@torch.no_grad()\n",
|
|
|
+"def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,\n",
|
|
|
+" mask=None, x0=None, quantize_x0=False, img_callback=None,\n",
|
|
|
+" temperature=1., noise_dropout=0., score_corrector=None,\n",
|
|
|
+" corrector_kwargs=None, x_T=None, log_every_t=None\n",
|
|
|
+" ):\n",
|
|
|
+"\n",
|
|
|
+" ddim = DDIMSampler(model)\n",
|
|
|
+" bs = shape[0] # dont know where this comes from but wayne\n",
|
|
|
+" shape = shape[1:] # cut batch dim\n",
|
|
|
+" # print(f\"Sampling with eta = {eta}; steps: {steps}\")\n",
|
|
|
+" samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,\n",
|
|
|
+" normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,\n",
|
|
|
+" mask=mask, x0=x0, temperature=temperature, verbose=False,\n",
|
|
|
+" score_corrector=score_corrector,\n",
|
|
|
+" corrector_kwargs=corrector_kwargs, x_T=x_T)\n",
|
|
|
+"\n",
|
|
|
+" return samples, intermediates\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"@torch.no_grad()\n",
|
|
|
+"def make_convolutional_sample(batch, model, mode=\"vanilla\", custom_steps=None, eta=1.0, swap_mode=False, masked=False,\n",
|
|
|
+" invert_mask=True, quantize_x0=False, custom_schedule=None, decode_interval=1000,\n",
|
|
|
+" resize_enabled=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,\n",
|
|
|
+" corrector_kwargs=None, x_T=None, save_intermediate_vid=False, make_progrow=True,ddim_use_x0_pred=False):\n",
|
|
|
+" log = dict()\n",
|
|
|
+"\n",
|
|
|
+" z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,\n",
|
|
|
+" return_first_stage_outputs=True,\n",
|
|
|
+" force_c_encode=not (hasattr(model, 'split_input_params')\n",
|
|
|
+" and model.cond_stage_key == 'coordinates_bbox'),\n",
|
|
|
+" return_original_cond=True)\n",
|
|
|
+"\n",
|
|
|
+" log_every_t = 1 if save_intermediate_vid else None\n",
|
|
|
+"\n",
|
|
|
+" if custom_shape is not None:\n",
|
|
|
+" z = torch.randn(custom_shape)\n",
|
|
|
+" # print(f\"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}\")\n",
|
|
|
+"\n",
|
|
|
+" z0 = None\n",
|
|
|
+"\n",
|
|
|
+" log[\"input\"] = x\n",
|
|
|
+" log[\"reconstruction\"] = xrec\n",
|
|
|
+"\n",
|
|
|
+" if ismap(xc):\n",
|
|
|
+" log[\"original_conditioning\"] = model.to_rgb(xc)\n",
|
|
|
+" if hasattr(model, 'cond_stage_key'):\n",
|
|
|
+" log[model.cond_stage_key] = model.to_rgb(xc)\n",
|
|
|
+"\n",
|
|
|
+" else:\n",
|
|
|
+" log[\"original_conditioning\"] = xc if xc is not None else torch.zeros_like(x)\n",
|
|
|
+" if model.cond_stage_model:\n",
|
|
|
+" log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)\n",
|
|
|
+" if model.cond_stage_key =='class_label':\n",
|
|
|
+" log[model.cond_stage_key] = xc[model.cond_stage_key]\n",
|
|
|
+"\n",
|
|
|
+" with model.ema_scope(\"Plotting\"):\n",
|
|
|
+" t0 = time.time()\n",
|
|
|
+" img_cb = None\n",
|
|
|
+"\n",
|
|
|
+" sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,\n",
|
|
|
+" eta=eta,\n",
|
|
|
+" quantize_x0=quantize_x0, img_callback=img_cb, mask=None, x0=z0,\n",
|
|
|
+" temperature=temperature, noise_dropout=noise_dropout,\n",
|
|
|
+" score_corrector=corrector, corrector_kwargs=corrector_kwargs,\n",
|
|
|
+" x_T=x_T, log_every_t=log_every_t)\n",
|
|
|
+" t1 = time.time()\n",
|
|
|
+"\n",
|
|
|
+" if ddim_use_x0_pred:\n",
|
|
|
+" sample = intermediates['pred_x0'][-1]\n",
|
|
|
+"\n",
|
|
|
+" x_sample = model.decode_first_stage(sample)\n",
|
|
|
+"\n",
|
|
|
+" try:\n",
|
|
|
+" x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)\n",
|
|
|
+" log[\"sample_noquant\"] = x_sample_noquant\n",
|
|
|
+" log[\"sample_diff\"] = torch.abs(x_sample_noquant - x_sample)\n",
|
|
|
+" except:\n",
|
|
|
+" pass\n",
|
|
|
+"\n",
|
|
|
+" log[\"sample\"] = x_sample\n",
|
|
|
+" log[\"time\"] = t1 - t0\n",
|
|
|
+"\n",
|
|
|
+" return log\n",
|
|
|
+"\n",
|
|
|
+"sr_diffMode = 'superresolution'\n",
|
|
|
+"sr_model = get_model('superresolution')\n",
|
|
|
+"\n",
|
|
|
+"def do_superres(img, filepath):\n",
|
|
|
+"\n",
|
|
|
+" if args.sharpen_preset == 'Faster':\n",
|
|
|
+" sr_diffusion_steps = \"25\" \n",
|
|
|
+" sr_pre_downsample = '1/2' \n",
|
|
|
+" if args.sharpen_preset == 'Fast':\n",
|
|
|
+" sr_diffusion_steps = \"100\" \n",
|
|
|
+" sr_pre_downsample = '1/2' \n",
|
|
|
+" if args.sharpen_preset == 'Slow':\n",
|
|
|
+" sr_diffusion_steps = \"25\" \n",
|
|
|
+" sr_pre_downsample = 'None' \n",
|
|
|
+" if args.sharpen_preset == 'Very Slow':\n",
|
|
|
+" sr_diffusion_steps = \"100\" \n",
|
|
|
+" sr_pre_downsample = 'None' \n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+" sr_post_downsample = 'Original Size'\n",
|
|
|
+" sr_diffusion_steps = int(sr_diffusion_steps)\n",
|
|
|
+" sr_eta = 1.0 \n",
|
|
|
+" sr_downsample_method = 'Lanczos' \n",
|
|
|
+"\n",
|
|
|
+" gc.collect()\n",
|
|
|
+" torch.cuda.empty_cache()\n",
|
|
|
+"\n",
|
|
|
+" im_og = img\n",
|
|
|
+" width_og, height_og = im_og.size\n",
|
|
|
+"\n",
|
|
|
+" #Downsample Pre\n",
|
|
|
+" if sr_pre_downsample == '1/2':\n",
|
|
|
+" downsample_rate = 2\n",
|
|
|
+" elif sr_pre_downsample == '1/4':\n",
|
|
|
+" downsample_rate = 4\n",
|
|
|
+" else:\n",
|
|
|
+" downsample_rate = 1\n",
|
|
|
+"\n",
|
|
|
+" width_downsampled_pre = width_og//downsample_rate\n",
|
|
|
+" height_downsampled_pre = height_og//downsample_rate\n",
|
|
|
+"\n",
|
|
|
+" if downsample_rate != 1:\n",
|
|
|
+" # print(f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')\n",
|
|
|
+" im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)\n",
|
|
|
+" # im_og.save('/content/temp.png')\n",
|
|
|
+" # filepath = '/content/temp.png'\n",
|
|
|
+"\n",
|
|
|
+" logs = sr_run(sr_model[\"model\"], im_og, sr_diffMode, sr_diffusion_steps, sr_eta)\n",
|
|
|
+"\n",
|
|
|
+" sample = logs[\"sample\"]\n",
|
|
|
+" sample = sample.detach().cpu()\n",
|
|
|
+" sample = torch.clamp(sample, -1., 1.)\n",
|
|
|
+" sample = (sample + 1.) / 2. * 255\n",
|
|
|
+" sample = sample.numpy().astype(np.uint8)\n",
|
|
|
+" sample = np.transpose(sample, (0, 2, 3, 1))\n",
|
|
|
+" a = Image.fromarray(sample[0])\n",
|
|
|
+"\n",
|
|
|
+" #Downsample Post\n",
|
|
|
+" if sr_post_downsample == '1/2':\n",
|
|
|
+" downsample_rate = 2\n",
|
|
|
+" elif sr_post_downsample == '1/4':\n",
|
|
|
+" downsample_rate = 4\n",
|
|
|
+" else:\n",
|
|
|
+" downsample_rate = 1\n",
|
|
|
+"\n",
|
|
|
+" width, height = a.size\n",
|
|
|
+" width_downsampled_post = width//downsample_rate\n",
|
|
|
+" height_downsampled_post = height//downsample_rate\n",
|
|
|
+"\n",
|
|
|
+" if sr_downsample_method == 'Lanczos':\n",
|
|
|
+" aliasing = Image.LANCZOS\n",
|
|
|
+" else:\n",
|
|
|
+" aliasing = Image.NEAREST\n",
|
|
|
+"\n",
|
|
|
+" if downsample_rate != 1:\n",
|
|
|
+" # print(f'Downsampling from [{width}, {height}] to [{width_downsampled_post}, {height_downsampled_post}]')\n",
|
|
|
+" a = a.resize((width_downsampled_post, height_downsampled_post), aliasing)\n",
|
|
|
+" elif sr_post_downsample == 'Original Size':\n",
|
|
|
+" # print(f'Downsampling from [{width}, {height}] to Original Size [{width_og}, {height_og}]')\n",
|
|
|
+" a = a.resize((width_og, height_og), aliasing)\n",
|
|
|
+"\n",
|
|
|
+" display.display(a)\n",
|
|
|
+" a.save(filepath)\n",
|
|
|
+" return\n",
|
|
|
+" print(f'Processing finished!')\n"
|
|
|
+]
|
|
|
+},
|
|
|
+{
|
|
|
+"cell_type": "markdown",
|
|
|
+"metadata": {
|
|
|
+"id": "CQVtY1Ixnqx4"
|
|
|
+},
|
|
|
+"source": [
|
|
|
+"# 2. Diffusion and CLIP model settings"
|
|
|
+]
|
|
|
+},
|
|
|
+{
|
|
|
+"cell_type": "code",
|
|
|
+"execution_count": null,
|
|
|
+"metadata": {
|
|
|
+"cellView": "form",
|
|
|
+"id": "Fpbody2NCR7w",
|
|
|
+"scrolled": true
|
|
|
+},
|
|
|
+"outputs": [],
|
|
|
+"source": [
|
|
|
+"#@markdown ####**Models Settings:**\n",
|
|
|
+"diffusion_model = \"512x512_diffusion_uncond_finetune_008100\" #@param [\"256x256_diffusion_uncond\", \"512x512_diffusion_uncond_finetune_008100\"]\n",
|
|
|
+"use_secondary_model = True #@param {type: 'boolean'}\n",
|
|
|
+"\n",
|
|
|
+"timestep_respacing = '50' # param ['25','50','100','150','250','500','1000','ddim25','ddim50', 'ddim75', 'ddim100','ddim150','ddim250','ddim500','ddim1000'] \n",
|
|
|
+"diffusion_steps = 1000 # param {type: 'number'}\n",
|
|
|
+"use_checkpoint = True #@param {type: 'boolean'}\n",
|
|
|
+"ViTB32 = True #@param{type:\"boolean\"}\n",
|
|
|
+"ViTB16 = True #@param{type:\"boolean\"}\n",
|
|
|
+"ViTL14 = False #@param{type:\"boolean\"} # Default False\n",
|
|
|
+"RN101 = True #@param{type:\"boolean\"} # Default False\n",
|
|
|
+"RN50 = True #@param{type:\"boolean\"} # Default True\n",
|
|
|
+"RN50x4 = True #@param{type:\"boolean\"} # Default False\n",
|
|
|
+"RN50x16 = False #@param{type:\"boolean\"}\n",
|
|
|
+"RN50x64 = False #@param{type:\"boolean\"}\n",
|
|
|
+"SLIPB16 = False # param{type:\"boolean\"} # Default False. Looks broken, likely related to commented import of SLIP_VITB16\n",
|
|
|
+"SLIPL16 = False # param{type:\"boolean\"}\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"#@markdown If you're having issues with model downloads, check this to compare SHA's:\n",
|
|
|
+"check_model_SHA = False #@param{type:\"boolean\"}\n",
|
|
|
+"\n",
|
|
|
+"model_256_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a'\n",
|
|
|
+"model_512_SHA = '9c111ab89e214862b76e1fa6a1b3f1d329b1a88281885943d2cdbe357ad57648'\n",
|
|
|
+"model_secondary_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a'\n",
|
|
|
+"\n",
|
|
|
+"model_256_link = 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt'\n",
|
|
|
+"model_512_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/512x512_diffusion_uncond_finetune_008100.pt'\n",
|
|
|
+"model_secondary_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/secondary_model_imagenet_2.pth'\n",
|
|
|
+"\n",
|
|
|
+"model_256_path = f'{model_path}/256x256_diffusion_uncond.pt'\n",
|
|
|
+"model_512_path = f'{model_path}/512x512_diffusion_uncond_finetune_008100.pt'\n",
|
|
|
+"model_secondary_path = f'{model_path}/secondary_model_imagenet_2.pth'\n",
|
|
|
+"\n",
|
|
|
+"# Download the diffusion model\n",
|
|
|
+"if diffusion_model == '256x256_diffusion_uncond':\n",
|
|
|
+" if os.path.exists(model_256_path) and check_model_SHA:\n",
|
|
|
+" print('Checking 256 Diffusion File')\n",
|
|
|
+" with open(model_256_path,\"rb\") as f:\n",
|
|
|
+" bytes = f.read() \n",
|
|
|
+" hash = hashlib.sha256(bytes).hexdigest();\n",
|
|
|
+" if hash == model_256_SHA:\n",
|
|
|
+" print('256 Model SHA matches')\n",
|
|
|
+" model_256_downloaded = True\n",
|
|
|
+" else: \n",
|
|
|
+" print(\"256 Model SHA doesn't match, redownloading...\")\n",
|
|
|
+" !wget --continue {model_256_link} -P {model_path}\n",
|
|
|
+" model_256_downloaded = True\n",
|
|
|
+" elif os.path.exists(model_256_path) and not check_model_SHA or model_256_downloaded == True:\n",
|
|
|
+" print('256 Model already downloaded, check check_model_SHA if the file is corrupt')\n",
|
|
|
+" else: \n",
|
|
|
+" !wget --continue {model_256_link} -P {model_path}\n",
|
|
|
+" model_256_downloaded = True\n",
|
|
|
+"elif diffusion_model == '512x512_diffusion_uncond_finetune_008100':\n",
|
|
|
+" if os.path.exists(model_512_path) and check_model_SHA:\n",
|
|
|
+" print('Checking 512 Diffusion File')\n",
|
|
|
+" with open(model_512_path,\"rb\") as f:\n",
|
|
|
+" bytes = f.read() \n",
|
|
|
+" hash = hashlib.sha256(bytes).hexdigest();\n",
|
|
|
+" if hash == model_512_SHA:\n",
|
|
|
+" print('512 Model SHA matches')\n",
|
|
|
+" model_512_downloaded = True\n",
|
|
|
+" else: \n",
|
|
|
+" print(\"512 Model SHA doesn't match, redownloading...\")\n",
|
|
|
+" !wget --continue {model_512_link} -P {model_path}\n",
|
|
|
+" model_512_downloaded = True\n",
|
|
|
+" elif os.path.exists(model_512_path) and not check_model_SHA or model_512_downloaded == True:\n",
|
|
|
+" print('512 Model already downloaded, check check_model_SHA if the file is corrupt')\n",
|
|
|
+" else: \n",
|
|
|
+" !wget --continue {model_512_link} -P {model_path}\n",
|
|
|
+" model_512_downloaded = True\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"# Download the secondary diffusion model v2\n",
|
|
|
+"if use_secondary_model == True:\n",
|
|
|
+" if os.path.exists(model_secondary_path) and check_model_SHA:\n",
|
|
|
+" print('Checking Secondary Diffusion File')\n",
|
|
|
+" with open(model_secondary_path,\"rb\") as f:\n",
|
|
|
+" bytes = f.read() \n",
|
|
|
+" hash = hashlib.sha256(bytes).hexdigest();\n",
|
|
|
+" if hash == model_secondary_SHA:\n",
|
|
|
+" print('Secondary Model SHA matches')\n",
|
|
|
+" model_secondary_downloaded = True\n",
|
|
|
+" else: \n",
|
|
|
+" print(\"Secondary Model SHA doesn't match, redownloading...\")\n",
|
|
|
+" !wget --continue {model_secondary_link} -P {model_path}\n",
|
|
|
+" model_secondary_downloaded = True\n",
|
|
|
+" elif os.path.exists(model_secondary_path) and not check_model_SHA or model_secondary_downloaded == True:\n",
|
|
|
+" print('Secondary Model already downloaded, check check_model_SHA if the file is corrupt')\n",
|
|
|
+" else: \n",
|
|
|
+" !wget --continue {model_secondary_link} -P {model_path}\n",
|
|
|
+" model_secondary_downloaded = True\n",
|
|
|
+"\n",
|
|
|
+"model_config = model_and_diffusion_defaults()\n",
|
|
|
+"if diffusion_model == '512x512_diffusion_uncond_finetune_008100':\n",
|
|
|
+" model_config.update({\n",
|
|
|
+" 'attention_resolutions': '32, 16, 8',\n",
|
|
|
+" 'class_cond': False,\n",
|
|
|
+" 'diffusion_steps': diffusion_steps,\n",
|
|
|
+" 'rescale_timesteps': True,\n",
|
|
|
+" 'timestep_respacing': timestep_respacing,\n",
|
|
|
+" 'image_size': 512,\n",
|
|
|
+" 'learn_sigma': True,\n",
|
|
|
+" 'noise_schedule': 'linear',\n",
|
|
|
+" 'num_channels': 256,\n",
|
|
|
+" 'num_head_channels': 64,\n",
|
|
|
+" 'num_res_blocks': 2,\n",
|
|
|
+" 'resblock_updown': True,\n",
|
|
|
+" 'use_checkpoint': use_checkpoint,\n",
|
|
|
+" 'use_fp16': True,\n",
|
|
|
+" 'use_scale_shift_norm': True,\n",
|
|
|
+" })\n",
|
|
|
+"elif diffusion_model == '256x256_diffusion_uncond':\n",
|
|
|
+" model_config.update({\n",
|
|
|
+" 'attention_resolutions': '32, 16, 8',\n",
|
|
|
+" 'class_cond': False,\n",
|
|
|
+" 'diffusion_steps': diffusion_steps,\n",
|
|
|
+" 'rescale_timesteps': True,\n",
|
|
|
+" 'timestep_respacing': timestep_respacing,\n",
|
|
|
+" 'image_size': 256,\n",
|
|
|
+" 'learn_sigma': True,\n",
|
|
|
+" 'noise_schedule': 'linear',\n",
|
|
|
+" 'num_channels': 256,\n",
|
|
|
+" 'num_head_channels': 64,\n",
|
|
|
+" 'num_res_blocks': 2,\n",
|
|
|
+" 'resblock_updown': True,\n",
|
|
|
+" 'use_checkpoint': use_checkpoint,\n",
|
|
|
+" 'use_fp16': True,\n",
|
|
|
+" 'use_scale_shift_norm': True,\n",
|
|
|
+" })\n",
|
|
|
+"\n",
|
|
|
+"secondary_model_ver = 2\n",
|
|
|
+"model_default = model_config['image_size']\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"if secondary_model_ver == 2:\n",
|
|
|
+" secondary_model = SecondaryDiffusionImageNet2()\n",
|
|
|
+" secondary_model.load_state_dict(torch.load(f'{model_path}/secondary_model_imagenet_2.pth', map_location='cpu'))\n",
|
|
|
+"secondary_model.eval().requires_grad_(False).to(device)\n",
|
|
|
+"\n",
|
|
|
+"clip_models = []\n",
|
|
|
+"if ViTB32 is True: clip_models.append(clip.load('ViT-B/32', jit=False)[0].eval().requires_grad_(False).to(device)) \n",
|
|
|
+"if ViTB16 is True: clip_models.append(clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device) ) \n",
|
|
|
+"if ViTL14 is True: clip_models.append(clip.load('ViT-L/14', jit=False)[0].eval().requires_grad_(False).to(device) ) \n",
|
|
|
+"if RN50 is True: clip_models.append(clip.load('RN50', jit=False)[0].eval().requires_grad_(False).to(device))\n",
|
|
|
+"if RN50x4 is True: clip_models.append(clip.load('RN50x4', jit=False)[0].eval().requires_grad_(False).to(device)) \n",
|
|
|
+"if RN50x16 is True: clip_models.append(clip.load('RN50x16', jit=False)[0].eval().requires_grad_(False).to(device)) \n",
|
|
|
+"if RN50x64 is True: clip_models.append(clip.load('RN50x64', jit=False)[0].eval().requires_grad_(False).to(device)) \n",
|
|
|
+"if RN101 is True: clip_models.append(clip.load('RN101', jit=False)[0].eval().requires_grad_(False).to(device)) \n",
|
|
|
+"\n",
|
|
|
+"if SLIPB16:\n",
|
|
|
+" SLIPB16model = SLIP_VITB16(ssl_mlp_dim=4096, ssl_emb_dim=256)\n",
|
|
|
+" if not os.path.exists(f'{model_path}/slip_base_100ep.pt'):\n",
|
|
|
+" !wget https://dl.fbaipublicfiles.com/slip/slip_base_100ep.pt -P {model_path}\n",
|
|
|
+" sd = torch.load(f'{model_path}/slip_base_100ep.pt')\n",
|
|
|
+" real_sd = {}\n",
|
|
|
+" for k, v in sd['state_dict'].items():\n",
|
|
|
+" real_sd['.'.join(k.split('.')[1:])] = v\n",
|
|
|
+" del sd\n",
|
|
|
+" SLIPB16model.load_state_dict(real_sd)\n",
|
|
|
+" SLIPB16model.requires_grad_(False).eval().to(device)\n",
|
|
|
+"\n",
|
|
|
+" clip_models.append(SLIPB16model)\n",
|
|
|
+"\n",
|
|
|
+"if SLIPL16:\n",
|
|
|
+" SLIPL16model = SLIP_VITL16(ssl_mlp_dim=4096, ssl_emb_dim=256)\n",
|
|
|
+" if not os.path.exists(f'{model_path}/slip_large_100ep.pt'):\n",
|
|
|
+" !wget https://dl.fbaipublicfiles.com/slip/slip_large_100ep.pt -P {model_path}\n",
|
|
|
+" sd = torch.load(f'{model_path}/slip_large_100ep.pt')\n",
|
|
|
+" real_sd = {}\n",
|
|
|
+" for k, v in sd['state_dict'].items():\n",
|
|
|
+" real_sd['.'.join(k.split('.')[1:])] = v\n",
|
|
|
+" del sd\n",
|
|
|
+" SLIPL16model.load_state_dict(real_sd)\n",
|
|
|
+" SLIPL16model.requires_grad_(False).eval().to(device)\n",
|
|
|
+"\n",
|
|
|
+" clip_models.append(SLIPL16model)\n",
|
|
|
+"\n",
|
|
|
+"normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])\n",
|
|
|
+"lpips_model = lpips.LPIPS(net='vgg').to(device)"
|
|
|
+]
|
|
|
+},
|
|
|
+{
|
|
|
+"cell_type": "markdown",
|
|
|
+"metadata": {
|
|
|
+"id": "kjtsXaszn-bB"
|
|
|
+},
|
|
|
+"source": [
|
|
|
+"# 3. Settings"
|
|
|
+]
|
|
|
+},
|
|
|
+{
|
|
|
+"cell_type": "code",
|
|
|
+"execution_count": null,
|
|
|
+"metadata": {
|
|
|
+"cellView": "form",
|
|
|
+"id": "U0PwzFZbLfcy"
|
|
|
+},
|
|
|
+"outputs": [],
|
|
|
+"source": [
|
|
|
+"#@markdown ####**Basic Settings:**\n",
|
|
|
+"batch_name = 'TimeToDiscoTurbo3' #@param{type: 'string'}\n",
|
|
|
+"steps = 500 #@param [25,50,100,150,250,500,1000]{type: 'raw', allow-input: true}\n",
|
|
|
+"width_height = [800, 450]#[1280, 720]# [800, 450] #[1600, 900] #@param{type: 'raw'}\n",
|
|
|
+"clip_guidance_scale = 35000 #@param{type: 'number'}\n",
|
|
|
+"tv_scale = 1#@param{type: 'number'}\n",
|
|
|
+"range_scale = 450#@param{type: 'number'}\n",
|
|
|
+"sat_scale = 10000#@param{type: 'number'}\n",
|
|
|
+"cutn_batches = 1 #@param{type: 'number'}\n",
|
|
|
+"skip_augs = False#@param{type: 'boolean'}\n",
|
|
|
+"\n",
|
|
|
+"#@markdown ---\n",
|
|
|
+"\n",
|
|
|
+"#@markdown ####**Init Settings:**\n",
|
|
|
+"init_image = None #@param{type: 'string'}\n",
|
|
|
+"init_scale = 1000 #@param{type: 'integer'} default 1000\n",
|
|
|
+"skip_steps = int(steps * 0.5) if init_image else steps // 5 #@param{type: 'integer'}\n",
|
|
|
+"#@markdown *Make sure you set skip_steps to ~50% of your steps if you want to use an init image.*\n",
|
|
|
+"\n",
|
|
|
+"#Get corrected sizes\n",
|
|
|
+"side_x = (width_height[0]//64)*64;\n",
|
|
|
+"side_y = (width_height[1]//64)*64;\n",
|
|
|
+"if side_x != width_height[0] or side_y != width_height[1]:\n",
|
|
|
+" print(f'Changing output size to {side_x}x{side_y}. Dimensions must by multiples of 64.')\n",
|
|
|
+"\n",
|
|
|
+"#Update Model Settings\n",
|
|
|
+"timestep_respacing = f'ddim{steps}'\n",
|
|
|
+"diffusion_steps = (1000//steps)*steps if steps < 1000 else steps\n",
|
|
|
+"model_config.update({\n",
|
|
|
+" 'timestep_respacing': timestep_respacing,\n",
|
|
|
+" 'diffusion_steps': diffusion_steps,\n",
|
|
|
+"})\n",
|
|
|
+"\n",
|
|
|
+"#Make folder for batch\n",
|
|
|
+"batchFolder = f'{outDirPath}/{batch_name}'\n",
|
|
|
+"createPath(batchFolder)\n"
|
|
|
+]
|
|
|
+},
|
|
|
+{
|
|
|
+"cell_type": "markdown",
|
|
|
+"metadata": {
|
|
|
+"id": "CnkTNXJAPzL2"
|
|
|
+},
|
|
|
+"source": [
|
|
|
+"###Animation Settings"
|
|
|
+]
|
|
|
+},
|
|
|
+{
|
|
|
+"cell_type": "code",
|
|
|
+"execution_count": null,
|
|
|
+"metadata": {
|
|
|
+"cellView": "form",
|
|
|
+"id": "djPY2_4kHgV2"
|
|
|
+},
|
|
|
+"outputs": [],
|
|
|
+"source": [
|
|
|
+"#@markdown ####**Animation Mode:**\n",
|
|
|
+"animation_mode = '3D' #@param ['None', '2D', '3D', 'Video Input'] {type:'string'}\n",
|
|
|
+"#@markdown *For animation, you probably want to turn `cutn_batches` to 1 to make it quicker.*\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"#@markdown ---\n",
|
|
|
+"\n",
|
|
|
+"#@markdown ####**Video Input Settings:**\n",
|
|
|
+"if is_colab:\n",
|
|
|
+" video_init_path = \"/content/training.mp4\" #@param {type: 'string'}\n",
|
|
|
+"else:\n",
|
|
|
+" video_init_path = \"training.mp4\" #@param {type: 'string'}\n",
|
|
|
+"extract_nth_frame = 2 #@param {type:\"number\"} \n",
|
|
|
+"\n",
|
|
|
+"if animation_mode == \"Video Input\":\n",
|
|
|
+" if is_colab:\n",
|
|
|
+" videoFramesFolder = f'/content/videoFrames'\n",
|
|
|
+" else:\n",
|
|
|
+" videoFramesFolder = f'videoFrames'\n",
|
|
|
+" createPath(videoFramesFolder)\n",
|
|
|
+" print(f\"Exporting Video Frames (1 every {extract_nth_frame})...\")\n",
|
|
|
+" try:\n",
|
|
|
+" !rm {videoFramesFolder}/*.jpg\n",
|
|
|
+" except:\n",
|
|
|
+" print('')\n",
|
|
|
+" vf = f'\"select=not(mod(n\\,{extract_nth_frame}))\"'\n",
|
|
|
+" !ffmpeg -i {video_init_path} -vf {vf} -vsync vfr -q:v 2 -loglevel error -stats {videoFramesFolder}/%04d.jpg\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"#@markdown ---\n",
|
|
|
+"\n",
|
|
|
+"#@markdown ####**2D/3D Animation Settings:**\n",
|
|
|
+"#@markdown `zoom` is a multiplier of dimensions, 1 is no zoom.\n",
|
|
|
+"\n",
|
|
|
+"key_frames = True #@param {type:\"boolean\"}\n",
|
|
|
+"max_frames = 100000#@param {type:\"number\"}\n",
|
|
|
+"\n",
|
|
|
+"if animation_mode == \"Video Input\":\n",
|
|
|
+" max_frames = len(glob(f'{videoFramesFolder}/*.jpg'))\n",
|
|
|
+"\n",
|
|
|
+"interp_spline = 'Linear' #Do not change, currently will not look good. param ['Linear','Quadratic','Cubic']{type:\"string\"}\n",
|
|
|
+"angle = \"0:(0)\"#@param {type:\"string\"}\n",
|
|
|
+"zoom = \"0: (1)\"#@param {type:\"string\"}\n",
|
|
|
+"translation_x = \"0:(0.1111)\"#\"0:(0),22:(4.465),41:(0.355),61:(1.163),69:(-1.358),85:(0.079),107:(-0.843),116:(-4.123),136:(1.029),157:(1.074),166:(-3.439),187:(-0.214),209:(0.357),219:(-4.708),239:(0.49)\"#@param {type:\"string\"}\n",
|
|
|
+"translation_y = \"0:(0)\"#\"0:(0.35), 2000:(1.4)\"#\"0:(0),22:(2.42),41:(-0.019),61:(0.24),69:(-2.381),85:(-0.358),107:(0.097),116:(1.479),136:(0.425),157:(-0.401),166:(-2.366),187:(-0.508),209:(-0.525),219:(0.683),239:(0.351)\"#@param {type:\"string\"}\n",
|
|
|
+"translation_z = \"0:(1)\"#\"0:(2.5), 2000:(10)\"#@param {type:\"string\"}\n",
|
|
|
+"rotation_3d_x = \"0:(0)\"#\"0:(0),22:(0.013),41:(-0.004),61:(-0.001),69:(-0.022),85:(0.005),107:(-0.002),116:(0.026),136:(0.004),157:(0.001),166:(0.027),187:(0.002),209:(-0.005),219:(-0.01),239:(-0.004)\"#@param {type:\"string\"}\n",
|
|
|
+"rotation_3d_y = \"0:(-0.0003)\"#\"0:(0),21:(0.02),38:(0.001),53:(0.001),62:(0.016),82:(-0.004),102:(0.005),113:(0.012),130:(0.006),149:(0.002),159:(0.006),179:(0.005),200:(0.001),210:(-0.002),231:(0.005)\"#@param {type:\"string\"}\n",
|
|
|
+"rotation_3d_z = \"0:(0)\"#\"0:(0),22:(0.007),41:(0.001),61:(0.005),69:(0.014),85:(-0.0),107:(-0.002),116:(0.028),136:(0.0),157:(0.003),166:(0.02),187:(-0.001),209:(-0.004),219:(-0.001),239:(-0.001)\"#@param {type:\"string\"}\n",
|
|
|
+"midas_depth_model = \"dpt_large\"#@param {type:\"string\"}\n",
|
|
|
+"midas_weight = 0.3#@param {type:\"number\"}\n",
|
|
|
+"near_plane = 200#@param {type:\"number\"}\n",
|
|
|
+"far_plane = 10000#@param {type:\"number\"}\n",
|
|
|
+"fov = 40#120#@param {type:\"number\"}\n",
|
|
|
+"padding_mode = 'border'#@param {type:\"string\"}\n",
|
|
|
+"sampling_mode = 'bicubic'#@param {type:\"string\"}\n",
|
|
|
+"#======= TURBO MODE\n",
|
|
|
+"#@markdown ---\n",
|
|
|
+"#@markdown ####**Turbo Mode (3D anim only):**\n",
|
|
|
+"#@markdown (Starts after frame 10,) skips diffusion steps and just uses depth map to warp images for skipped frames.\n",
|
|
|
+"#@markdown Speeds up rendering by 2x-4x, and may improve frame coherence.\n",
|
|
|
+"\n",
|
|
|
+"turbo_mode = True #@param {type:\"boolean\"}\n",
|
|
|
+"turbo_steps = \"3\" #@param [\"2\",\"3\",\"4\",\"5\",\"6\"] {type:'string'}\n",
|
|
|
+"if turbo_mode == True:\n",
|
|
|
+" try:\n",
|
|
|
+" #Make folder for turbo\n",
|
|
|
+" turboFolder = f'{outDirPath}/{batch_name}/turbo'\n",
|
|
|
+" createPath(turboFolder)\n",
|
|
|
+" except OSError:\n",
|
|
|
+" pass # already exists\n",
|
|
|
+"#@markdown ---\n",
|
|
|
+"\n",
|
|
|
+"#@markdown ####**Coherency Settings:**\n",
|
|
|
+"#@markdown `frame_scale` tries to guide the new frame to looking like the old one. A good default is 1500.\n",
|
|
|
+"frames_scale = 1500 #@param{type: 'integer'} Default was 35000\n",
|
|
|
+"#@markdown `frame_skip_steps` will blur the previous frame - higher values will flicker less but struggle to add enough new detail to zoom into.\n",
|
|
|
+"frames_skip_steps = '70%' #@param ['40%', '50%', '60%', '70%', '80%'] {type: 'string'}\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"def parse_key_frames(string, prompt_parser=None):\n",
|
|
|
+" \"\"\"Given a string representing frame numbers paired with parameter values at that frame,\n",
|
|
|
+" return a dictionary with the frame numbers as keys and the parameter values as the values.\n",
|
|
|
+"\n",
|
|
|
+" Parameters\n",
|
|
|
+" ----------\n",
|
|
|
+" string: string\n",
|
|
|
+" Frame numbers paired with parameter values at that frame number, in the format\n",
|
|
|
+" 'framenumber1: (parametervalues1), framenumber2: (parametervalues2), ...'\n",
|
|
|
+" prompt_parser: function or None, optional\n",
|
|
|
+" If provided, prompt_parser will be applied to each string of parameter values.\n",
|
|
|
+" \n",
|
|
|
+" Returns\n",
|
|
|
+" -------\n",
|
|
|
+" dict\n",
|
|
|
+" Frame numbers as keys, parameter values at that frame number as values\n",
|
|
|
+"\n",
|
|
|
+" Raises\n",
|
|
|
+" ------\n",
|
|
|
+" RuntimeError\n",
|
|
|
+" If the input string does not match the expected format.\n",
|
|
|
+" \n",
|
|
|
+" Examples\n",
|
|
|
+" --------\n",
|
|
|
+" >>> parse_key_frames(\"10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)\")\n",
|
|
|
+" {10: 'Apple: 1| Orange: 0', 20: 'Apple: 0| Orange: 1| Peach: 1'}\n",
|
|
|
+"\n",
|
|
|
+" >>> parse_key_frames(\"10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)\", prompt_parser=lambda x: x.lower()))\n",
|
|
|
+" {10: 'apple: 1| orange: 0', 20: 'apple: 0| orange: 1| peach: 1'}\n",
|
|
|
+" \"\"\"\n",
|
|
|
+" import re\n",
|
|
|
+" pattern = r'((?P<frame>[0-9]+):[\\s]*[\\(](?P<param>[\\S\\s]*?)[\\)])'\n",
|
|
|
+" frames = dict()\n",
|
|
|
+" for match_object in re.finditer(pattern, string):\n",
|
|
|
+" frame = int(match_object.groupdict()['frame'])\n",
|
|
|
+" param = match_object.groupdict()['param']\n",
|
|
|
+" if prompt_parser:\n",
|
|
|
+" frames[frame] = prompt_parser(param)\n",
|
|
|
+" else:\n",
|
|
|
+" frames[frame] = param\n",
|
|
|
+"\n",
|
|
|
+" if frames == {} and len(string) != 0:\n",
|
|
|
+" raise RuntimeError('Key Frame string not correctly formatted')\n",
|
|
|
+" return frames\n",
|
|
|
+"\n",
|
|
|
+"def get_inbetweens(key_frames, integer=False):\n",
|
|
|
+" \"\"\"Given a dict with frame numbers as keys and a parameter value as values,\n",
|
|
|
+" return a pandas Series containing the value of the parameter at every frame from 0 to max_frames.\n",
|
|
|
+" Any values not provided in the input dict are calculated by linear interpolation between\n",
|
|
|
+" the values of the previous and next provided frames. If there is no previous provided frame, then\n",
|
|
|
+" the value is equal to the value of the next provided frame, or if there is no next provided frame,\n",
|
|
|
+" then the value is equal to the value of the previous provided frame. If no frames are provided,\n",
|
|
|
+" all frame values are NaN.\n",
|
|
|
+"\n",
|
|
|
+" Parameters\n",
|
|
|
+" ----------\n",
|
|
|
+" key_frames: dict\n",
|
|
|
+" A dict with integer frame numbers as keys and numerical values of a particular parameter as values.\n",
|
|
|
+" integer: Bool, optional\n",
|
|
|
+" If True, the values of the output series are converted to integers.\n",
|
|
|
+" Otherwise, the values are floats.\n",
|
|
|
+" \n",
|
|
|
+" Returns\n",
|
|
|
+" -------\n",
|
|
|
+" pd.Series\n",
|
|
|
+" A Series with length max_frames representing the parameter values for each frame.\n",
|
|
|
+" \n",
|
|
|
+" Examples\n",
|
|
|
+" --------\n",
|
|
|
+" >>> max_frames = 5\n",
|
|
|
+" >>> get_inbetweens({1: 5, 3: 6})\n",
|
|
|
+" 0 5.0\n",
|
|
|
+" 1 5.0\n",
|
|
|
+" 2 5.5\n",
|
|
|
+" 3 6.0\n",
|
|
|
+" 4 6.0\n",
|
|
|
+" dtype: float64\n",
|
|
|
+"\n",
|
|
|
+" >>> get_inbetweens({1: 5, 3: 6}, integer=True)\n",
|
|
|
+" 0 5\n",
|
|
|
+" 1 5\n",
|
|
|
+" 2 5\n",
|
|
|
+" 3 6\n",
|
|
|
+" 4 6\n",
|
|
|
+" dtype: int64\n",
|
|
|
+" \"\"\"\n",
|
|
|
+" key_frame_series = pd.Series([np.nan for a in range(max_frames)])\n",
|
|
|
+"\n",
|
|
|
+" for i, value in key_frames.items():\n",
|
|
|
+" key_frame_series[i] = value\n",
|
|
|
+" key_frame_series = key_frame_series.astype(float)\n",
|
|
|
+" \n",
|
|
|
+" interp_method = interp_spline\n",
|
|
|
+"\n",
|
|
|
+" if interp_method == 'Cubic' and len(key_frames.items()) <=3:\n",
|
|
|
+" interp_method = 'Quadratic'\n",
|
|
|
+" \n",
|
|
|
+" if interp_method == 'Quadratic' and len(key_frames.items()) <= 2:\n",
|
|
|
+" interp_method = 'Linear'\n",
|
|
|
+" \n",
|
|
|
+" \n",
|
|
|
+" key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()]\n",
|
|
|
+" key_frame_series[max_frames-1] = key_frame_series[key_frame_series.last_valid_index()]\n",
|
|
|
+" # key_frame_series = key_frame_series.interpolate(method=intrp_method,order=1, limit_direction='both')\n",
|
|
|
+" key_frame_series = key_frame_series.interpolate(method=interp_method.lower(),limit_direction='both')\n",
|
|
|
+" if integer:\n",
|
|
|
+" return key_frame_series.astype(int)\n",
|
|
|
+" return key_frame_series\n",
|
|
|
+"\n",
|
|
|
+"def split_prompts(prompts):\n",
|
|
|
+" prompt_series = pd.Series([np.nan for a in range(max_frames)])\n",
|
|
|
+" for i, prompt in prompts.items():\n",
|
|
|
+" prompt_series[i] = prompt\n",
|
|
|
+" # prompt_series = prompt_series.astype(str)\n",
|
|
|
+" prompt_series = prompt_series.ffill().bfill()\n",
|
|
|
+" return prompt_series\n",
|
|
|
+"\n",
|
|
|
+"if key_frames:\n",
|
|
|
+" try:\n",
|
|
|
+" angle_series = get_inbetweens(parse_key_frames(angle))\n",
|
|
|
+" except RuntimeError as e:\n",
|
|
|
+" print(\n",
|
|
|
+" \"WARNING: You have selected to use key frames, but you have not \"\n",
|
|
|
+" \"formatted `angle` correctly for key frames.\\n\"\n",
|
|
|
+" \"Attempting to interpret `angle` as \"\n",
|
|
|
+" f'\"0: ({angle})\"\\n'\n",
|
|
|
+" \"Please read the instructions to find out how to use key frames \"\n",
|
|
|
+" \"correctly.\\n\"\n",
|
|
|
+" )\n",
|
|
|
+" angle = f\"0: ({angle})\"\n",
|
|
|
+" angle_series = get_inbetweens(parse_key_frames(angle))\n",
|
|
|
+"\n",
|
|
|
+" try:\n",
|
|
|
+" zoom_series = get_inbetweens(parse_key_frames(zoom))\n",
|
|
|
+" except RuntimeError as e:\n",
|
|
|
+" print(\n",
|
|
|
+" \"WARNING: You have selected to use key frames, but you have not \"\n",
|
|
|
+" \"formatted `zoom` correctly for key frames.\\n\"\n",
|
|
|
+" \"Attempting to interpret `zoom` as \"\n",
|
|
|
+" f'\"0: ({zoom})\"\\n'\n",
|
|
|
+" \"Please read the instructions to find out how to use key frames \"\n",
|
|
|
+" \"correctly.\\n\"\n",
|
|
|
+" )\n",
|
|
|
+" zoom = f\"0: ({zoom})\"\n",
|
|
|
+" zoom_series = get_inbetweens(parse_key_frames(zoom))\n",
|
|
|
+"\n",
|
|
|
+" try:\n",
|
|
|
+" translation_x_series = get_inbetweens(parse_key_frames(translation_x))\n",
|
|
|
+" except RuntimeError as e:\n",
|
|
|
+" print(\n",
|
|
|
+" \"WARNING: You have selected to use key frames, but you have not \"\n",
|
|
|
+" \"formatted `translation_x` correctly for key frames.\\n\"\n",
|
|
|
+" \"Attempting to interpret `translation_x` as \"\n",
|
|
|
+" f'\"0: ({translation_x})\"\\n'\n",
|
|
|
+" \"Please read the instructions to find out how to use key frames \"\n",
|
|
|
+" \"correctly.\\n\"\n",
|
|
|
+" )\n",
|
|
|
+" translation_x = f\"0: ({translation_x})\"\n",
|
|
|
+" translation_x_series = get_inbetweens(parse_key_frames(translation_x))\n",
|
|
|
+"\n",
|
|
|
+" try:\n",
|
|
|
+" translation_y_series = get_inbetweens(parse_key_frames(translation_y))\n",
|
|
|
+" except RuntimeError as e:\n",
|
|
|
+" print(\n",
|
|
|
+" \"WARNING: You have selected to use key frames, but you have not \"\n",
|
|
|
+" \"formatted `translation_y` correctly for key frames.\\n\"\n",
|
|
|
+" \"Attempting to interpret `translation_y` as \"\n",
|
|
|
+" f'\"0: ({translation_y})\"\\n'\n",
|
|
|
+" \"Please read the instructions to find out how to use key frames \"\n",
|
|
|
+" \"correctly.\\n\"\n",
|
|
|
+" )\n",
|
|
|
+" translation_y = f\"0: ({translation_y})\"\n",
|
|
|
+" translation_y_series = get_inbetweens(parse_key_frames(translation_y))\n",
|
|
|
+"\n",
|
|
|
+" try:\n",
|
|
|
+" translation_z_series = get_inbetweens(parse_key_frames(translation_z))\n",
|
|
|
+" except RuntimeError as e:\n",
|
|
|
+" print(\n",
|
|
|
+" \"WARNING: You have selected to use key frames, but you have not \"\n",
|
|
|
+" \"formatted `translation_z` correctly for key frames.\\n\"\n",
|
|
|
+" \"Attempting to interpret `translation_z` as \"\n",
|
|
|
+" f'\"0: ({translation_z})\"\\n'\n",
|
|
|
+" \"Please read the instructions to find out how to use key frames \"\n",
|
|
|
+" \"correctly.\\n\"\n",
|
|
|
+" )\n",
|
|
|
+" translation_z = f\"0: ({translation_z})\"\n",
|
|
|
+" translation_z_series = get_inbetweens(parse_key_frames(translation_z))\n",
|
|
|
+"\n",
|
|
|
+" try:\n",
|
|
|
+" rotation_3d_x_series = get_inbetweens(parse_key_frames(rotation_3d_x))\n",
|
|
|
+" except RuntimeError as e:\n",
|
|
|
+" print(\n",
|
|
|
+" \"WARNING: You have selected to use key frames, but you have not \"\n",
|
|
|
+" \"formatted `rotation_3d_x` correctly for key frames.\\n\"\n",
|
|
|
+" \"Attempting to interpret `rotation_3d_x` as \"\n",
|
|
|
+" f'\"0: ({rotation_3d_x})\"\\n'\n",
|
|
|
+" \"Please read the instructions to find out how to use key frames \"\n",
|
|
|
+" \"correctly.\\n\"\n",
|
|
|
+" )\n",
|
|
|
+" rotation_3d_x = f\"0: ({rotation_3d_x})\"\n",
|
|
|
+" rotation_3d_x_series = get_inbetweens(parse_key_frames(rotation_3d_x))\n",
|
|
|
+"\n",
|
|
|
+" try:\n",
|
|
|
+" rotation_3d_y_series = get_inbetweens(parse_key_frames(rotation_3d_y))\n",
|
|
|
+" except RuntimeError as e:\n",
|
|
|
+" print(\n",
|
|
|
+" \"WARNING: You have selected to use key frames, but you have not \"\n",
|
|
|
+" \"formatted `rotation_3d_y` correctly for key frames.\\n\"\n",
|
|
|
+" \"Attempting to interpret `rotation_3d_y` as \"\n",
|
|
|
+" f'\"0: ({rotation_3d_y})\"\\n'\n",
|
|
|
+" \"Please read the instructions to find out how to use key frames \"\n",
|
|
|
+" \"correctly.\\n\"\n",
|
|
|
+" )\n",
|
|
|
+" rotation_3d_y = f\"0: ({rotation_3d_y})\"\n",
|
|
|
+" rotation_3d_y_series = get_inbetweens(parse_key_frames(rotation_3d_y))\n",
|
|
|
+"\n",
|
|
|
+" try:\n",
|
|
|
+" rotation_3d_z_series = get_inbetweens(parse_key_frames(rotation_3d_z))\n",
|
|
|
+" except RuntimeError as e:\n",
|
|
|
+" print(\n",
|
|
|
+" \"WARNING: You have selected to use key frames, but you have not \"\n",
|
|
|
+" \"formatted `rotation_3d_z` correctly for key frames.\\n\"\n",
|
|
|
+" \"Attempting to interpret `rotation_3d_z` as \"\n",
|
|
|
+" f'\"0: ({rotation_3d_z})\"\\n'\n",
|
|
|
+" \"Please read the instructions to find out how to use key frames \"\n",
|
|
|
+" \"correctly.\\n\"\n",
|
|
|
+" )\n",
|
|
|
+" rotation_3d_z = f\"0: ({rotation_3d_z})\"\n",
|
|
|
+" rotation_3d_z_series = get_inbetweens(parse_key_frames(rotation_3d_z))\n",
|
|
|
+"\n",
|
|
|
+"else:\n",
|
|
|
+" angle = float(angle)\n",
|
|
|
+" zoom = float(zoom)\n",
|
|
|
+" translation_x = float(translation_x)\n",
|
|
|
+" translation_y = float(translation_y)\n",
|
|
|
+" translation_z = float(translation_z)\n",
|
|
|
+" rotation_3d_x = float(rotation_3d_x)\n",
|
|
|
+" rotation_3d_y = float(rotation_3d_y)\n",
|
|
|
+" rotation_3d_z = float(rotation_3d_z)"
|
|
|
+]
|
|
|
+},
|
|
|
+{
|
|
|
+"cell_type": "markdown",
|
|
|
+"metadata": {
|
|
|
+"id": "u1VHzHvNx5fd"
|
|
|
+},
|
|
|
+"source": [
|
|
|
+"### Extra Settings\n",
|
|
|
+" Partial Saves, Diffusion Sharpening, Advanced Settings, Cutn Scheduling"
|
|
|
+]
|
|
|
+},
|
|
|
+{
|
|
|
+"cell_type": "code",
|
|
|
+"execution_count": null,
|
|
|
+"metadata": {
|
|
|
+"cellView": "form",
|
|
|
+"id": "lCLMxtILyAHA"
|
|
|
+},
|
|
|
+"outputs": [],
|
|
|
+"source": [
|
|
|
+"#@markdown ####**Saving:**\n",
|
|
|
+"\n",
|
|
|
+"intermediate_saves = 0#@param{type: 'raw'}\n",
|
|
|
+"intermediates_in_subfolder = True #@param{type: 'boolean'}\n",
|
|
|
+"#@markdown Intermediate steps will save a copy at your specified intervals. You can either format it as a single integer or a list of specific steps \n",
|
|
|
+"\n",
|
|
|
+"#@markdown A value of `2` will save a copy at 33% and 66%. 0 will save none.\n",
|
|
|
+"\n",
|
|
|
+"#@markdown A value of `[5, 9, 34, 45]` will save at steps 5, 9, 34, and 45. (Make sure to include the brackets)\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"if type(intermediate_saves) is not list:\n",
|
|
|
+" if intermediate_saves:\n",
|
|
|
+" steps_per_checkpoint = math.floor((steps - skip_steps - 1) // (intermediate_saves+1))\n",
|
|
|
+" steps_per_checkpoint = steps_per_checkpoint if steps_per_checkpoint > 0 else 1\n",
|
|
|
+" print(f'Will save every {steps_per_checkpoint} steps')\n",
|
|
|
+" else:\n",
|
|
|
+" steps_per_checkpoint = steps+10\n",
|
|
|
+"else:\n",
|
|
|
+" steps_per_checkpoint = None\n",
|
|
|
+"\n",
|
|
|
+"if intermediate_saves and intermediates_in_subfolder is True:\n",
|
|
|
+" partialFolder = f'{batchFolder}/partials'\n",
|
|
|
+" createPath(partialFolder)\n",
|
|
|
+"\n",
|
|
|
+" #@markdown ---\n",
|
|
|
+"\n",
|
|
|
+"#@markdown ####**SuperRes Sharpening:**\n",
|
|
|
+"#@markdown *Sharpen each image using latent-diffusion. Does not run in animation mode. `keep_unsharp` will save both versions.*\n",
|
|
|
+"sharpen_preset = 'Off' #@param ['Off', 'Faster', 'Fast', 'Slow', 'Very Slow']\n",
|
|
|
+"keep_unsharp = True #@param{type: 'boolean'}\n",
|
|
|
+"\n",
|
|
|
+"if sharpen_preset != 'Off' and keep_unsharp is True:\n",
|
|
|
+" unsharpenFolder = f'{batchFolder}/unsharpened'\n",
|
|
|
+" createPath(unsharpenFolder)\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+" #@markdown ---\n",
|
|
|
+"\n",
|
|
|
+"#@markdown ####**Advanced Settings:**\n",
|
|
|
+"#@markdown *There are a few extra advanced settings available if you double click this cell.*\n",
|
|
|
+"\n",
|
|
|
+"#@markdown *Perlin init will replace your init, so uncheck if using one.*\n",
|
|
|
+"\n",
|
|
|
+"perlin_init = False #@param{type: 'boolean'}\n",
|
|
|
+"perlin_mode = 'mixed' #@param ['mixed', 'color', 'gray']\n",
|
|
|
+"set_seed = 'random_seed' #@param{type: 'string'}\n",
|
|
|
+"eta = 0.2#@param{type: 'number'}\n",
|
|
|
+"clamp_grad = True #@param{type: 'boolean'}\n",
|
|
|
+"clamp_max = 0.25 #@param{type: 'number'}\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"### EXTRA ADVANCED SETTINGS:\n",
|
|
|
+"randomize_class = True\n",
|
|
|
+"clip_denoised = False\n",
|
|
|
+"fuzzy_prompt = False\n",
|
|
|
+"rand_mag = 0.1\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+" #@markdown ---\n",
|
|
|
+"\n",
|
|
|
+"#@markdown ####**Cutn Scheduling:**\n",
|
|
|
+"#@markdown Format: `[40]*400+[20]*600` = 40 cuts for the first 400 /1000 steps, then 20 for the last 600/1000\n",
|
|
|
+"\n",
|
|
|
+"#@markdown cut_overview and cut_innercut are cumulative for total cutn on any given step. Overview cuts see the entire image and are good for early structure, innercuts are your standard cutn.\n",
|
|
|
+"# default overview = \"[8]*30+[0]*297000, innercut = \"[8]*30+[32]*297000\"\n",
|
|
|
+"cut_overview = \"[16]*30+[8]*297000\" #@param {type: 'string'} #\"[8]*30+[0]*2970\" #@param {type: 'string'} \n",
|
|
|
+"cut_innercut = \"[8]*30+[24]*297000\"#@param {type: 'string'} #\"[8]*30+[32]*2970\"#@param {type: 'string'} \n",
|
|
|
+"cut_ic_pow = 1#@param {type: 'number'} \n",
|
|
|
+"cut_icgray_p = \"[0.2]*30+[0]*2970\"#@param {type: 'string'} \n",
|
|
|
+"\n"
|
|
|
+]
|
|
|
+},
|
|
|
+{
|
|
|
+"cell_type": "markdown",
|
|
|
+"metadata": {
|
|
|
+"id": "XIwh5RvNpk4K"
|
|
|
+},
|
|
|
+"source": [
|
|
|
+"###Prompts\n",
|
|
|
+"`animation_mode: None` will only use the first set. `animation_mode: 2D / Video` will run through them per the set frames and hold on the last one."
|
|
|
+]
|
|
|
+},
|
|
|
+{
|
|
|
+"cell_type": "code",
|
|
|
+"execution_count": null,
|
|
|
+"metadata": {
|
|
|
+"id": "BGBzhk3dpcGO"
|
|
|
+},
|
|
|
+"outputs": [],
|
|
|
+"source": [
|
|
|
+"text_prompts = {\n",
|
|
|
+" \n",
|
|
|
+" 0: [\"The most majestic being ever spotted in the wild\", \"trending on artstation\"],\n",
|
|
|
+" #100: [\"This set of prompts start at frame 100\",\"This prompt has weight five:5\"],\n",
|
|
|
+"}\n",
|
|
|
+"\n",
|
|
|
+"image_prompts = {\n",
|
|
|
+" # 0:['ImagePromptsWorkButArentVeryGood.png:2',],\n",
|
|
|
+"}"
|
|
|
+]
|
|
|
+},
|
|
|
+{
|
|
|
+"cell_type": "markdown",
|
|
|
+"metadata": {
|
|
|
+"id": "Nf9hTc8YLoLx"
|
|
|
+},
|
|
|
+"source": [
|
|
|
+"# 4. Diffuse!"
|
|
|
+]
|
|
|
+},
|
|
|
+{
|
|
|
+"cell_type": "code",
|
|
|
+"execution_count": null,
|
|
|
+"metadata": {
|
|
|
+"cellView": "form",
|
|
|
+"id": "LHLiO56OfwgD",
|
|
|
+"scrolled": false
|
|
|
+},
|
|
|
+"outputs": [],
|
|
|
+"source": [
|
|
|
+"#@title Do the Run!\n",
|
|
|
+"#@markdown `n_batches` ignored with animation modes.\n",
|
|
|
+"display_rate = 40 #@param{type: 'number'}\n",
|
|
|
+"n_batches = 1 #@param{type: 'number'}\n",
|
|
|
+"\n",
|
|
|
+"batch_size = 1\n",
|
|
|
+"\n",
|
|
|
+"def move_files(start_num, end_num, old_folder, new_folder):\n",
|
|
|
+" for i in range(start_num, end_num):\n",
|
|
|
+" old_file = old_folder + f'/{batch_name}({batchNum})_{i:04}.png'\n",
|
|
|
+" new_file = new_folder + f'/{batch_name}({batchNum})_{i:04}.png'\n",
|
|
|
+" os.rename(old_file, new_file)\n",
|
|
|
+"\n",
|
|
|
+"#@markdown ---\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"resume_run = False #@param{type: 'boolean'}\n",
|
|
|
+"run_to_resume = 'latest' #@param{type: 'string'}\n",
|
|
|
+"resume_from_frame = 'latest' #@param{type: 'string'}\n",
|
|
|
+"retain_overwritten_frames = False #@param{type: 'boolean'}\n",
|
|
|
+"if retain_overwritten_frames is True:\n",
|
|
|
+" retainFolder = f'{batchFolder}/retained'\n",
|
|
|
+" createPath(retainFolder)\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"skip_step_ratio = int(frames_skip_steps.rstrip(\"%\")) / 100\n",
|
|
|
+"calc_frames_skip_steps = math.floor(steps * skip_step_ratio)\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+"if steps <= calc_frames_skip_steps:\n",
|
|
|
+" sys.exit(\"ERROR: You can't skip more steps than your total steps\")\n",
|
|
|
+"\n",
|
|
|
+"if resume_run:\n",
|
|
|
+" if run_to_resume == 'latest':\n",
|
|
|
+" try:\n",
|
|
|
+" batchNum\n",
|
|
|
+" except:\n",
|
|
|
+" batchNum = len(glob(f\"{batchFolder}/{batch_name}(*)_settings.txt\"))-1\n",
|
|
|
+" else:\n",
|
|
|
+" batchNum = int(run_to_resume)\n",
|
|
|
+" if resume_from_frame == 'latest':\n",
|
|
|
+" start_frame = len(glob(batchFolder+f\"/{batch_name}({batchNum})_*.png\"))\n",
|
|
|
+" else:\n",
|
|
|
+" start_frame = int(resume_from_frame)+1\n",
|
|
|
+" if retain_overwritten_frames is True:\n",
|
|
|
+" existing_frames = len(glob(batchFolder+f\"/{batch_name}({batchNum})_*.png\"))\n",
|
|
|
+" frames_to_save = existing_frames - start_frame\n",
|
|
|
+" print(f'Moving {frames_to_save} frames to the Retained folder')\n",
|
|
|
+" move_files(start_frame, existing_frames, batchFolder, retainFolder)\n",
|
|
|
+"else:\n",
|
|
|
+" start_frame = 0\n",
|
|
|
+" batchNum = len(glob(batchFolder+\"/*.txt\"))\n",
|
|
|
+" while path.isfile(f\"{batchFolder}/{batch_name}({batchNum})_settings.txt\") is True or path.isfile(f\"{batchFolder}/{batch_name}-{batchNum}_settings.txt\") is True:\n",
|
|
|
+" batchNum += 1\n",
|
|
|
+"\n",
|
|
|
+"print(f'Starting Run: {batch_name}({batchNum}) at frame {start_frame}')\n",
|
|
|
+"\n",
|
|
|
+"if set_seed == 'random_seed':\n",
|
|
|
+" random.seed()\n",
|
|
|
+" seed = random.randint(0, 2**32)\n",
|
|
|
+" # print(f'Using seed: {seed}')\n",
|
|
|
+"else:\n",
|
|
|
+" seed = int(set_seed)\n",
|
|
|
+"\n",
|
|
|
+"args = {\n",
|
|
|
+" 'batchNum': batchNum,\n",
|
|
|
+" 'prompts_series':split_prompts(text_prompts) if text_prompts else None,\n",
|
|
|
+" 'image_prompts_series':split_prompts(image_prompts) if image_prompts else None,\n",
|
|
|
+" 'seed': seed,\n",
|
|
|
+" 'display_rate':display_rate,\n",
|
|
|
+" 'n_batches':n_batches if animation_mode == 'None' else 1,\n",
|
|
|
+" 'batch_size':batch_size,\n",
|
|
|
+" 'batch_name': batch_name,\n",
|
|
|
+" 'steps': steps,\n",
|
|
|
+" 'width_height': width_height,\n",
|
|
|
+" 'clip_guidance_scale': clip_guidance_scale,\n",
|
|
|
+" 'tv_scale': tv_scale,\n",
|
|
|
+" 'range_scale': range_scale,\n",
|
|
|
+" 'sat_scale': sat_scale,\n",
|
|
|
+" 'cutn_batches': cutn_batches,\n",
|
|
|
+" 'init_image': init_image,\n",
|
|
|
+" 'init_scale': init_scale,\n",
|
|
|
+" 'skip_steps': skip_steps,\n",
|
|
|
+" 'sharpen_preset': sharpen_preset,\n",
|
|
|
+" 'keep_unsharp': keep_unsharp,\n",
|
|
|
+" 'side_x': side_x,\n",
|
|
|
+" 'side_y': side_y,\n",
|
|
|
+" 'timestep_respacing': timestep_respacing,\n",
|
|
|
+" 'diffusion_steps': diffusion_steps,\n",
|
|
|
+" 'animation_mode': animation_mode,\n",
|
|
|
+" 'video_init_path': video_init_path,\n",
|
|
|
+" 'extract_nth_frame': extract_nth_frame,\n",
|
|
|
+" 'key_frames': key_frames,\n",
|
|
|
+" 'max_frames': max_frames if animation_mode != \"None\" else 1,\n",
|
|
|
+" 'interp_spline': interp_spline,\n",
|
|
|
+" 'start_frame': start_frame,\n",
|
|
|
+" 'angle': angle,\n",
|
|
|
+" 'zoom': zoom,\n",
|
|
|
+" 'translation_x': translation_x,\n",
|
|
|
+" 'translation_y': translation_y,\n",
|
|
|
+" 'translation_z': translation_z,\n",
|
|
|
+" 'rotation_3d_x': rotation_3d_x,\n",
|
|
|
+" 'rotation_3d_y': rotation_3d_y,\n",
|
|
|
+" 'rotation_3d_z': rotation_3d_z,\n",
|
|
|
+" 'midas_depth_model': midas_depth_model,\n",
|
|
|
+" 'midas_weight': midas_weight,\n",
|
|
|
+" 'near_plane': near_plane,\n",
|
|
|
+" 'far_plane': far_plane,\n",
|
|
|
+" 'fov': fov,\n",
|
|
|
+" 'padding_mode': padding_mode,\n",
|
|
|
+" 'sampling_mode': sampling_mode,\n",
|
|
|
+" 'angle_series':angle_series,\n",
|
|
|
+" 'zoom_series':zoom_series,\n",
|
|
|
+" 'translation_x_series':translation_x_series,\n",
|
|
|
+" 'translation_y_series':translation_y_series,\n",
|
|
|
+" 'translation_z_series':translation_z_series,\n",
|
|
|
+" 'rotation_3d_x_series':rotation_3d_x_series,\n",
|
|
|
+" 'rotation_3d_y_series':rotation_3d_y_series,\n",
|
|
|
+" 'rotation_3d_z_series':rotation_3d_z_series,\n",
|
|
|
+" 'frames_scale': frames_scale,\n",
|
|
|
+" 'calc_frames_skip_steps': calc_frames_skip_steps,\n",
|
|
|
+" 'skip_step_ratio': skip_step_ratio,\n",
|
|
|
+" 'calc_frames_skip_steps': calc_frames_skip_steps,\n",
|
|
|
+" 'text_prompts': text_prompts,\n",
|
|
|
+" 'image_prompts': image_prompts,\n",
|
|
|
+" 'cut_overview': eval(cut_overview),\n",
|
|
|
+" 'cut_innercut': eval(cut_innercut),\n",
|
|
|
+" 'cut_ic_pow': cut_ic_pow,\n",
|
|
|
+" 'cut_icgray_p': eval(cut_icgray_p),\n",
|
|
|
+" 'intermediate_saves': intermediate_saves,\n",
|
|
|
+" 'intermediates_in_subfolder': intermediates_in_subfolder,\n",
|
|
|
+" 'steps_per_checkpoint': steps_per_checkpoint,\n",
|
|
|
+" 'perlin_init': perlin_init,\n",
|
|
|
+" 'perlin_mode': perlin_mode,\n",
|
|
|
+" 'set_seed': set_seed,\n",
|
|
|
+" 'eta': eta,\n",
|
|
|
+" 'clamp_grad': clamp_grad,\n",
|
|
|
+" 'clamp_max': clamp_max,\n",
|
|
|
+" 'skip_augs': skip_augs,\n",
|
|
|
+" 'randomize_class': randomize_class,\n",
|
|
|
+" 'clip_denoised': clip_denoised,\n",
|
|
|
+" 'fuzzy_prompt': fuzzy_prompt,\n",
|
|
|
+" 'rand_mag': rand_mag,\n",
|
|
|
+"}\n",
|
|
|
+"\n",
|
|
|
+"args = SimpleNamespace(**args)\n",
|
|
|
+"\n",
|
|
|
+"print('Prepping model...')\n",
|
|
|
+"model, diffusion = create_model_and_diffusion(**model_config)\n",
|
|
|
+"diffusion_model_path = f'{model_path}/{diffusion_model}.pt'\n",
|
|
|
+"print(diffusion_model_path)\n",
|
|
|
+"model.load_state_dict(torch.load(diffusion_model_path, map_location='cpu'))\n",
|
|
|
+"model.requires_grad_(False).eval().to(device)\n",
|
|
|
+"for name, param in model.named_parameters():\n",
|
|
|
+" if 'qkv' in name or 'norm' in name or 'proj' in name:\n",
|
|
|
+" param.requires_grad_()\n",
|
|
|
+"if model_config['use_fp16']:\n",
|
|
|
+" model.convert_to_fp16()\n",
|
|
|
+"\n",
|
|
|
+"gc.collect()\n",
|
|
|
+"torch.cuda.empty_cache()\n",
|
|
|
+"try:\n",
|
|
|
+" do_run()\n",
|
|
|
+"except KeyboardInterrupt:\n",
|
|
|
+" pass\n",
|
|
|
+"finally:\n",
|
|
|
+" print('Seed used:', seed)\n",
|
|
|
+" gc.collect()\n",
|
|
|
+" torch.cuda.empty_cache()"
|
|
|
+]
|
|
|
+},
|
|
|
+{
|
|
|
+"cell_type": "markdown",
|
|
|
+"metadata": {
|
|
|
+"id": "EZUg3bfzazgW"
|
|
|
+},
|
|
|
+"source": [
|
|
|
+"# 5. Create the video (CURRENTLY NOT WORKING)"
|
|
|
+]
|
|
|
+},
|
|
|
+{
|
|
|
+"cell_type": "code",
|
|
|
+"execution_count": null,
|
|
|
+"metadata": {
|
|
|
+"cellView": "form",
|
|
|
+"id": "HV54fuU3pMzJ"
|
|
|
+},
|
|
|
+"outputs": [],
|
|
|
+"source": [
|
|
|
+"# @title ### **Create video**\n",
|
|
|
+"#@markdown Video file will save in the same folder as your images.\n",
|
|
|
+"\n",
|
|
|
+"skip_video_for_run_all = False #@param {type: 'boolean'}\n",
|
|
|
+"\n",
|
|
|
+"if skip_video_for_run_all == True:\n",
|
|
|
+" print('Skipping video creation, uncheck skip_video_for_run_all if you want to run it')\n",
|
|
|
+"\n",
|
|
|
+"else:\n",
|
|
|
+" # import subprocess in case this cell is run without the above cells\n",
|
|
|
+" import subprocess\n",
|
|
|
+" from base64 import b64encode\n",
|
|
|
+"\n",
|
|
|
+" latest_run = batchNum\n",
|
|
|
+"\n",
|
|
|
+" folder = batch_name #@param\n",
|
|
|
+" run = latest_run #@param\n",
|
|
|
+" final_frame = 'final_frame'\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+" init_frame = 1#@param {type:\"number\"} This is the frame where the video will start\n",
|
|
|
+" last_frame = final_frame#@param {type:\"number\"} You can change i to the number of the last frame you want to generate. It will raise an error if that number of frames does not exist.\n",
|
|
|
+" fps = 12#@param {type:\"number\"}\n",
|
|
|
+" # view_video_in_cell = True #@param {type: 'boolean'}\n",
|
|
|
+"\n",
|
|
|
+" frames = []\n",
|
|
|
+" # tqdm.write('Generating video...')\n",
|
|
|
+"\n",
|
|
|
+" if last_frame == 'final_frame':\n",
|
|
|
+" last_frame = len(glob(batchFolder+f\"/{folder}({run})_*.png\"))\n",
|
|
|
+" print(f'Total frames: {last_frame}')\n",
|
|
|
+"\n",
|
|
|
+" image_path = f\"{outDirPath}/{folder}/{folder}({run})_%04d.png\"\n",
|
|
|
+" filepath = f\"{outDirPath}/{folder}/{folder}({run}).mp4\"\n",
|
|
|
+"\n",
|
|
|
+"\n",
|
|
|
+" cmd = [\n",
|
|
|
+" 'ffmpeg',\n",
|
|
|
+" '-y',\n",
|
|
|
+" '-vcodec',\n",
|
|
|
+" 'png',\n",
|
|
|
+" '-r',\n",
|
|
|
+" str(fps),\n",
|
|
|
+" '-start_number',\n",
|
|
|
+" str(init_frame),\n",
|
|
|
+" #'\"' + str(init_frame) + '\"',\n",
|
|
|
+" '-i',\n",
|
|
|
+" '\"' + image_path + '\"',\n",
|
|
|
+" '-frames:v',\n",
|
|
|
+" str(last_frame+1),\n",
|
|
|
+" '-c:v',\n",
|
|
|
+" 'libx264',\n",
|
|
|
+" '-vf',\n",
|
|
|
+" f'fps={fps}',\n",
|
|
|
+" '-pix_fmt',\n",
|
|
|
+" 'yuv420p',\n",
|
|
|
+" '-crf',\n",
|
|
|
+" '17',\n",
|
|
|
+" '-preset',\n",
|
|
|
+" 'veryslow',\n",
|
|
|
+" '\"' + filepath + '\"'\n",
|
|
|
+" ]\n",
|
|
|
+"\n",
|
|
|
+" print(\"Going to run this command:\")\n",
|
|
|
+" print(\" \".join(cmd))\n",
|
|
|
+"\n",
|
|
|
+" process = subprocess.Popen(cmd, cwd=f'{batchFolder}', stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n",
|
|
|
+" stdout, stderr = process.communicate()\n",
|
|
|
+" if process.returncode != 0:\n",
|
|
|
+" print(stderr)\n",
|
|
|
+" raise RuntimeError(stderr)\n",
|
|
|
+" else:\n",
|
|
|
+" print(\"The video is ready and saved to the images folder\")\n",
|
|
|
+"\n",
|
|
|
+" # if view_video_in_cell:\n",
|
|
|
+" # mp4 = open(filepath,'rb').read()\n",
|
|
|
+" # data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
|
|
|
+" # display.HTML(f'<video width=400 controls><source src=\"{data_url}\" type=\"video/mp4\"></video>')"
|
|
|
+]
|
|
|
+}
|
|
|
+],
|
|
|
+"metadata": {
|
|
|
+"accelerator": "GPU",
|
|
|
+"colab": {
|
|
|
+"collapsed_sections": [
|
|
|
+"1YwMUyt9LHG1",
|
|
|
+"XTu6AjLyFQUq",
|
|
|
+"CQVtY1Ixnqx4",
|
|
|
+"XIwh5RvNpk4K",
|
|
|
+"EZUg3bfzazgW"
|
|
|
+],
|
|
|
+"machine_shape": "hm",
|
|
|
+"name": "Disco Diffusion v5 Turbo [w/ 3D animation]",
|
|
|
+"private_outputs": true,
|
|
|
+"provenance": []
|
|
|
+},
|
|
|
+"kernelspec": {
|
|
|
+"display_name": "Python 3 (ipykernel)",
|
|
|
+"language": "python",
|
|
|
+"name": "python3"
|
|
|
+},
|
|
|
+"language_info": {
|
|
|
+"codemirror_mode": {
|
|
|
+"name": "ipython",
|
|
|
+"version": 3
|
|
|
+},
|
|
|
+"file_extension": ".py",
|
|
|
+"mimetype": "text/x-python",
|
|
|
+"name": "python",
|
|
|
+"nbconvert_exporter": "python",
|
|
|
+"pygments_lexer": "ipython3",
|
|
|
+"version": "3.9.10"
|
|
|
+}
|
|
|
+},
|
|
|
+"nbformat": 4,
|
|
|
+"nbformat_minor": 1
|
|
|
}
|