disco_ovh.py 117 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979
  1. # -*- coding: utf-8 -*-
  2. """Copie de Disco Diffusion v5.1 [w/ Turbo]
  3. Automatically generated by Colaboratory.
  4. Original file is located at
  5. https://colab.research.google.com/drive/11dX8Ve_UQ45_sg-Y02sT3M7nbVOrBllB
  6. # Disco Diffusion v5.1 - Now with Turbo
  7. In case of confusion, Disco is the name of this notebook edit. The diffusion model in use is Katherine Crowson's fine-tuned 512x512 model
  8. For issues, join the [Disco Diffusion Discord](https://discord.gg/msEZBy4HxA) or message us on twitter at [@somnai_dreams](https://twitter.com/somnai_dreams) or [@gandamu](https://twitter.com/gandamu_ml)
  9. ### Credits & Changelog ⬇️
  10. #### Credits
  11. Original notebook by Katherine Crowson (https://github.com/crowsonkb, https://twitter.com/RiversHaveWings). It uses either OpenAI's 256x256 unconditional ImageNet or Katherine Crowson's fine-tuned 512x512 diffusion model (https://github.com/openai/guided-diffusion), together with CLIP (https://github.com/openai/CLIP) to connect text prompts with images.
  12. Modified by Daniel Russell (https://github.com/russelldc, https://twitter.com/danielrussruss) to include (hopefully) optimal params for quick generations in 15-100 timesteps rather than 1000, as well as more robust augmentations.
  13. Further improvements from Dango233 and nsheppard helped improve the quality of diffusion in general, and especially so for shorter runs like this notebook aims to achieve.
  14. Vark added code to load in multiple Clip models at once, which all prompts are evaluated against, which may greatly improve accuracy.
  15. The latest zoom, pan, rotation, and keyframes features were taken from Chigozie Nri's VQGAN Zoom Notebook (https://github.com/chigozienri, https://twitter.com/chigozienri)
  16. Advanced DangoCutn Cutout method is also from Dango223.
  17. --
  18. Disco:
  19. Somnai (https://twitter.com/Somnai_dreams) added Diffusion Animation techniques, QoL improvements and various implementations of tech and techniques, mostly listed in the changelog below.
  20. 3D animation implementation added by Adam Letts (https://twitter.com/gandamu_ml) in collaboration with Somnai.
  21. Turbo feature by Chris Allen (https://twitter.com/zippy731)
  22. #### License
  23. Licensed under the MIT License
  24. Copyright (c) 2021 Katherine Crowson
  25. Permission is hereby granted, free of charge, to any person obtaining a copy
  26. of this software and associated documentation files (the "Software"), to deal
  27. in the Software without restriction, including without limitation the rights
  28. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  29. copies of the Software, and to permit persons to whom the Software is
  30. furnished to do so, subject to the following conditions:
  31. The above copyright notice and this permission notice shall be included in
  32. all copies or substantial portions of the Software.
  33. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  34. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  35. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  36. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  37. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  38. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  39. THE SOFTWARE.
  40. --
  41. MIT License
  42. Copyright (c) 2019 Intel ISL (Intel Intelligent Systems Lab)
  43. Permission is hereby granted, free of charge, to any person obtaining a copy
  44. of this software and associated documentation files (the "Software"), to deal
  45. in the Software without restriction, including without limitation the rights
  46. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  47. copies of the Software, and to permit persons to whom the Software is
  48. furnished to do so, subject to the following conditions:
  49. The above copyright notice and this permission notice shall be included in all
  50. copies or substantial portions of the Software.
  51. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  52. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  53. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  54. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  55. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  56. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  57. SOFTWARE.
  58. --
  59. Licensed under the MIT License
  60. Copyright (c) 2021 Maxwell Ingham
  61. Copyright (c) 2022 Adam Letts
  62. Permission is hereby granted, free of charge, to any person obtaining a copy
  63. of this software and associated documentation files (the "Software"), to deal
  64. in the Software without restriction, including without limitation the rights
  65. to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  66. copies of the Software, and to permit persons to whom the Software is
  67. furnished to do so, subject to the following conditions:
  68. The above copyright notice and this permission notice shall be included in
  69. all copies or substantial portions of the Software.
  70. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  71. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  72. FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  73. AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  74. LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  75. OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
  76. THE SOFTWARE.
  77. #### Changelog
  78. """
  79. #@title <- View Changelog
  80. skip_for_run_all = True #@param {type: 'boolean'}
  81. if skip_for_run_all == False:
  82. print(
  83. '''
  84. v1 Update: Oct 29th 2021 - Somnai
  85. QoL improvements added by Somnai (@somnai_dreams), including user friendly UI, settings+prompt saving and improved google drive folder organization.
  86. v1.1 Update: Nov 13th 2021 - Somnai
  87. Now includes sizing options, intermediate saves and fixed image prompts and perlin inits. unexposed batch option since it doesn't work
  88. v2 Update: Nov 22nd 2021 - Somnai
  89. Initial addition of Katherine Crowson's Secondary Model Method (https://colab.research.google.com/drive/1mpkrhOjoyzPeSWy2r7T8EYRaU7amYOOi#scrollTo=X5gODNAMEUCR)
  90. Noticed settings were saving with the wrong name so corrected it. Let me know if you preferred the old scheme.
  91. v3 Update: Dec 24th 2021 - Somnai
  92. Implemented Dango's advanced cutout method
  93. Added SLIP models, thanks to NeuralDivergent
  94. Fixed issue with NaNs resulting in black images, with massive help and testing from @Softology
  95. Perlin now changes properly within batches (not sure where this perlin_regen code came from originally, but thank you)
  96. v4 Update: Jan 2021 - Somnai
  97. Implemented Diffusion Zooming
  98. Added Chigozie keyframing
  99. Made a bunch of edits to processes
  100. v4.1 Update: Jan 14th 2021 - Somnai
  101. Added video input mode
  102. Added license that somehow went missing
  103. Added improved prompt keyframing, fixed image_prompts and multiple prompts
  104. Improved UI
  105. Significant under the hood cleanup and improvement
  106. Refined defaults for each mode
  107. Added latent-diffusion SuperRes for sharpening
  108. Added resume run mode
  109. v4.9 Update: Feb 5th 2022 - gandamu / Adam Letts
  110. Added 3D
  111. Added brightness corrections to prevent animation from steadily going dark over time
  112. v4.91 Update: Feb 19th 2022 - gandamu / Adam Letts
  113. Cleaned up 3D implementation and made associated args accessible via Colab UI elements
  114. v4.92 Update: Feb 20th 2022 - gandamu / Adam Letts
  115. Separated transform code
  116. v5.01 Update: Mar 10th 2022 - gandamu / Adam Letts
  117. IPython magic commands replaced by Python code
  118. v5.1 Update: Mar 30th 2022 - zippy / Chris Allen and gandamu / Adam Letts
  119. Integrated Turbo+Smooth features from Disco Diffusion Turbo -- just the implementation, without its defaults.
  120. Implemented resume of turbo animations in such a way that it's now possible to resume from different batch folders and batch numbers.
  121. 3D rotation parameter units are now degrees (rather than radians)
  122. Corrected name collision in sampling_mode (now diffusion_sampling_mode for plms/ddim, and sampling_mode for 3D transform sampling)
  123. Added video_init_seed_continuity option to make init video animations more continuous
  124. '''
  125. )
  126. """# Tutorial
  127. **Diffusion settings (Defaults are heavily outdated)**
  128. ---
  129. This section is outdated as of v2
  130. Setting | Description | Default
  131. --- | --- | ---
  132. **Your vision:**
  133. `text_prompts` | A description of what you'd like the machine to generate. Think of it like writing the caption below your image on a website. | N/A
  134. `image_prompts` | Think of these images more as a description of their contents. | N/A
  135. **Image quality:**
  136. `clip_guidance_scale` | Controls how much the image should look like the prompt. | 1000
  137. `tv_scale` | Controls the smoothness of the final output. | 150
  138. `range_scale` | Controls how far out of range RGB values are allowed to be. | 150
  139. `sat_scale` | Controls how much saturation is allowed. From nshepperd's JAX notebook. | 0
  140. `cutn` | Controls how many crops to take from the image. | 16
  141. `cutn_batches` | Accumulate CLIP gradient from multiple batches of cuts | 2
  142. **Init settings:**
  143. `init_image` | URL or local path | None
  144. `init_scale` | This enhances the effect of the init image, a good value is 1000 | 0
  145. `skip_steps Controls the starting point along the diffusion timesteps | 0
  146. `perlin_init` | Option to start with random perlin noise | False
  147. `perlin_mode` | ('gray', 'color') | 'mixed'
  148. **Advanced:**
  149. `skip_augs` |Controls whether to skip torchvision augmentations | False
  150. `randomize_class` |Controls whether the imagenet class is randomly changed each iteration | True
  151. `clip_denoised` |Determines whether CLIP discriminates a noisy or denoised image | False
  152. `clamp_grad` |Experimental: Using adaptive clip grad in the cond_fn | True
  153. `seed` | Choose a random seed and print it at end of run for reproduction | random_seed
  154. `fuzzy_prompt` | Controls whether to add multiple noisy prompts to the prompt losses | False
  155. `rand_mag` |Controls the magnitude of the random noise | 0.1
  156. `eta` | DDIM hyperparameter | 0.5
  157. ..
  158. **Model settings**
  159. ---
  160. Setting | Description | Default
  161. --- | --- | ---
  162. **Diffusion:**
  163. `timestep_respacing` | Modify this value to decrease the number of timesteps. | ddim100
  164. `diffusion_steps` || 1000
  165. **Diffusion:**
  166. `clip_models` | Models of CLIP to load. Typically the more, the better but they all come at a hefty VRAM cost. | ViT-B/32, ViT-B/16, RN50x4
  167. # 1. Set Up
  168. """
  169. #@title 1.1 Check GPU Status
  170. import subprocess
  171. simple_nvidia_smi_display = False#@param {type:"boolean"}
  172. if simple_nvidia_smi_display:
  173. #!nvidia-smi
  174. nvidiasmi_output = subprocess.run(['nvidia-smi', '-L'], stdout=subprocess.PIPE).stdout.decode('utf-8')
  175. print(nvidiasmi_output)
  176. else:
  177. #!nvidia-smi -i 0 -e 0
  178. nvidiasmi_output = subprocess.run(['nvidia-smi'], stdout=subprocess.PIPE).stdout.decode('utf-8')
  179. print(nvidiasmi_output)
  180. nvidiasmi_ecc_note = subprocess.run(['nvidia-smi', '-i', '0', '-e', '0'], stdout=subprocess.PIPE).stdout.decode('utf-8')
  181. print(nvidiasmi_ecc_note)
  182. #@title 1.2 Prepare Folders
  183. import subprocess
  184. import sys
  185. import ipykernel
  186. def gitclone(url):
  187. res = subprocess.run(['git', 'clone', url], stdout=subprocess.PIPE).stdout.decode('utf-8')
  188. print(res)
  189. def pipi(modulestr):
  190. res = subprocess.run(['pip', 'install', modulestr], stdout=subprocess.PIPE).stdout.decode('utf-8')
  191. print(res)
  192. def pipie(modulestr):
  193. res = subprocess.run(['git', 'install', '-e', modulestr], stdout=subprocess.PIPE).stdout.decode('utf-8')
  194. print(res)
  195. def wget(url, outputdir):
  196. res = subprocess.run(['wget', url, '-P', f'{outputdir}'], stdout=subprocess.PIPE).stdout.decode('utf-8')
  197. print(res)
  198. try:
  199. from google.colab import drive
  200. print("Google Colab detected. Using Google Drive.")
  201. is_colab = True
  202. #@markdown If you connect your Google Drive, you can save the final image of each run on your drive.
  203. google_drive = True #@param {type:"boolean"}
  204. #@markdown Click here if you'd like to save the diffusion model checkpoint file to (and/or load from) your Google Drive:
  205. save_models_to_google_drive = True #@param {type:"boolean"}
  206. except:
  207. is_colab = False
  208. google_drive = False
  209. save_models_to_google_drive = False
  210. print("Google Colab not detected.")
  211. if is_colab:
  212. if google_drive is True:
  213. drive.mount('/content/drive')
  214. root_path = '/content/drive/MyDrive/AI/Disco_Diffusion'
  215. else:
  216. root_path = '/content'
  217. else:
  218. root_path = '.'
  219. import os
  220. def createPath(filepath):
  221. os.makedirs(filepath, exist_ok=True)
  222. initDirPath = f'{root_path}/init_images'
  223. createPath(initDirPath)
  224. outDirPath = f'{root_path}/images_out'
  225. createPath(outDirPath)
  226. if is_colab:
  227. if google_drive and not save_models_to_google_drive or not google_drive:
  228. model_path = '/content/model'
  229. createPath(model_path)
  230. if google_drive and save_models_to_google_drive:
  231. model_path = f'{root_path}/model'
  232. createPath(model_path)
  233. else:
  234. model_path = f'{root_path}/model'
  235. createPath(model_path)
  236. # libraries = f'{root_path}/libraries'
  237. # createPath(libraries)
  238. #@title ### 1.3 Install and import dependencies
  239. import pathlib, shutil
  240. if not is_colab:
  241. # If running locally, there's a good chance your env will need this in order to not crash upon np.matmul() or similar operations.
  242. os.environ['KMP_DUPLICATE_LIB_OK']='TRUE'
  243. PROJECT_DIR = os.path.abspath(os.getcwd())
  244. USE_ADABINS = True
  245. if is_colab:
  246. if google_drive is not True:
  247. root_path = f'/content'
  248. model_path = '/content/models'
  249. else:
  250. root_path = f'.'
  251. model_path = f'{root_path}/model'
  252. model_256_downloaded = False
  253. model_512_downloaded = False
  254. model_secondary_downloaded = False
  255. if is_colab:
  256. gitclone("https://github.com/openai/CLIP")
  257. #gitclone("https://github.com/facebookresearch/SLIP.git")
  258. gitclone("https://github.com/crowsonkb/guided-diffusion")
  259. gitclone("https://github.com/assafshocher/ResizeRight.git")
  260. gitclone("https://github.com/MSFTserver/pytorch3d-lite.git")
  261. pipie("./CLIP")
  262. pipie("./guided-diffusion")
  263. multipip_res = subprocess.run(['pip', 'install', 'lpips', 'datetime', 'timm', 'ftfy'], stdout=subprocess.PIPE).stdout.decode('utf-8')
  264. print(multipip_res)
  265. subprocess.run(['apt', 'install', 'imagemagick'], stdout=subprocess.PIPE).stdout.decode('utf-8')
  266. gitclone("https://github.com/isl-org/MiDaS.git")
  267. gitclone("https://github.com/alembics/disco-diffusion.git")
  268. pipi("pytorch-lightning")
  269. pipi("omegaconf")
  270. pipi("einops")
  271. # Rename a file to avoid a name conflict..
  272. try:
  273. os.rename("MiDaS/utils.py", "MiDaS/midas_utils.py")
  274. shutil.copyfile("disco-diffusion/disco_xform_utils.py", "disco_xform_utils.py")
  275. except:
  276. pass
  277. if not os.path.exists(f'{model_path}'):
  278. pathlib.Path(model_path).mkdir(parents=True, exist_ok=True)
  279. if not os.path.exists(f'{model_path}/dpt_large-midas-2f21e586.pt'):
  280. wget("https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt", model_path)
  281. import sys
  282. import torch
  283. # sys.path.append('./SLIP')
  284. sys.path.append('./pytorch3d-lite')
  285. sys.path.append('./ResizeRight')
  286. sys.path.append('./MiDaS')
  287. from dataclasses import dataclass
  288. from functools import partial
  289. import cv2
  290. import pandas as pd
  291. import gc
  292. import io
  293. import math
  294. import timm
  295. from IPython import display
  296. import lpips
  297. from PIL import Image, ImageOps
  298. import requests
  299. from glob import glob
  300. import json
  301. from types import SimpleNamespace
  302. from torch import nn
  303. from torch.nn import functional as F
  304. import torchvision.transforms as T
  305. import torchvision.transforms.functional as TF
  306. from tqdm.notebook import tqdm
  307. sys.path.append('./CLIP')
  308. sys.path.append('./guided-diffusion')
  309. import clip
  310. from resize_right import resize
  311. # from models import SLIP_VITB16, SLIP, SLIP_VITL16
  312. from guided_diffusion.script_util import create_model_and_diffusion, model_and_diffusion_defaults
  313. from datetime import datetime
  314. import numpy as np
  315. import matplotlib.pyplot as plt
  316. import random
  317. from ipywidgets import Output
  318. import hashlib
  319. #SuperRes
  320. if is_colab:
  321. gitclone("https://github.com/CompVis/latent-diffusion.git")
  322. gitclone("https://github.com/CompVis/taming-transformers")
  323. pipie("./taming-transformers")
  324. pipi("ipywidgets omegaconf>=2.0.0 pytorch-lightning>=1.0.8 torch-fidelity einops wandb")
  325. #SuperRes
  326. import ipywidgets as widgets
  327. import os
  328. sys.path.append(".")
  329. sys.path.append('./taming-transformers')
  330. from taming.models import vqgan # checking correct import from taming
  331. from torchvision.datasets.utils import download_url
  332. if is_colab:
  333. os.chdir('/content/latent-diffusion')
  334. else:
  335. #os.chdir('latent-diffusion')
  336. sys.path.append('latent-diffusion')
  337. from functools import partial
  338. from ldm.util import instantiate_from_config
  339. from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like
  340. # from ldm.models.diffusion.ddim import DDIMSampler
  341. from ldm.util import ismap
  342. if is_colab:
  343. os.chdir('/content')
  344. from google.colab import files
  345. else:
  346. os.chdir(f'{PROJECT_DIR}')
  347. from IPython.display import Image as ipyimg
  348. from numpy import asarray
  349. from einops import rearrange, repeat
  350. import torch, torchvision
  351. import time
  352. from omegaconf import OmegaConf
  353. import warnings
  354. warnings.filterwarnings("ignore", category=UserWarning)
  355. # AdaBins stuff
  356. if USE_ADABINS:
  357. if is_colab:
  358. gitclone("https://github.com/shariqfarooq123/AdaBins.git")
  359. if not os.path.exists(f'{model_path}/AdaBins_nyu.pt'):
  360. wget("https://cloudflare-ipfs.com/ipfs/Qmd2mMnDLWePKmgfS8m6ntAg4nhV5VkUyAydYBp8cWWeB7/AdaBins_nyu.pt", model_path)
  361. pathlib.Path("pretrained").mkdir(parents=True, exist_ok=True)
  362. shutil.copyfile(f"{model_path}/AdaBins_nyu.pt", "pretrained/AdaBins_nyu.pt")
  363. sys.path.append('./AdaBins')
  364. from infer import InferenceHelper
  365. MAX_ADABINS_AREA = 500000
  366. import torch
  367. DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
  368. print('Using device:', DEVICE)
  369. device = DEVICE # At least one of the modules expects this name..
  370. if torch.cuda.get_device_capability(DEVICE) == (8,0): ## A100 fix thanks to Emad
  371. print('Disabling CUDNN for A100 gpu', file=sys.stderr)
  372. torch.backends.cudnn.enabled = False
  373. #@title ### 1.4 Define Midas functions
  374. from midas.dpt_depth import DPTDepthModel
  375. from midas.midas_net import MidasNet
  376. from midas.midas_net_custom import MidasNet_small
  377. from midas.transforms import Resize, NormalizeImage, PrepareForNet
  378. # Initialize MiDaS depth model.
  379. # It remains resident in VRAM and likely takes around 2GB VRAM.
  380. # You could instead initialize it for each frame (and free it after each frame) to save VRAM.. but initializing it is slow.
  381. default_models = {
  382. "midas_v21_small": f"{model_path}/midas_v21_small-70d6b9c8.pt",
  383. "midas_v21": f"{model_path}/midas_v21-f6b98070.pt",
  384. "dpt_large": f"{model_path}/dpt_large-midas-2f21e586.pt",
  385. "dpt_hybrid": f"{model_path}/dpt_hybrid-midas-501f0c75.pt",
  386. "dpt_hybrid_nyu": f"{model_path}/dpt_hybrid_nyu-2ce69ec7.pt",}
  387. def init_midas_depth_model(midas_model_type="dpt_large", optimize=True):
  388. midas_model = None
  389. net_w = None
  390. net_h = None
  391. resize_mode = None
  392. normalization = None
  393. print(f"Initializing MiDaS '{midas_model_type}' depth model...")
  394. # load network
  395. midas_model_path = default_models[midas_model_type]
  396. if midas_model_type == "dpt_large": # DPT-Large
  397. midas_model = DPTDepthModel(
  398. path=midas_model_path,
  399. backbone="vitl16_384",
  400. non_negative=True,
  401. )
  402. net_w, net_h = 384, 384
  403. resize_mode = "minimal"
  404. normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
  405. elif midas_model_type == "dpt_hybrid": #DPT-Hybrid
  406. midas_model = DPTDepthModel(
  407. path=midas_model_path,
  408. backbone="vitb_rn50_384",
  409. non_negative=True,
  410. )
  411. net_w, net_h = 384, 384
  412. resize_mode="minimal"
  413. normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
  414. elif midas_model_type == "dpt_hybrid_nyu": #DPT-Hybrid-NYU
  415. midas_model = DPTDepthModel(
  416. path=midas_model_path,
  417. backbone="vitb_rn50_384",
  418. non_negative=True,
  419. )
  420. net_w, net_h = 384, 384
  421. resize_mode="minimal"
  422. normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
  423. elif midas_model_type == "midas_v21":
  424. midas_model = MidasNet(midas_model_path, non_negative=True)
  425. net_w, net_h = 384, 384
  426. resize_mode="upper_bound"
  427. normalization = NormalizeImage(
  428. mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
  429. )
  430. elif midas_model_type == "midas_v21_small":
  431. midas_model = MidasNet_small(midas_model_path, features=64, backbone="efficientnet_lite3", exportable=True, non_negative=True, blocks={'expand': True})
  432. net_w, net_h = 256, 256
  433. resize_mode="upper_bound"
  434. normalization = NormalizeImage(
  435. mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
  436. )
  437. else:
  438. print(f"midas_model_type '{midas_model_type}' not implemented")
  439. assert False
  440. midas_transform = T.Compose(
  441. [
  442. Resize(
  443. net_w,
  444. net_h,
  445. resize_target=None,
  446. keep_aspect_ratio=True,
  447. ensure_multiple_of=32,
  448. resize_method=resize_mode,
  449. image_interpolation_method=cv2.INTER_CUBIC,
  450. ),
  451. normalization,
  452. PrepareForNet(),
  453. ]
  454. )
  455. midas_model.eval()
  456. if optimize==True:
  457. if DEVICE == torch.device("cuda"):
  458. midas_model = midas_model.to(memory_format=torch.channels_last)
  459. midas_model = midas_model.half()
  460. midas_model.to(DEVICE)
  461. print(f"MiDaS '{midas_model_type}' depth model initialized.")
  462. return midas_model, midas_transform, net_w, net_h, resize_mode, normalization
  463. #@title 1.5 Define necessary functions
  464. # https://gist.github.com/adefossez/0646dbe9ed4005480a2407c62aac8869
  465. import py3d_tools as p3dT
  466. import disco_xform_utils as dxf
  467. def interp(t):
  468. return 3 * t**2 - 2 * t ** 3
  469. def perlin(width, height, scale=10, device=None):
  470. gx, gy = torch.randn(2, width + 1, height + 1, 1, 1, device=device)
  471. xs = torch.linspace(0, 1, scale + 1)[:-1, None].to(device)
  472. ys = torch.linspace(0, 1, scale + 1)[None, :-1].to(device)
  473. wx = 1 - interp(xs)
  474. wy = 1 - interp(ys)
  475. dots = 0
  476. dots += wx * wy * (gx[:-1, :-1] * xs + gy[:-1, :-1] * ys)
  477. dots += (1 - wx) * wy * (-gx[1:, :-1] * (1 - xs) + gy[1:, :-1] * ys)
  478. dots += wx * (1 - wy) * (gx[:-1, 1:] * xs - gy[:-1, 1:] * (1 - ys))
  479. dots += (1 - wx) * (1 - wy) * (-gx[1:, 1:] * (1 - xs) - gy[1:, 1:] * (1 - ys))
  480. return dots.permute(0, 2, 1, 3).contiguous().view(width * scale, height * scale)
  481. def perlin_ms(octaves, width, height, grayscale, device=device):
  482. out_array = [0.5] if grayscale else [0.5, 0.5, 0.5]
  483. # out_array = [0.0] if grayscale else [0.0, 0.0, 0.0]
  484. for i in range(1 if grayscale else 3):
  485. scale = 2 ** len(octaves)
  486. oct_width = width
  487. oct_height = height
  488. for oct in octaves:
  489. p = perlin(oct_width, oct_height, scale, device)
  490. out_array[i] += p * oct
  491. scale //= 2
  492. oct_width *= 2
  493. oct_height *= 2
  494. return torch.cat(out_array)
  495. def create_perlin_noise(octaves=[1, 1, 1, 1], width=2, height=2, grayscale=True):
  496. out = perlin_ms(octaves, width, height, grayscale)
  497. if grayscale:
  498. out = TF.resize(size=(side_y, side_x), img=out.unsqueeze(0))
  499. out = TF.to_pil_image(out.clamp(0, 1)).convert('RGB')
  500. else:
  501. out = out.reshape(-1, 3, out.shape[0]//3, out.shape[1])
  502. out = TF.resize(size=(side_y, side_x), img=out)
  503. out = TF.to_pil_image(out.clamp(0, 1).squeeze())
  504. out = ImageOps.autocontrast(out)
  505. return out
  506. def regen_perlin():
  507. if perlin_mode == 'color':
  508. init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)
  509. init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, False)
  510. elif perlin_mode == 'gray':
  511. init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, True)
  512. init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)
  513. else:
  514. init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)
  515. init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)
  516. init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device).unsqueeze(0).mul(2).sub(1)
  517. del init2
  518. return init.expand(batch_size, -1, -1, -1)
  519. def fetch(url_or_path):
  520. if str(url_or_path).startswith('http://') or str(url_or_path).startswith('https://'):
  521. r = requests.get(url_or_path)
  522. r.raise_for_status()
  523. fd = io.BytesIO()
  524. fd.write(r.content)
  525. fd.seek(0)
  526. return fd
  527. return open(url_or_path, 'rb')
  528. def read_image_workaround(path):
  529. """OpenCV reads images as BGR, Pillow saves them as RGB. Work around
  530. this incompatibility to avoid colour inversions."""
  531. im_tmp = cv2.imread(path)
  532. return cv2.cvtColor(im_tmp, cv2.COLOR_BGR2RGB)
  533. def parse_prompt(prompt):
  534. if prompt.startswith('http://') or prompt.startswith('https://'):
  535. vals = prompt.rsplit(':', 2)
  536. vals = [vals[0] + ':' + vals[1], *vals[2:]]
  537. else:
  538. vals = prompt.rsplit(':', 1)
  539. vals = vals + ['', '1'][len(vals):]
  540. return vals[0], float(vals[1])
  541. def sinc(x):
  542. return torch.where(x != 0, torch.sin(math.pi * x) / (math.pi * x), x.new_ones([]))
  543. def lanczos(x, a):
  544. cond = torch.logical_and(-a < x, x < a)
  545. out = torch.where(cond, sinc(x) * sinc(x/a), x.new_zeros([]))
  546. return out / out.sum()
  547. def ramp(ratio, width):
  548. n = math.ceil(width / ratio + 1)
  549. out = torch.empty([n])
  550. cur = 0
  551. for i in range(out.shape[0]):
  552. out[i] = cur
  553. cur += ratio
  554. return torch.cat([-out[1:].flip([0]), out])[1:-1]
  555. def resample(input, size, align_corners=True):
  556. n, c, h, w = input.shape
  557. dh, dw = size
  558. input = input.reshape([n * c, 1, h, w])
  559. if dh < h:
  560. kernel_h = lanczos(ramp(dh / h, 2), 2).to(input.device, input.dtype)
  561. pad_h = (kernel_h.shape[0] - 1) // 2
  562. input = F.pad(input, (0, 0, pad_h, pad_h), 'reflect')
  563. input = F.conv2d(input, kernel_h[None, None, :, None])
  564. if dw < w:
  565. kernel_w = lanczos(ramp(dw / w, 2), 2).to(input.device, input.dtype)
  566. pad_w = (kernel_w.shape[0] - 1) // 2
  567. input = F.pad(input, (pad_w, pad_w, 0, 0), 'reflect')
  568. input = F.conv2d(input, kernel_w[None, None, None, :])
  569. input = input.reshape([n, c, h, w])
  570. return F.interpolate(input, size, mode='bicubic', align_corners=align_corners)
  571. class MakeCutouts(nn.Module):
  572. def __init__(self, cut_size, cutn, skip_augs=False):
  573. super().__init__()
  574. self.cut_size = cut_size
  575. self.cutn = cutn
  576. self.skip_augs = skip_augs
  577. self.augs = T.Compose([
  578. T.RandomHorizontalFlip(p=0.5),
  579. T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
  580. T.RandomAffine(degrees=15, translate=(0.1, 0.1)),
  581. T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
  582. T.RandomPerspective(distortion_scale=0.4, p=0.7),
  583. T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
  584. T.RandomGrayscale(p=0.15),
  585. T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
  586. # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
  587. ])
  588. def forward(self, input):
  589. input = T.Pad(input.shape[2]//4, fill=0)(input)
  590. sideY, sideX = input.shape[2:4]
  591. max_size = min(sideX, sideY)
  592. cutouts = []
  593. for ch in range(self.cutn):
  594. if ch > self.cutn - self.cutn//4:
  595. cutout = input.clone()
  596. else:
  597. size = int(max_size * torch.zeros(1,).normal_(mean=.8, std=.3).clip(float(self.cut_size/max_size), 1.))
  598. offsetx = torch.randint(0, abs(sideX - size + 1), ())
  599. offsety = torch.randint(0, abs(sideY - size + 1), ())
  600. cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
  601. if not self.skip_augs:
  602. cutout = self.augs(cutout)
  603. cutouts.append(resample(cutout, (self.cut_size, self.cut_size)))
  604. del cutout
  605. cutouts = torch.cat(cutouts, dim=0)
  606. return cutouts
  607. cutout_debug = False
  608. padargs = {}
  609. class MakeCutoutsDango(nn.Module):
  610. def __init__(self, cut_size,
  611. Overview=4,
  612. InnerCrop = 0, IC_Size_Pow=0.5, IC_Grey_P = 0.2
  613. ):
  614. super().__init__()
  615. self.cut_size = cut_size
  616. self.Overview = Overview
  617. self.InnerCrop = InnerCrop
  618. self.IC_Size_Pow = IC_Size_Pow
  619. self.IC_Grey_P = IC_Grey_P
  620. if args.animation_mode == 'None':
  621. self.augs = T.Compose([
  622. T.RandomHorizontalFlip(p=0.5),
  623. T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
  624. T.RandomAffine(degrees=10, translate=(0.05, 0.05), interpolation = T.InterpolationMode.BILINEAR),
  625. T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
  626. T.RandomGrayscale(p=0.1),
  627. T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
  628. T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
  629. ])
  630. elif args.animation_mode == 'Video Input':
  631. self.augs = T.Compose([
  632. T.RandomHorizontalFlip(p=0.5),
  633. T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
  634. T.RandomAffine(degrees=15, translate=(0.1, 0.1)),
  635. T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
  636. T.RandomPerspective(distortion_scale=0.4, p=0.7),
  637. T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
  638. T.RandomGrayscale(p=0.15),
  639. T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
  640. # T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
  641. ])
  642. elif args.animation_mode == '2D' or args.animation_mode == '3D':
  643. self.augs = T.Compose([
  644. T.RandomHorizontalFlip(p=0.4),
  645. T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
  646. T.RandomAffine(degrees=10, translate=(0.05, 0.05), interpolation = T.InterpolationMode.BILINEAR),
  647. T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
  648. T.RandomGrayscale(p=0.1),
  649. T.Lambda(lambda x: x + torch.randn_like(x) * 0.01),
  650. T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.3),
  651. ])
  652. def forward(self, input):
  653. cutouts = []
  654. gray = T.Grayscale(3)
  655. sideY, sideX = input.shape[2:4]
  656. max_size = min(sideX, sideY)
  657. min_size = min(sideX, sideY, self.cut_size)
  658. l_size = max(sideX, sideY)
  659. output_shape = [1,3,self.cut_size,self.cut_size]
  660. output_shape_2 = [1,3,self.cut_size+2,self.cut_size+2]
  661. pad_input = F.pad(input,((sideY-max_size)//2,(sideY-max_size)//2,(sideX-max_size)//2,(sideX-max_size)//2), **padargs)
  662. cutout = resize(pad_input, out_shape=output_shape)
  663. if self.Overview>0:
  664. if self.Overview<=4:
  665. if self.Overview>=1:
  666. cutouts.append(cutout)
  667. if self.Overview>=2:
  668. cutouts.append(gray(cutout))
  669. if self.Overview>=3:
  670. cutouts.append(TF.hflip(cutout))
  671. if self.Overview==4:
  672. cutouts.append(gray(TF.hflip(cutout)))
  673. else:
  674. cutout = resize(pad_input, out_shape=output_shape)
  675. for _ in range(self.Overview):
  676. cutouts.append(cutout)
  677. if cutout_debug:
  678. if is_colab:
  679. TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save("/content/cutout_overview0.jpg",quality=99)
  680. else:
  681. TF.to_pil_image(cutouts[0].clamp(0, 1).squeeze(0)).save("cutout_overview0.jpg",quality=99)
  682. if self.InnerCrop >0:
  683. for i in range(self.InnerCrop):
  684. size = int(torch.rand([])**self.IC_Size_Pow * (max_size - min_size) + min_size)
  685. offsetx = torch.randint(0, sideX - size + 1, ())
  686. offsety = torch.randint(0, sideY - size + 1, ())
  687. cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
  688. if i <= int(self.IC_Grey_P * self.InnerCrop):
  689. cutout = gray(cutout)
  690. cutout = resize(cutout, out_shape=output_shape)
  691. cutouts.append(cutout)
  692. if cutout_debug:
  693. if is_colab:
  694. TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save("/content/cutout_InnerCrop.jpg",quality=99)
  695. else:
  696. TF.to_pil_image(cutouts[-1].clamp(0, 1).squeeze(0)).save("cutout_InnerCrop.jpg",quality=99)
  697. cutouts = torch.cat(cutouts)
  698. if skip_augs is not True: cutouts=self.augs(cutouts)
  699. return cutouts
  700. def spherical_dist_loss(x, y):
  701. x = F.normalize(x, dim=-1)
  702. y = F.normalize(y, dim=-1)
  703. return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
  704. def tv_loss(input):
  705. """L2 total variation loss, as in Mahendran et al."""
  706. input = F.pad(input, (0, 1, 0, 1), 'replicate')
  707. x_diff = input[..., :-1, 1:] - input[..., :-1, :-1]
  708. y_diff = input[..., 1:, :-1] - input[..., :-1, :-1]
  709. return (x_diff**2 + y_diff**2).mean([1, 2, 3])
  710. def range_loss(input):
  711. return (input - input.clamp(-1, 1)).pow(2).mean([1, 2, 3])
  712. stop_on_next_loop = False # Make sure GPU memory doesn't get corrupted from cancelling the run mid-way through, allow a full frame to complete
  713. def do_3d_step(img_filepath, frame_num, midas_model, midas_transform):
  714. if args.key_frames:
  715. translation_x = args.translation_x_series[frame_num]
  716. translation_y = args.translation_y_series[frame_num]
  717. translation_z = args.translation_z_series[frame_num]
  718. rotation_3d_x = args.rotation_3d_x_series[frame_num]
  719. rotation_3d_y = args.rotation_3d_y_series[frame_num]
  720. rotation_3d_z = args.rotation_3d_z_series[frame_num]
  721. print(
  722. f'translation_x: {translation_x}',
  723. f'translation_y: {translation_y}',
  724. f'translation_z: {translation_z}',
  725. f'rotation_3d_x: {rotation_3d_x}',
  726. f'rotation_3d_y: {rotation_3d_y}',
  727. f'rotation_3d_z: {rotation_3d_z}',
  728. )
  729. trans_scale = 1.0/200.0
  730. translate_xyz = [-translation_x*trans_scale, translation_y*trans_scale, -translation_z*trans_scale]
  731. rotate_xyz_degrees = [rotation_3d_x, rotation_3d_y, rotation_3d_z]
  732. print('translation:',translate_xyz)
  733. print('rotation:',rotate_xyz_degrees)
  734. rotate_xyz = [math.radians(rotate_xyz_degrees[0]), math.radians(rotate_xyz_degrees[1]), math.radians(rotate_xyz_degrees[2])]
  735. rot_mat = p3dT.euler_angles_to_matrix(torch.tensor(rotate_xyz, device=device), "XYZ").unsqueeze(0)
  736. print("rot_mat: " + str(rot_mat))
  737. next_step_pil = dxf.transform_image_3d(img_filepath, midas_model, midas_transform, DEVICE,
  738. rot_mat, translate_xyz, args.near_plane, args.far_plane,
  739. args.fov, padding_mode=args.padding_mode,
  740. sampling_mode=args.sampling_mode, midas_weight=args.midas_weight)
  741. return next_step_pil
  742. def do_run():
  743. seed = args.seed
  744. print(range(args.start_frame, args.max_frames))
  745. if (args.animation_mode == "3D") and (args.midas_weight > 0.0):
  746. midas_model, midas_transform, midas_net_w, midas_net_h, midas_resize_mode, midas_normalization = init_midas_depth_model(args.midas_depth_model)
  747. for frame_num in range(args.start_frame, args.max_frames):
  748. if stop_on_next_loop:
  749. break
  750. display.clear_output(wait=True)
  751. # Print Frame progress if animation mode is on
  752. if args.animation_mode != "None":
  753. batchBar = tqdm(range(args.max_frames), desc ="Frames")
  754. batchBar.n = frame_num
  755. batchBar.refresh()
  756. # Inits if not video frames
  757. if args.animation_mode != "Video Input":
  758. if args.init_image == '':
  759. init_image = None
  760. else:
  761. init_image = args.init_image
  762. init_scale = args.init_scale
  763. skip_steps = args.skip_steps
  764. if args.animation_mode == "2D":
  765. if args.key_frames:
  766. angle = args.angle_series[frame_num]
  767. zoom = args.zoom_series[frame_num]
  768. translation_x = args.translation_x_series[frame_num]
  769. translation_y = args.translation_y_series[frame_num]
  770. print(
  771. f'angle: {angle}',
  772. f'zoom: {zoom}',
  773. f'translation_x: {translation_x}',
  774. f'translation_y: {translation_y}',
  775. )
  776. if frame_num > 0:
  777. seed += 1
  778. if resume_run and frame_num == start_frame:
  779. img_0 = cv2.imread(batchFolder+f"/{batch_name}({batchNum})_{start_frame-1:04}.png")
  780. else:
  781. img_0 = cv2.imread('prevFrame.png')
  782. center = (1*img_0.shape[1]//2, 1*img_0.shape[0]//2)
  783. trans_mat = np.float32(
  784. [[1, 0, translation_x],
  785. [0, 1, translation_y]]
  786. )
  787. rot_mat = cv2.getRotationMatrix2D( center, angle, zoom )
  788. trans_mat = np.vstack([trans_mat, [0,0,1]])
  789. rot_mat = np.vstack([rot_mat, [0,0,1]])
  790. transformation_matrix = np.matmul(rot_mat, trans_mat)
  791. img_0 = cv2.warpPerspective(
  792. img_0,
  793. transformation_matrix,
  794. (img_0.shape[1], img_0.shape[0]),
  795. borderMode=cv2.BORDER_WRAP
  796. )
  797. cv2.imwrite('prevFrameScaled.png', img_0)
  798. init_image = 'prevFrameScaled.png'
  799. init_scale = args.frames_scale
  800. skip_steps = args.calc_frames_skip_steps
  801. if args.animation_mode == "3D":
  802. if frame_num == 0:
  803. pass
  804. else:
  805. seed += 1
  806. if resume_run and frame_num == start_frame:
  807. img_filepath = batchFolder+f"/{batch_name}({batchNum})_{start_frame-1:04}.png"
  808. if turbo_mode and frame_num > turbo_preroll:
  809. shutil.copyfile(img_filepath, 'oldFrameScaled.png')
  810. else:
  811. img_filepath = '/content/prevFrame.png' if is_colab else 'prevFrame.png'
  812. next_step_pil = do_3d_step(img_filepath, frame_num, midas_model, midas_transform)
  813. next_step_pil.save('prevFrameScaled.png')
  814. ### Turbo mode - skip some diffusions, use 3d morph for clarity and to save time
  815. if turbo_mode:
  816. if frame_num == turbo_preroll: #start tracking oldframe
  817. next_step_pil.save('oldFrameScaled.png')#stash for later blending
  818. elif frame_num > turbo_preroll:
  819. #set up 2 warped image sequences, old & new, to blend toward new diff image
  820. old_frame = do_3d_step('oldFrameScaled.png', frame_num, midas_model, midas_transform)
  821. old_frame.save('oldFrameScaled.png')
  822. if frame_num % int(turbo_steps) != 0:
  823. print('turbo skip this frame: skipping clip diffusion steps')
  824. filename = f'{args.batch_name}({args.batchNum})_{frame_num:04}.png'
  825. blend_factor = ((frame_num % int(turbo_steps))+1)/int(turbo_steps)
  826. print('turbo skip this frame: skipping clip diffusion steps and saving blended frame')
  827. newWarpedImg = cv2.imread('prevFrameScaled.png')#this is already updated..
  828. oldWarpedImg = cv2.imread('oldFrameScaled.png')
  829. blendedImage = cv2.addWeighted(newWarpedImg, blend_factor, oldWarpedImg,1-blend_factor, 0.0)
  830. cv2.imwrite(f'{batchFolder}/{filename}',blendedImage)
  831. next_step_pil.save(f'{img_filepath}') # save it also as prev_frame to feed next iteration
  832. continue
  833. else:
  834. #if not a skip frame, will run diffusion and need to blend.
  835. oldWarpedImg = cv2.imread('prevFrameScaled.png')
  836. cv2.imwrite(f'oldFrameScaled.png',oldWarpedImg)#swap in for blending later
  837. print('clip/diff this frame - generate clip diff image')
  838. init_image = 'prevFrameScaled.png'
  839. init_scale = args.frames_scale
  840. skip_steps = args.calc_frames_skip_steps
  841. if args.animation_mode == "Video Input":
  842. if not video_init_seed_continuity:
  843. seed += 1
  844. init_image = f'{videoFramesFolder}/{frame_num+1:04}.jpg'
  845. init_scale = args.frames_scale
  846. skip_steps = args.calc_frames_skip_steps
  847. loss_values = []
  848. if seed is not None:
  849. np.random.seed(seed)
  850. random.seed(seed)
  851. torch.manual_seed(seed)
  852. torch.cuda.manual_seed_all(seed)
  853. torch.backends.cudnn.deterministic = True
  854. target_embeds, weights = [], []
  855. if args.prompts_series is not None and frame_num >= len(args.prompts_series):
  856. frame_prompt = args.prompts_series[-1]
  857. elif args.prompts_series is not None:
  858. frame_prompt = args.prompts_series[frame_num]
  859. else:
  860. frame_prompt = []
  861. print(args.image_prompts_series)
  862. if args.image_prompts_series is not None and frame_num >= len(args.image_prompts_series):
  863. image_prompt = args.image_prompts_series[-1]
  864. elif args.image_prompts_series is not None:
  865. image_prompt = args.image_prompts_series[frame_num]
  866. else:
  867. image_prompt = []
  868. print(f'Frame {frame_num} Prompt: {frame_prompt}')
  869. model_stats = []
  870. for clip_model in clip_models:
  871. cutn = 16
  872. model_stat = {"clip_model":None,"target_embeds":[],"make_cutouts":None,"weights":[]}
  873. model_stat["clip_model"] = clip_model
  874. for prompt in frame_prompt:
  875. txt, weight = parse_prompt(prompt)
  876. txt = clip_model.encode_text(clip.tokenize(prompt).to(device)).float()
  877. if args.fuzzy_prompt:
  878. for i in range(25):
  879. model_stat["target_embeds"].append((txt + torch.randn(txt.shape).cuda() * args.rand_mag).clamp(0,1))
  880. model_stat["weights"].append(weight)
  881. else:
  882. model_stat["target_embeds"].append(txt)
  883. model_stat["weights"].append(weight)
  884. if image_prompt:
  885. model_stat["make_cutouts"] = MakeCutouts(clip_model.visual.input_resolution, cutn, skip_augs=skip_augs)
  886. for prompt in image_prompt:
  887. path, weight = parse_prompt(prompt)
  888. img = Image.open(fetch(path)).convert('RGB')
  889. img = TF.resize(img, min(side_x, side_y, *img.size), T.InterpolationMode.LANCZOS)
  890. batch = model_stat["make_cutouts"](TF.to_tensor(img).to(device).unsqueeze(0).mul(2).sub(1))
  891. embed = clip_model.encode_image(normalize(batch)).float()
  892. if fuzzy_prompt:
  893. for i in range(25):
  894. model_stat["target_embeds"].append((embed + torch.randn(embed.shape).cuda() * rand_mag).clamp(0,1))
  895. weights.extend([weight / cutn] * cutn)
  896. else:
  897. model_stat["target_embeds"].append(embed)
  898. model_stat["weights"].extend([weight / cutn] * cutn)
  899. model_stat["target_embeds"] = torch.cat(model_stat["target_embeds"])
  900. model_stat["weights"] = torch.tensor(model_stat["weights"], device=device)
  901. if model_stat["weights"].sum().abs() < 1e-3:
  902. raise RuntimeError('The weights must not sum to 0.')
  903. model_stat["weights"] /= model_stat["weights"].sum().abs()
  904. model_stats.append(model_stat)
  905. init = None
  906. if init_image is not None:
  907. init = Image.open(fetch(init_image)).convert('RGB')
  908. init = init.resize((args.side_x, args.side_y), Image.LANCZOS)
  909. init = TF.to_tensor(init).to(device).unsqueeze(0).mul(2).sub(1)
  910. if args.perlin_init:
  911. if args.perlin_mode == 'color':
  912. init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)
  913. init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, False)
  914. elif args.perlin_mode == 'gray':
  915. init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, True)
  916. init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)
  917. else:
  918. init = create_perlin_noise([1.5**-i*0.5 for i in range(12)], 1, 1, False)
  919. init2 = create_perlin_noise([1.5**-i*0.5 for i in range(8)], 4, 4, True)
  920. # init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device)
  921. init = TF.to_tensor(init).add(TF.to_tensor(init2)).div(2).to(device).unsqueeze(0).mul(2).sub(1)
  922. del init2
  923. cur_t = None
  924. def cond_fn(x, t, y=None):
  925. with torch.enable_grad():
  926. x_is_NaN = False
  927. x = x.detach().requires_grad_()
  928. n = x.shape[0]
  929. if use_secondary_model is True:
  930. alpha = torch.tensor(diffusion.sqrt_alphas_cumprod[cur_t], device=device, dtype=torch.float32)
  931. sigma = torch.tensor(diffusion.sqrt_one_minus_alphas_cumprod[cur_t], device=device, dtype=torch.float32)
  932. cosine_t = alpha_sigma_to_t(alpha, sigma)
  933. out = secondary_model(x, cosine_t[None].repeat([n])).pred
  934. fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]
  935. x_in = out * fac + x * (1 - fac)
  936. x_in_grad = torch.zeros_like(x_in)
  937. else:
  938. my_t = torch.ones([n], device=device, dtype=torch.long) * cur_t
  939. out = diffusion.p_mean_variance(model, x, my_t, clip_denoised=False, model_kwargs={'y': y})
  940. fac = diffusion.sqrt_one_minus_alphas_cumprod[cur_t]
  941. x_in = out['pred_xstart'] * fac + x * (1 - fac)
  942. x_in_grad = torch.zeros_like(x_in)
  943. for model_stat in model_stats:
  944. for i in range(args.cutn_batches):
  945. t_int = int(t.item())+1 #errors on last step without +1, need to find source
  946. #when using SLIP Base model the dimensions need to be hard coded to avoid AttributeError: 'VisionTransformer' object has no attribute 'input_resolution'
  947. try:
  948. input_resolution=model_stat["clip_model"].visual.input_resolution
  949. except:
  950. input_resolution=224
  951. cuts = MakeCutoutsDango(input_resolution,
  952. Overview= args.cut_overview[1000-t_int],
  953. InnerCrop = args.cut_innercut[1000-t_int], IC_Size_Pow=args.cut_ic_pow, IC_Grey_P = args.cut_icgray_p[1000-t_int]
  954. )
  955. clip_in = normalize(cuts(x_in.add(1).div(2)))
  956. image_embeds = model_stat["clip_model"].encode_image(clip_in).float()
  957. dists = spherical_dist_loss(image_embeds.unsqueeze(1), model_stat["target_embeds"].unsqueeze(0))
  958. dists = dists.view([args.cut_overview[1000-t_int]+args.cut_innercut[1000-t_int], n, -1])
  959. losses = dists.mul(model_stat["weights"]).sum(2).mean(0)
  960. loss_values.append(losses.sum().item()) # log loss, probably shouldn't do per cutn_batch
  961. x_in_grad += torch.autograd.grad(losses.sum() * clip_guidance_scale, x_in)[0] / cutn_batches
  962. tv_losses = tv_loss(x_in)
  963. if use_secondary_model is True:
  964. range_losses = range_loss(out)
  965. else:
  966. range_losses = range_loss(out['pred_xstart'])
  967. sat_losses = torch.abs(x_in - x_in.clamp(min=-1,max=1)).mean()
  968. loss = tv_losses.sum() * tv_scale + range_losses.sum() * range_scale + sat_losses.sum() * sat_scale
  969. if init is not None and args.init_scale:
  970. init_losses = lpips_model(x_in, init)
  971. loss = loss + init_losses.sum() * args.init_scale
  972. x_in_grad += torch.autograd.grad(loss, x_in)[0]
  973. if torch.isnan(x_in_grad).any()==False:
  974. grad = -torch.autograd.grad(x_in, x, x_in_grad)[0]
  975. else:
  976. # print("NaN'd")
  977. x_is_NaN = True
  978. grad = torch.zeros_like(x)
  979. if args.clamp_grad and x_is_NaN == False:
  980. magnitude = grad.square().mean().sqrt()
  981. return grad * magnitude.clamp(max=args.clamp_max) / magnitude #min=-0.02, min=-clamp_max,
  982. return grad
  983. if args.diffusion_sampling_mode == 'ddim':
  984. sample_fn = diffusion.ddim_sample_loop_progressive
  985. else:
  986. sample_fn = diffusion.plms_sample_loop_progressive
  987. image_display = Output()
  988. for i in range(args.n_batches):
  989. if args.animation_mode == 'None':
  990. display.clear_output(wait=True)
  991. batchBar = tqdm(range(args.n_batches), desc ="Batches")
  992. batchBar.n = i
  993. batchBar.refresh()
  994. print('')
  995. display.display(image_display)
  996. gc.collect()
  997. torch.cuda.empty_cache()
  998. cur_t = diffusion.num_timesteps - skip_steps - 1
  999. total_steps = cur_t
  1000. if perlin_init:
  1001. init = regen_perlin()
  1002. if args.diffusion_sampling_mode == 'ddim':
  1003. samples = sample_fn(
  1004. model,
  1005. (batch_size, 3, args.side_y, args.side_x),
  1006. clip_denoised=clip_denoised,
  1007. model_kwargs={},
  1008. cond_fn=cond_fn,
  1009. progress=True,
  1010. skip_timesteps=skip_steps,
  1011. init_image=init,
  1012. randomize_class=randomize_class,
  1013. eta=eta,
  1014. )
  1015. else:
  1016. samples = sample_fn(
  1017. model,
  1018. (batch_size, 3, args.side_y, args.side_x),
  1019. clip_denoised=clip_denoised,
  1020. model_kwargs={},
  1021. cond_fn=cond_fn,
  1022. progress=True,
  1023. skip_timesteps=skip_steps,
  1024. init_image=init,
  1025. randomize_class=randomize_class,
  1026. order=2,
  1027. )
  1028. # with run_display:
  1029. # display.clear_output(wait=True)
  1030. imgToSharpen = None
  1031. for j, sample in enumerate(samples):
  1032. cur_t -= 1
  1033. intermediateStep = False
  1034. if args.steps_per_checkpoint is not None:
  1035. if j % steps_per_checkpoint == 0 and j > 0:
  1036. intermediateStep = True
  1037. elif j in args.intermediate_saves:
  1038. intermediateStep = True
  1039. with image_display:
  1040. if j % args.display_rate == 0 or cur_t == -1 or intermediateStep == True:
  1041. for k, image in enumerate(sample['pred_xstart']):
  1042. # tqdm.write(f'Batch {i}, step {j}, output {k}:')
  1043. current_time = datetime.now().strftime('%y%m%d-%H%M%S_%f')
  1044. percent = math.ceil(j/total_steps*100)
  1045. if args.n_batches > 0:
  1046. #if intermediates are saved to the subfolder, don't append a step or percentage to the name
  1047. if cur_t == -1 and args.intermediates_in_subfolder is True:
  1048. save_num = f'{frame_num:04}' if animation_mode != "None" else i
  1049. filename = f'{args.batch_name}({args.batchNum})_{save_num}.png'
  1050. else:
  1051. #If we're working with percentages, append it
  1052. if args.steps_per_checkpoint is not None:
  1053. filename = f'{args.batch_name}({args.batchNum})_{i:04}-{percent:02}%.png'
  1054. # Or else, iIf we're working with specific steps, append those
  1055. else:
  1056. filename = f'{args.batch_name}({args.batchNum})_{i:04}-{j:03}.png'
  1057. image = TF.to_pil_image(image.add(1).div(2).clamp(0, 1))
  1058. if j % args.display_rate == 0 or cur_t == -1:
  1059. image.save('progress.png')
  1060. display.clear_output(wait=True)
  1061. display.display(display.Image('progress.png'))
  1062. if args.steps_per_checkpoint is not None:
  1063. if j % args.steps_per_checkpoint == 0 and j > 0:
  1064. if args.intermediates_in_subfolder is True:
  1065. image.save(f'{partialFolder}/{filename}')
  1066. else:
  1067. image.save(f'{batchFolder}/{filename}')
  1068. else:
  1069. if j in args.intermediate_saves:
  1070. if args.intermediates_in_subfolder is True:
  1071. image.save(f'{partialFolder}/{filename}')
  1072. else:
  1073. image.save(f'{batchFolder}/{filename}')
  1074. if cur_t == -1:
  1075. if frame_num == 0:
  1076. save_settings()
  1077. if args.animation_mode != "None":
  1078. image.save('prevFrame.png')
  1079. if args.sharpen_preset != "Off" and animation_mode == "None":
  1080. imgToSharpen = image
  1081. if args.keep_unsharp is True:
  1082. image.save(f'{unsharpenFolder}/{filename}')
  1083. else:
  1084. image.save(f'{batchFolder}/{filename}')
  1085. if args.animation_mode == "3D":
  1086. # If turbo, save a blended image
  1087. if turbo_mode:
  1088. # Mix new image with prevFrameScaled
  1089. blend_factor = (1)/int(turbo_steps)
  1090. newFrame = cv2.imread('prevFrame.png') # This is already updated..
  1091. prev_frame_warped = cv2.imread('prevFrameScaled.png')
  1092. blendedImage = cv2.addWeighted(newFrame, blend_factor, prev_frame_warped, (1-blend_factor), 0.0)
  1093. cv2.imwrite(f'{batchFolder}/{filename}',blendedImage)
  1094. else:
  1095. image.save(f'{batchFolder}/{filename}')
  1096. # if frame_num != args.max_frames-1:
  1097. # display.clear_output()
  1098. with image_display:
  1099. if args.sharpen_preset != "Off" and animation_mode == "None":
  1100. print('Starting Diffusion Sharpening...')
  1101. do_superres(imgToSharpen, f'{batchFolder}/{filename}')
  1102. display.clear_output()
  1103. plt.plot(np.array(loss_values), 'r')
  1104. def save_settings():
  1105. setting_list = {
  1106. 'text_prompts': text_prompts,
  1107. 'image_prompts': image_prompts,
  1108. 'clip_guidance_scale': clip_guidance_scale,
  1109. 'tv_scale': tv_scale,
  1110. 'range_scale': range_scale,
  1111. 'sat_scale': sat_scale,
  1112. # 'cutn': cutn,
  1113. 'cutn_batches': cutn_batches,
  1114. 'max_frames': max_frames,
  1115. 'interp_spline': interp_spline,
  1116. # 'rotation_per_frame': rotation_per_frame,
  1117. 'init_image': init_image,
  1118. 'init_scale': init_scale,
  1119. 'skip_steps': skip_steps,
  1120. # 'zoom_per_frame': zoom_per_frame,
  1121. 'frames_scale': frames_scale,
  1122. 'frames_skip_steps': frames_skip_steps,
  1123. 'perlin_init': perlin_init,
  1124. 'perlin_mode': perlin_mode,
  1125. 'skip_augs': skip_augs,
  1126. 'randomize_class': randomize_class,
  1127. 'clip_denoised': clip_denoised,
  1128. 'clamp_grad': clamp_grad,
  1129. 'clamp_max': clamp_max,
  1130. 'seed': seed,
  1131. 'fuzzy_prompt': fuzzy_prompt,
  1132. 'rand_mag': rand_mag,
  1133. 'eta': eta,
  1134. 'width': width_height[0],
  1135. 'height': width_height[1],
  1136. 'diffusion_model': diffusion_model,
  1137. 'use_secondary_model': use_secondary_model,
  1138. 'steps': steps,
  1139. 'diffusion_steps': diffusion_steps,
  1140. 'diffusion_sampling_mode': diffusion_sampling_mode,
  1141. 'ViTB32': ViTB32,
  1142. 'ViTB16': ViTB16,
  1143. 'ViTL14': ViTL14,
  1144. 'RN101': RN101,
  1145. 'RN50': RN50,
  1146. 'RN50x4': RN50x4,
  1147. 'RN50x16': RN50x16,
  1148. 'RN50x64': RN50x64,
  1149. 'cut_overview': str(cut_overview),
  1150. 'cut_innercut': str(cut_innercut),
  1151. 'cut_ic_pow': cut_ic_pow,
  1152. 'cut_icgray_p': str(cut_icgray_p),
  1153. 'key_frames': key_frames,
  1154. 'max_frames': max_frames,
  1155. 'angle': angle,
  1156. 'zoom': zoom,
  1157. 'translation_x': translation_x,
  1158. 'translation_y': translation_y,
  1159. 'translation_z': translation_z,
  1160. 'rotation_3d_x': rotation_3d_x,
  1161. 'rotation_3d_y': rotation_3d_y,
  1162. 'rotation_3d_z': rotation_3d_z,
  1163. 'midas_depth_model': midas_depth_model,
  1164. 'midas_weight': midas_weight,
  1165. 'near_plane': near_plane,
  1166. 'far_plane': far_plane,
  1167. 'fov': fov,
  1168. 'padding_mode': padding_mode,
  1169. 'sampling_mode': sampling_mode,
  1170. 'video_init_path':video_init_path,
  1171. 'extract_nth_frame':extract_nth_frame,
  1172. 'video_init_seed_continuity': video_init_seed_continuity,
  1173. 'turbo_mode':turbo_mode,
  1174. 'turbo_steps':turbo_steps,
  1175. 'turbo_preroll':turbo_preroll,
  1176. }
  1177. # print('Settings:', setting_list)
  1178. with open(f"{batchFolder}/{batch_name}({batchNum})_settings.txt", "w+") as f: #save settings
  1179. json.dump(setting_list, f, ensure_ascii=False, indent=4)
  1180. #@title 1.6 Define the secondary diffusion model
  1181. def append_dims(x, n):
  1182. return x[(Ellipsis, *(None,) * (n - x.ndim))]
  1183. def expand_to_planes(x, shape):
  1184. return append_dims(x, len(shape)).repeat([1, 1, *shape[2:]])
  1185. def alpha_sigma_to_t(alpha, sigma):
  1186. return torch.atan2(sigma, alpha) * 2 / math.pi
  1187. def t_to_alpha_sigma(t):
  1188. return torch.cos(t * math.pi / 2), torch.sin(t * math.pi / 2)
  1189. @dataclass
  1190. class DiffusionOutput:
  1191. v: torch.Tensor
  1192. pred: torch.Tensor
  1193. eps: torch.Tensor
  1194. class ConvBlock(nn.Sequential):
  1195. def __init__(self, c_in, c_out):
  1196. super().__init__(
  1197. nn.Conv2d(c_in, c_out, 3, padding=1),
  1198. nn.ReLU(inplace=True),
  1199. )
  1200. class SkipBlock(nn.Module):
  1201. def __init__(self, main, skip=None):
  1202. super().__init__()
  1203. self.main = nn.Sequential(*main)
  1204. self.skip = skip if skip else nn.Identity()
  1205. def forward(self, input):
  1206. return torch.cat([self.main(input), self.skip(input)], dim=1)
  1207. class FourierFeatures(nn.Module):
  1208. def __init__(self, in_features, out_features, std=1.):
  1209. super().__init__()
  1210. assert out_features % 2 == 0
  1211. self.weight = nn.Parameter(torch.randn([out_features // 2, in_features]) * std)
  1212. def forward(self, input):
  1213. f = 2 * math.pi * input @ self.weight.T
  1214. return torch.cat([f.cos(), f.sin()], dim=-1)
  1215. class SecondaryDiffusionImageNet(nn.Module):
  1216. def __init__(self):
  1217. super().__init__()
  1218. c = 64 # The base channel count
  1219. self.timestep_embed = FourierFeatures(1, 16)
  1220. self.net = nn.Sequential(
  1221. ConvBlock(3 + 16, c),
  1222. ConvBlock(c, c),
  1223. SkipBlock([
  1224. nn.AvgPool2d(2),
  1225. ConvBlock(c, c * 2),
  1226. ConvBlock(c * 2, c * 2),
  1227. SkipBlock([
  1228. nn.AvgPool2d(2),
  1229. ConvBlock(c * 2, c * 4),
  1230. ConvBlock(c * 4, c * 4),
  1231. SkipBlock([
  1232. nn.AvgPool2d(2),
  1233. ConvBlock(c * 4, c * 8),
  1234. ConvBlock(c * 8, c * 4),
  1235. nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
  1236. ]),
  1237. ConvBlock(c * 8, c * 4),
  1238. ConvBlock(c * 4, c * 2),
  1239. nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
  1240. ]),
  1241. ConvBlock(c * 4, c * 2),
  1242. ConvBlock(c * 2, c),
  1243. nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False),
  1244. ]),
  1245. ConvBlock(c * 2, c),
  1246. nn.Conv2d(c, 3, 3, padding=1),
  1247. )
  1248. def forward(self, input, t):
  1249. timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape)
  1250. v = self.net(torch.cat([input, timestep_embed], dim=1))
  1251. alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t))
  1252. pred = input * alphas - v * sigmas
  1253. eps = input * sigmas + v * alphas
  1254. return DiffusionOutput(v, pred, eps)
  1255. class SecondaryDiffusionImageNet2(nn.Module):
  1256. def __init__(self):
  1257. super().__init__()
  1258. c = 64 # The base channel count
  1259. cs = [c, c * 2, c * 2, c * 4, c * 4, c * 8]
  1260. self.timestep_embed = FourierFeatures(1, 16)
  1261. self.down = nn.AvgPool2d(2)
  1262. self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False)
  1263. self.net = nn.Sequential(
  1264. ConvBlock(3 + 16, cs[0]),
  1265. ConvBlock(cs[0], cs[0]),
  1266. SkipBlock([
  1267. self.down,
  1268. ConvBlock(cs[0], cs[1]),
  1269. ConvBlock(cs[1], cs[1]),
  1270. SkipBlock([
  1271. self.down,
  1272. ConvBlock(cs[1], cs[2]),
  1273. ConvBlock(cs[2], cs[2]),
  1274. SkipBlock([
  1275. self.down,
  1276. ConvBlock(cs[2], cs[3]),
  1277. ConvBlock(cs[3], cs[3]),
  1278. SkipBlock([
  1279. self.down,
  1280. ConvBlock(cs[3], cs[4]),
  1281. ConvBlock(cs[4], cs[4]),
  1282. SkipBlock([
  1283. self.down,
  1284. ConvBlock(cs[4], cs[5]),
  1285. ConvBlock(cs[5], cs[5]),
  1286. ConvBlock(cs[5], cs[5]),
  1287. ConvBlock(cs[5], cs[4]),
  1288. self.up,
  1289. ]),
  1290. ConvBlock(cs[4] * 2, cs[4]),
  1291. ConvBlock(cs[4], cs[3]),
  1292. self.up,
  1293. ]),
  1294. ConvBlock(cs[3] * 2, cs[3]),
  1295. ConvBlock(cs[3], cs[2]),
  1296. self.up,
  1297. ]),
  1298. ConvBlock(cs[2] * 2, cs[2]),
  1299. ConvBlock(cs[2], cs[1]),
  1300. self.up,
  1301. ]),
  1302. ConvBlock(cs[1] * 2, cs[1]),
  1303. ConvBlock(cs[1], cs[0]),
  1304. self.up,
  1305. ]),
  1306. ConvBlock(cs[0] * 2, cs[0]),
  1307. nn.Conv2d(cs[0], 3, 3, padding=1),
  1308. )
  1309. def forward(self, input, t):
  1310. timestep_embed = expand_to_planes(self.timestep_embed(t[:, None]), input.shape)
  1311. v = self.net(torch.cat([input, timestep_embed], dim=1))
  1312. alphas, sigmas = map(partial(append_dims, n=v.ndim), t_to_alpha_sigma(t))
  1313. pred = input * alphas - v * sigmas
  1314. eps = input * sigmas + v * alphas
  1315. return DiffusionOutput(v, pred, eps)
  1316. #@title 1.7 SuperRes Define
  1317. class DDIMSampler(object):
  1318. def __init__(self, model, schedule="linear", **kwargs):
  1319. super().__init__()
  1320. self.model = model
  1321. self.ddpm_num_timesteps = model.num_timesteps
  1322. self.schedule = schedule
  1323. def register_buffer(self, name, attr):
  1324. if type(attr) == torch.Tensor:
  1325. if attr.device != torch.device("cuda"):
  1326. attr = attr.to(torch.device("cuda"))
  1327. setattr(self, name, attr)
  1328. def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True):
  1329. self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps,
  1330. num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose)
  1331. alphas_cumprod = self.model.alphas_cumprod
  1332. assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep'
  1333. to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device)
  1334. self.register_buffer('betas', to_torch(self.model.betas))
  1335. self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
  1336. self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev))
  1337. # calculations for diffusion q(x_t | x_{t-1}) and others
  1338. self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu())))
  1339. self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu())))
  1340. self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu())))
  1341. self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu())))
  1342. self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1)))
  1343. # ddim sampling parameters
  1344. ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(),
  1345. ddim_timesteps=self.ddim_timesteps,
  1346. eta=ddim_eta,verbose=verbose)
  1347. self.register_buffer('ddim_sigmas', ddim_sigmas)
  1348. self.register_buffer('ddim_alphas', ddim_alphas)
  1349. self.register_buffer('ddim_alphas_prev', ddim_alphas_prev)
  1350. self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas))
  1351. sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt(
  1352. (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * (
  1353. 1 - self.alphas_cumprod / self.alphas_cumprod_prev))
  1354. self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps)
  1355. @torch.no_grad()
  1356. def sample(self,
  1357. S,
  1358. batch_size,
  1359. shape,
  1360. conditioning=None,
  1361. callback=None,
  1362. normals_sequence=None,
  1363. img_callback=None,
  1364. quantize_x0=False,
  1365. eta=0.,
  1366. mask=None,
  1367. x0=None,
  1368. temperature=1.,
  1369. noise_dropout=0.,
  1370. score_corrector=None,
  1371. corrector_kwargs=None,
  1372. verbose=True,
  1373. x_T=None,
  1374. log_every_t=100,
  1375. **kwargs
  1376. ):
  1377. if conditioning is not None:
  1378. if isinstance(conditioning, dict):
  1379. cbs = conditioning[list(conditioning.keys())[0]].shape[0]
  1380. if cbs != batch_size:
  1381. print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
  1382. else:
  1383. if conditioning.shape[0] != batch_size:
  1384. print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
  1385. self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose)
  1386. # sampling
  1387. C, H, W = shape
  1388. size = (batch_size, C, H, W)
  1389. # print(f'Data shape for DDIM sampling is {size}, eta {eta}')
  1390. samples, intermediates = self.ddim_sampling(conditioning, size,
  1391. callback=callback,
  1392. img_callback=img_callback,
  1393. quantize_denoised=quantize_x0,
  1394. mask=mask, x0=x0,
  1395. ddim_use_original_steps=False,
  1396. noise_dropout=noise_dropout,
  1397. temperature=temperature,
  1398. score_corrector=score_corrector,
  1399. corrector_kwargs=corrector_kwargs,
  1400. x_T=x_T,
  1401. log_every_t=log_every_t
  1402. )
  1403. return samples, intermediates
  1404. @torch.no_grad()
  1405. def ddim_sampling(self, cond, shape,
  1406. x_T=None, ddim_use_original_steps=False,
  1407. callback=None, timesteps=None, quantize_denoised=False,
  1408. mask=None, x0=None, img_callback=None, log_every_t=100,
  1409. temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
  1410. device = self.model.betas.device
  1411. b = shape[0]
  1412. if x_T is None:
  1413. img = torch.randn(shape, device=device)
  1414. else:
  1415. img = x_T
  1416. if timesteps is None:
  1417. timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps
  1418. elif timesteps is not None and not ddim_use_original_steps:
  1419. subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1
  1420. timesteps = self.ddim_timesteps[:subset_end]
  1421. intermediates = {'x_inter': [img], 'pred_x0': [img]}
  1422. time_range = reversed(range(0,timesteps)) if ddim_use_original_steps else np.flip(timesteps)
  1423. total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0]
  1424. print(f"Running DDIM Sharpening with {total_steps} timesteps")
  1425. iterator = tqdm(time_range, desc='DDIM Sharpening', total=total_steps)
  1426. for i, step in enumerate(iterator):
  1427. index = total_steps - i - 1
  1428. ts = torch.full((b,), step, device=device, dtype=torch.long)
  1429. if mask is not None:
  1430. assert x0 is not None
  1431. img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass?
  1432. img = img_orig * mask + (1. - mask) * img
  1433. outs = self.p_sample_ddim(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps,
  1434. quantize_denoised=quantize_denoised, temperature=temperature,
  1435. noise_dropout=noise_dropout, score_corrector=score_corrector,
  1436. corrector_kwargs=corrector_kwargs)
  1437. img, pred_x0 = outs
  1438. if callback: callback(i)
  1439. if img_callback: img_callback(pred_x0, i)
  1440. if index % log_every_t == 0 or index == total_steps - 1:
  1441. intermediates['x_inter'].append(img)
  1442. intermediates['pred_x0'].append(pred_x0)
  1443. return img, intermediates
  1444. @torch.no_grad()
  1445. def p_sample_ddim(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False,
  1446. temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None):
  1447. b, *_, device = *x.shape, x.device
  1448. e_t = self.model.apply_model(x, t, c)
  1449. if score_corrector is not None:
  1450. assert self.model.parameterization == "eps"
  1451. e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs)
  1452. alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas
  1453. alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev
  1454. sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas
  1455. sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas
  1456. # select parameters corresponding to the currently considered timestep
  1457. a_t = torch.full((b, 1, 1, 1), alphas[index], device=device)
  1458. a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device)
  1459. sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device)
  1460. sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device)
  1461. # current prediction for x_0
  1462. pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()
  1463. if quantize_denoised:
  1464. pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0)
  1465. # direction pointing to x_t
  1466. dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t
  1467. noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature
  1468. if noise_dropout > 0.:
  1469. noise = torch.nn.functional.dropout(noise, p=noise_dropout)
  1470. x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise
  1471. return x_prev, pred_x0
  1472. def download_models(mode):
  1473. if mode == "superresolution":
  1474. # this is the small bsr light model
  1475. url_conf = 'https://heibox.uni-heidelberg.de/f/31a76b13ea27482981b4/?dl=1'
  1476. url_ckpt = 'https://heibox.uni-heidelberg.de/f/578df07c8fc04ffbadf3/?dl=1'
  1477. path_conf = f'{model_path}/superres/project.yaml'
  1478. path_ckpt = f'{model_path}/superres/last.ckpt'
  1479. download_url(url_conf, path_conf)
  1480. download_url(url_ckpt, path_ckpt)
  1481. path_conf = path_conf + '/?dl=1' # fix it
  1482. path_ckpt = path_ckpt + '/?dl=1' # fix it
  1483. return path_conf, path_ckpt
  1484. else:
  1485. raise NotImplementedError
  1486. def load_model_from_config(config, ckpt):
  1487. print(f"Loading model from {ckpt}")
  1488. pl_sd = torch.load(ckpt, map_location="cpu")
  1489. global_step = pl_sd["global_step"]
  1490. sd = pl_sd["state_dict"]
  1491. model = instantiate_from_config(config.model)
  1492. m, u = model.load_state_dict(sd, strict=False)
  1493. model.cuda()
  1494. model.eval()
  1495. return {"model": model}, global_step
  1496. def get_model(mode):
  1497. path_conf, path_ckpt = download_models(mode)
  1498. config = OmegaConf.load(path_conf)
  1499. model, step = load_model_from_config(config, path_ckpt)
  1500. return model
  1501. def get_custom_cond(mode):
  1502. dest = "data/example_conditioning"
  1503. if mode == "superresolution":
  1504. uploaded_img = files.upload()
  1505. filename = next(iter(uploaded_img))
  1506. name, filetype = filename.split(".") # todo assumes just one dot in name !
  1507. os.rename(f"{filename}", f"{dest}/{mode}/custom_{name}.{filetype}")
  1508. elif mode == "text_conditional":
  1509. w = widgets.Text(value='A cake with cream!', disabled=True)
  1510. display.display(w)
  1511. with open(f"{dest}/{mode}/custom_{w.value[:20]}.txt", 'w') as f:
  1512. f.write(w.value)
  1513. elif mode == "class_conditional":
  1514. w = widgets.IntSlider(min=0, max=1000)
  1515. display.display(w)
  1516. with open(f"{dest}/{mode}/custom.txt", 'w') as f:
  1517. f.write(w.value)
  1518. else:
  1519. raise NotImplementedError(f"cond not implemented for mode{mode}")
  1520. def get_cond_options(mode):
  1521. path = "data/example_conditioning"
  1522. path = os.path.join(path, mode)
  1523. onlyfiles = [f for f in sorted(os.listdir(path))]
  1524. return path, onlyfiles
  1525. def select_cond_path(mode):
  1526. path = "data/example_conditioning" # todo
  1527. path = os.path.join(path, mode)
  1528. onlyfiles = [f for f in sorted(os.listdir(path))]
  1529. selected = widgets.RadioButtons(
  1530. options=onlyfiles,
  1531. description='Select conditioning:',
  1532. disabled=False
  1533. )
  1534. display.display(selected)
  1535. selected_path = os.path.join(path, selected.value)
  1536. return selected_path
  1537. def get_cond(mode, img):
  1538. example = dict()
  1539. if mode == "superresolution":
  1540. up_f = 4
  1541. # visualize_cond_img(selected_path)
  1542. c = img
  1543. c = torch.unsqueeze(torchvision.transforms.ToTensor()(c), 0)
  1544. c_up = torchvision.transforms.functional.resize(c, size=[up_f * c.shape[2], up_f * c.shape[3]], antialias=True)
  1545. c_up = rearrange(c_up, '1 c h w -> 1 h w c')
  1546. c = rearrange(c, '1 c h w -> 1 h w c')
  1547. c = 2. * c - 1.
  1548. c = c.to(torch.device("cuda"))
  1549. example["LR_image"] = c
  1550. example["image"] = c_up
  1551. return example
  1552. def visualize_cond_img(path):
  1553. display.display(ipyimg(filename=path))
  1554. def sr_run(model, img, task, custom_steps, eta, resize_enabled=False, classifier_ckpt=None, global_step=None):
  1555. # global stride
  1556. example = get_cond(task, img)
  1557. save_intermediate_vid = False
  1558. n_runs = 1
  1559. masked = False
  1560. guider = None
  1561. ckwargs = None
  1562. mode = 'ddim'
  1563. ddim_use_x0_pred = False
  1564. temperature = 1.
  1565. eta = eta
  1566. make_progrow = True
  1567. custom_shape = None
  1568. height, width = example["image"].shape[1:3]
  1569. split_input = height >= 128 and width >= 128
  1570. if split_input:
  1571. ks = 128
  1572. stride = 64
  1573. vqf = 4 #
  1574. model.split_input_params = {"ks": (ks, ks), "stride": (stride, stride),
  1575. "vqf": vqf,
  1576. "patch_distributed_vq": True,
  1577. "tie_braker": False,
  1578. "clip_max_weight": 0.5,
  1579. "clip_min_weight": 0.01,
  1580. "clip_max_tie_weight": 0.5,
  1581. "clip_min_tie_weight": 0.01}
  1582. else:
  1583. if hasattr(model, "split_input_params"):
  1584. delattr(model, "split_input_params")
  1585. invert_mask = False
  1586. x_T = None
  1587. for n in range(n_runs):
  1588. if custom_shape is not None:
  1589. x_T = torch.randn(1, custom_shape[1], custom_shape[2], custom_shape[3]).to(model.device)
  1590. x_T = repeat(x_T, '1 c h w -> b c h w', b=custom_shape[0])
  1591. logs = make_convolutional_sample(example, model,
  1592. mode=mode, custom_steps=custom_steps,
  1593. eta=eta, swap_mode=False , masked=masked,
  1594. invert_mask=invert_mask, quantize_x0=False,
  1595. custom_schedule=None, decode_interval=10,
  1596. resize_enabled=resize_enabled, custom_shape=custom_shape,
  1597. temperature=temperature, noise_dropout=0.,
  1598. corrector=guider, corrector_kwargs=ckwargs, x_T=x_T, save_intermediate_vid=save_intermediate_vid,
  1599. make_progrow=make_progrow,ddim_use_x0_pred=ddim_use_x0_pred
  1600. )
  1601. return logs
  1602. @torch.no_grad()
  1603. def convsample_ddim(model, cond, steps, shape, eta=1.0, callback=None, normals_sequence=None,
  1604. mask=None, x0=None, quantize_x0=False, img_callback=None,
  1605. temperature=1., noise_dropout=0., score_corrector=None,
  1606. corrector_kwargs=None, x_T=None, log_every_t=None
  1607. ):
  1608. ddim = DDIMSampler(model)
  1609. bs = shape[0] # dont know where this comes from but wayne
  1610. shape = shape[1:] # cut batch dim
  1611. # print(f"Sampling with eta = {eta}; steps: {steps}")
  1612. samples, intermediates = ddim.sample(steps, batch_size=bs, shape=shape, conditioning=cond, callback=callback,
  1613. normals_sequence=normals_sequence, quantize_x0=quantize_x0, eta=eta,
  1614. mask=mask, x0=x0, temperature=temperature, verbose=False,
  1615. score_corrector=score_corrector,
  1616. corrector_kwargs=corrector_kwargs, x_T=x_T)
  1617. return samples, intermediates
  1618. @torch.no_grad()
  1619. def make_convolutional_sample(batch, model, mode="vanilla", custom_steps=None, eta=1.0, swap_mode=False, masked=False,
  1620. invert_mask=True, quantize_x0=False, custom_schedule=None, decode_interval=1000,
  1621. resize_enabled=False, custom_shape=None, temperature=1., noise_dropout=0., corrector=None,
  1622. corrector_kwargs=None, x_T=None, save_intermediate_vid=False, make_progrow=True,ddim_use_x0_pred=False):
  1623. log = dict()
  1624. z, c, x, xrec, xc = model.get_input(batch, model.first_stage_key,
  1625. return_first_stage_outputs=True,
  1626. force_c_encode=not (hasattr(model, 'split_input_params')
  1627. and model.cond_stage_key == 'coordinates_bbox'),
  1628. return_original_cond=True)
  1629. log_every_t = 1 if save_intermediate_vid else None
  1630. if custom_shape is not None:
  1631. z = torch.randn(custom_shape)
  1632. # print(f"Generating {custom_shape[0]} samples of shape {custom_shape[1:]}")
  1633. z0 = None
  1634. log["input"] = x
  1635. log["reconstruction"] = xrec
  1636. if ismap(xc):
  1637. log["original_conditioning"] = model.to_rgb(xc)
  1638. if hasattr(model, 'cond_stage_key'):
  1639. log[model.cond_stage_key] = model.to_rgb(xc)
  1640. else:
  1641. log["original_conditioning"] = xc if xc is not None else torch.zeros_like(x)
  1642. if model.cond_stage_model:
  1643. log[model.cond_stage_key] = xc if xc is not None else torch.zeros_like(x)
  1644. if model.cond_stage_key =='class_label':
  1645. log[model.cond_stage_key] = xc[model.cond_stage_key]
  1646. with model.ema_scope("Plotting"):
  1647. t0 = time.time()
  1648. img_cb = None
  1649. sample, intermediates = convsample_ddim(model, c, steps=custom_steps, shape=z.shape,
  1650. eta=eta,
  1651. quantize_x0=quantize_x0, img_callback=img_cb, mask=None, x0=z0,
  1652. temperature=temperature, noise_dropout=noise_dropout,
  1653. score_corrector=corrector, corrector_kwargs=corrector_kwargs,
  1654. x_T=x_T, log_every_t=log_every_t)
  1655. t1 = time.time()
  1656. if ddim_use_x0_pred:
  1657. sample = intermediates['pred_x0'][-1]
  1658. x_sample = model.decode_first_stage(sample)
  1659. try:
  1660. x_sample_noquant = model.decode_first_stage(sample, force_not_quantize=True)
  1661. log["sample_noquant"] = x_sample_noquant
  1662. log["sample_diff"] = torch.abs(x_sample_noquant - x_sample)
  1663. except:
  1664. pass
  1665. log["sample"] = x_sample
  1666. log["time"] = t1 - t0
  1667. return log
  1668. sr_diffMode = 'superresolution'
  1669. sr_model = get_model('superresolution')
  1670. def do_superres(img, filepath):
  1671. if args.sharpen_preset == 'Faster':
  1672. sr_diffusion_steps = "25"
  1673. sr_pre_downsample = '1/2'
  1674. if args.sharpen_preset == 'Fast':
  1675. sr_diffusion_steps = "100"
  1676. sr_pre_downsample = '1/2'
  1677. if args.sharpen_preset == 'Slow':
  1678. sr_diffusion_steps = "25"
  1679. sr_pre_downsample = 'None'
  1680. if args.sharpen_preset == 'Very Slow':
  1681. sr_diffusion_steps = "100"
  1682. sr_pre_downsample = 'None'
  1683. sr_post_downsample = 'Original Size'
  1684. sr_diffusion_steps = int(sr_diffusion_steps)
  1685. sr_eta = 1.0
  1686. sr_downsample_method = 'Lanczos'
  1687. gc.collect()
  1688. torch.cuda.empty_cache()
  1689. im_og = img
  1690. width_og, height_og = im_og.size
  1691. #Downsample Pre
  1692. if sr_pre_downsample == '1/2':
  1693. downsample_rate = 2
  1694. elif sr_pre_downsample == '1/4':
  1695. downsample_rate = 4
  1696. else:
  1697. downsample_rate = 1
  1698. width_downsampled_pre = width_og//downsample_rate
  1699. height_downsampled_pre = height_og//downsample_rate
  1700. if downsample_rate != 1:
  1701. # print(f'Downsampling from [{width_og}, {height_og}] to [{width_downsampled_pre}, {height_downsampled_pre}]')
  1702. im_og = im_og.resize((width_downsampled_pre, height_downsampled_pre), Image.LANCZOS)
  1703. # im_og.save('/content/temp.png')
  1704. # filepath = '/content/temp.png'
  1705. logs = sr_run(sr_model["model"], im_og, sr_diffMode, sr_diffusion_steps, sr_eta)
  1706. sample = logs["sample"]
  1707. sample = sample.detach().cpu()
  1708. sample = torch.clamp(sample, -1., 1.)
  1709. sample = (sample + 1.) / 2. * 255
  1710. sample = sample.numpy().astype(np.uint8)
  1711. sample = np.transpose(sample, (0, 2, 3, 1))
  1712. a = Image.fromarray(sample[0])
  1713. #Downsample Post
  1714. if sr_post_downsample == '1/2':
  1715. downsample_rate = 2
  1716. elif sr_post_downsample == '1/4':
  1717. downsample_rate = 4
  1718. else:
  1719. downsample_rate = 1
  1720. width, height = a.size
  1721. width_downsampled_post = width//downsample_rate
  1722. height_downsampled_post = height//downsample_rate
  1723. if sr_downsample_method == 'Lanczos':
  1724. aliasing = Image.LANCZOS
  1725. else:
  1726. aliasing = Image.NEAREST
  1727. if downsample_rate != 1:
  1728. # print(f'Downsampling from [{width}, {height}] to [{width_downsampled_post}, {height_downsampled_post}]')
  1729. a = a.resize((width_downsampled_post, height_downsampled_post), aliasing)
  1730. elif sr_post_downsample == 'Original Size':
  1731. # print(f'Downsampling from [{width}, {height}] to Original Size [{width_og}, {height_og}]')
  1732. a = a.resize((width_og, height_og), aliasing)
  1733. display.display(a)
  1734. a.save(filepath)
  1735. return
  1736. print(f'Processing finished!')
  1737. """# 2. Diffusion and CLIP model settings"""
  1738. #@markdown ####**Models Settings:**
  1739. diffusion_model = "512x512_diffusion_uncond_finetune_008100" #@param ["256x256_diffusion_uncond", "512x512_diffusion_uncond_finetune_008100"]
  1740. use_secondary_model = True #@param {type: 'boolean'}
  1741. diffusion_sampling_mode = 'ddim' #@param ['plms','ddim']
  1742. timestep_respacing = '250' #@param ['25','50','100','150','250','500','1000','ddim25','ddim50', 'ddim75', 'ddim100','ddim150','ddim250','ddim500','ddim1000']
  1743. diffusion_steps = 300 #@param {type: 'number'}
  1744. use_checkpoint = True #@param {type: 'boolean'}
  1745. ViTB32 = True #@param{type:"boolean"}
  1746. ViTB16 = True #@param{type:"boolean"}
  1747. ViTL14 = False #@param{type:"boolean"}
  1748. RN101 = False #@param{type:"boolean"}
  1749. RN50 = True #@param{type:"boolean"}
  1750. RN50x4 = False #@param{type:"boolean"}
  1751. RN50x16 = False #@param{type:"boolean"}
  1752. RN50x64 = False #@param{type:"boolean"}
  1753. SLIPB16 = False #@param{type:"boolean"}
  1754. SLIPL16 = False #@param{type:"boolean"}
  1755. #@markdown If you're having issues with model downloads, check this to compare SHA's:
  1756. check_model_SHA = False #@param{type:"boolean"}
  1757. model_256_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a'
  1758. model_512_SHA = '9c111ab89e214862b76e1fa6a1b3f1d329b1a88281885943d2cdbe357ad57648'
  1759. model_secondary_SHA = '983e3de6f95c88c81b2ca7ebb2c217933be1973b1ff058776b970f901584613a'
  1760. model_256_link = 'https://openaipublic.blob.core.windows.net/diffusion/jul-2021/256x256_diffusion_uncond.pt'
  1761. model_512_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/512x512_diffusion_uncond_finetune_008100.pt'
  1762. model_secondary_link = 'https://v-diffusion.s3.us-west-2.amazonaws.com/secondary_model_imagenet_2.pth'
  1763. model_256_path = f'{model_path}/256x256_diffusion_uncond.pt'
  1764. model_512_path = f'{model_path}/512x512_diffusion_uncond_finetune_008100.pt'
  1765. model_secondary_path = f'{model_path}/secondary_model_imagenet_2.pth'
  1766. # Download the diffusion model
  1767. if diffusion_model == '256x256_diffusion_uncond':
  1768. if os.path.exists(model_256_path) and check_model_SHA:
  1769. print('Checking 256 Diffusion File')
  1770. with open(model_256_path,"rb") as f:
  1771. bytes = f.read()
  1772. hash = hashlib.sha256(bytes).hexdigest();
  1773. if hash == model_256_SHA:
  1774. print('256 Model SHA matches')
  1775. model_256_downloaded = True
  1776. else:
  1777. print("256 Model SHA doesn't match, redownloading...")
  1778. wget(model_256_link, model_path)
  1779. model_256_downloaded = True
  1780. elif os.path.exists(model_256_path) and not check_model_SHA or model_256_downloaded == True:
  1781. print('256 Model already downloaded, check check_model_SHA if the file is corrupt')
  1782. else:
  1783. wget(model_256_link, model_path)
  1784. model_256_downloaded = True
  1785. elif diffusion_model == '512x512_diffusion_uncond_finetune_008100':
  1786. if os.path.exists(model_512_path) and check_model_SHA:
  1787. print('Checking 512 Diffusion File')
  1788. with open(model_512_path,"rb") as f:
  1789. bytes = f.read()
  1790. hash = hashlib.sha256(bytes).hexdigest();
  1791. if hash == model_512_SHA:
  1792. print('512 Model SHA matches')
  1793. model_512_downloaded = True
  1794. else:
  1795. print("512 Model SHA doesn't match, redownloading...")
  1796. wget(model_512_link, model_path)
  1797. model_512_downloaded = True
  1798. elif os.path.exists(model_512_path) and not check_model_SHA or model_512_downloaded == True:
  1799. print('512 Model already downloaded, check check_model_SHA if the file is corrupt')
  1800. else:
  1801. wget(model_512_link, model_path)
  1802. model_512_downloaded = True
  1803. # Download the secondary diffusion model v2
  1804. if use_secondary_model == True:
  1805. if os.path.exists(model_secondary_path) and check_model_SHA:
  1806. print('Checking Secondary Diffusion File')
  1807. with open(model_secondary_path,"rb") as f:
  1808. bytes = f.read()
  1809. hash = hashlib.sha256(bytes).hexdigest();
  1810. if hash == model_secondary_SHA:
  1811. print('Secondary Model SHA matches')
  1812. model_secondary_downloaded = True
  1813. else:
  1814. print("Secondary Model SHA doesn't match, redownloading...")
  1815. wget(model_secondary_link, model_path)
  1816. model_secondary_downloaded = True
  1817. elif os.path.exists(model_secondary_path) and not check_model_SHA or model_secondary_downloaded == True:
  1818. print('Secondary Model already downloaded, check check_model_SHA if the file is corrupt')
  1819. else:
  1820. wget(model_secondary_link, model_path)
  1821. model_secondary_downloaded = True
  1822. model_config = model_and_diffusion_defaults()
  1823. if diffusion_model == '512x512_diffusion_uncond_finetune_008100':
  1824. model_config.update({
  1825. 'attention_resolutions': '32, 16, 8',
  1826. 'class_cond': False,
  1827. 'diffusion_steps': diffusion_steps,
  1828. 'rescale_timesteps': True,
  1829. 'timestep_respacing': timestep_respacing,
  1830. 'image_size': 512,
  1831. 'learn_sigma': True,
  1832. 'noise_schedule': 'linear',
  1833. 'num_channels': 256,
  1834. 'num_head_channels': 64,
  1835. 'num_res_blocks': 2,
  1836. 'resblock_updown': True,
  1837. 'use_checkpoint': use_checkpoint,
  1838. 'use_fp16': True,
  1839. 'use_scale_shift_norm': True,
  1840. })
  1841. elif diffusion_model == '256x256_diffusion_uncond':
  1842. model_config.update({
  1843. 'attention_resolutions': '32, 16, 8',
  1844. 'class_cond': False,
  1845. 'diffusion_steps': diffusion_steps,
  1846. 'rescale_timesteps': True,
  1847. 'timestep_respacing': timestep_respacing,
  1848. 'image_size': 256,
  1849. 'learn_sigma': True,
  1850. 'noise_schedule': 'linear',
  1851. 'num_channels': 256,
  1852. 'num_head_channels': 64,
  1853. 'num_res_blocks': 2,
  1854. 'resblock_updown': True,
  1855. 'use_checkpoint': use_checkpoint,
  1856. 'use_fp16': True,
  1857. 'use_scale_shift_norm': True,
  1858. })
  1859. secondary_model_ver = 2
  1860. model_default = model_config['image_size']
  1861. if secondary_model_ver == 2:
  1862. secondary_model = SecondaryDiffusionImageNet2()
  1863. secondary_model.load_state_dict(torch.load(f'{model_path}/secondary_model_imagenet_2.pth', map_location='cpu'))
  1864. secondary_model.eval().requires_grad_(False).to(device)
  1865. clip_models = []
  1866. if ViTB32 is True: clip_models.append(clip.load('ViT-B/32', jit=False)[0].eval().requires_grad_(False).to(device))
  1867. if ViTB16 is True: clip_models.append(clip.load('ViT-B/16', jit=False)[0].eval().requires_grad_(False).to(device) )
  1868. if ViTL14 is True: clip_models.append(clip.load('ViT-L/14', jit=False)[0].eval().requires_grad_(False).to(device) )
  1869. if RN50 is True: clip_models.append(clip.load('RN50', jit=False)[0].eval().requires_grad_(False).to(device))
  1870. if RN50x4 is True: clip_models.append(clip.load('RN50x4', jit=False)[0].eval().requires_grad_(False).to(device))
  1871. if RN50x16 is True: clip_models.append(clip.load('RN50x16', jit=False)[0].eval().requires_grad_(False).to(device))
  1872. if RN50x64 is True: clip_models.append(clip.load('RN50x64', jit=False)[0].eval().requires_grad_(False).to(device))
  1873. if RN101 is True: clip_models.append(clip.load('RN101', jit=False)[0].eval().requires_grad_(False).to(device))
  1874. if SLIPB16:
  1875. SLIPB16model = SLIP_VITB16(ssl_mlp_dim=4096, ssl_emb_dim=256)
  1876. if not os.path.exists(f'{model_path}/slip_base_100ep.pt'):
  1877. wget("https://dl.fbaipublicfiles.com/slip/slip_base_100ep.pt", model_path)
  1878. sd = torch.load(f'{model_path}/slip_base_100ep.pt')
  1879. real_sd = {}
  1880. for k, v in sd['state_dict'].items():
  1881. real_sd['.'.join(k.split('.')[1:])] = v
  1882. del sd
  1883. SLIPB16model.load_state_dict(real_sd)
  1884. SLIPB16model.requires_grad_(False).eval().to(device)
  1885. clip_models.append(SLIPB16model)
  1886. if SLIPL16:
  1887. SLIPL16model = SLIP_VITL16(ssl_mlp_dim=4096, ssl_emb_dim=256)
  1888. if not os.path.exists(f'{model_path}/slip_large_100ep.pt'):
  1889. wget("https://dl.fbaipublicfiles.com/slip/slip_large_100ep.pt", model_path)
  1890. sd = torch.load(f'{model_path}/slip_large_100ep.pt')
  1891. real_sd = {}
  1892. for k, v in sd['state_dict'].items():
  1893. real_sd['.'.join(k.split('.')[1:])] = v
  1894. del sd
  1895. SLIPL16model.load_state_dict(real_sd)
  1896. SLIPL16model.requires_grad_(False).eval().to(device)
  1897. clip_models.append(SLIPL16model)
  1898. normalize = T.Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711])
  1899. lpips_model = lpips.LPIPS(net='vgg').to(device)
  1900. """# 3. Settings"""
  1901. #@markdown ####**Basic Settings:**
  1902. batch_name = 'new_House' #@param{type: 'string'}
  1903. steps = 300#@param [25,50,100,150,250,500,1000]{type: 'raw', allow-input: true}
  1904. width_height = [1280, 720]#@param{type: 'raw'}
  1905. clip_guidance_scale = 5000 #@param{type: 'number'}
  1906. tv_scale = 0#@param{type: 'number'}
  1907. range_scale = 150#@param{type: 'number'}
  1908. sat_scale = 0#@param{type: 'number'}
  1909. cutn_batches = 4#@param{type: 'number'}
  1910. skip_augs = False#@param{type: 'boolean'}
  1911. #@markdown ---
  1912. #@markdown ####**Init Settings:**
  1913. init_image = "/content/drive/MyDrive/AI/Disco_Diffusion/init_images/xv_1_decoupe_noback.jpg" #@param{type: 'string'}
  1914. init_scale = 1000 #@param{type: 'integer'}
  1915. skip_steps = 50 #@param{type: 'integer'}
  1916. #@markdown *Make sure you set skip_steps to ~50% of your steps if you want to use an init image.*
  1917. #Get corrected sizes
  1918. side_x = (width_height[0]//64)*64;
  1919. side_y = (width_height[1]//64)*64;
  1920. if side_x != width_height[0] or side_y != width_height[1]:
  1921. print(f'Changing output size to {side_x}x{side_y}. Dimensions must by multiples of 64.')
  1922. #Update Model Settings
  1923. timestep_respacing = f'ddim{steps}'
  1924. diffusion_steps = (1000//steps)*steps if steps < 1000 else steps
  1925. model_config.update({
  1926. 'timestep_respacing': timestep_respacing,
  1927. 'diffusion_steps': diffusion_steps,
  1928. })
  1929. #Make folder for batch
  1930. batchFolder = f'{outDirPath}/{batch_name}'
  1931. createPath(batchFolder)
  1932. """### Animation Settings"""
  1933. #@markdown ####**Animation Mode:**
  1934. animation_mode = 'None' #@param ['None', '2D', '3D', 'Video Input'] {type:'string'}
  1935. #@markdown *For animation, you probably want to turn `cutn_batches` to 1 to make it quicker.*
  1936. #@markdown ---
  1937. #@markdown ####**Video Input Settings:**
  1938. if is_colab:
  1939. video_init_path = "/content/training.mp4" #@param {type: 'string'}
  1940. else:
  1941. video_init_path = "training.mp4" #@param {type: 'string'}
  1942. extract_nth_frame = 2 #@param {type: 'number'}
  1943. video_init_seed_continuity = True #@param {type: 'boolean'}
  1944. if animation_mode == "Video Input":
  1945. if is_colab:
  1946. videoFramesFolder = f'/content/videoFrames'
  1947. else:
  1948. videoFramesFolder = f'videoFrames'
  1949. createPath(videoFramesFolder)
  1950. print(f"Exporting Video Frames (1 every {extract_nth_frame})...")
  1951. try:
  1952. for f in pathlib.Path(f'{videoFramesFolder}').glob('*.jpg'):
  1953. f.unlink()
  1954. except:
  1955. print('')
  1956. vf = f'"select=not(mod(n\,{extract_nth_frame}))"'
  1957. subprocess.run(['ffmpeg', '-i', f'{video_init_path}', '-vf', f'{vf}', '-vsync', 'vfr', '-q:v', '2', '-loglevel', 'error', '-stats', f'{videoFramesFolder}/%04d.jpg'], stdout=subprocess.PIPE).stdout.decode('utf-8')
  1958. #!ffmpeg -i {video_init_path} -vf {vf} -vsync vfr -q:v 2 -loglevel error -stats {videoFramesFolder}/%04d.jpg
  1959. #@markdown ---
  1960. #@markdown ####**2D Animation Settings:**
  1961. #@markdown `zoom` is a multiplier of dimensions, 1 is no zoom.
  1962. #@markdown All rotations are provided in degrees.
  1963. key_frames = True #@param {type:"boolean"}
  1964. max_frames = 10000#@param {type:"number"}
  1965. if animation_mode == "Video Input":
  1966. max_frames = len(glob(f'{videoFramesFolder}/*.jpg'))
  1967. interp_spline = 'Linear' #Do not change, currently will not look good. param ['Linear','Quadratic','Cubic']{type:"string"}
  1968. angle = "0:(0)"#@param {type:"string"}
  1969. zoom = "0: (1), 10: (1.05)"#@param {type:"string"}
  1970. translation_x = "0: (0)"#@param {type:"string"}
  1971. translation_y = "0: (0)"#@param {type:"string"}
  1972. translation_z = "0: (10.0)"#@param {type:"string"}
  1973. rotation_3d_x = "0: (0)"#@param {type:"string"}
  1974. rotation_3d_y = "0: (0)"#@param {type:"string"}
  1975. rotation_3d_z = "0: (0)"#@param {type:"string"}
  1976. midas_depth_model = "dpt_large"#@param {type:"string"}
  1977. midas_weight = 0.3#@param {type:"number"}
  1978. near_plane = 200#@param {type:"number"}
  1979. far_plane = 10000#@param {type:"number"}
  1980. fov = 40#@param {type:"number"}
  1981. padding_mode = 'border'#@param {type:"string"}
  1982. sampling_mode = 'bicubic'#@param {type:"string"}
  1983. #======= TURBO MODE
  1984. #@markdown ---
  1985. #@markdown ####**Turbo Mode (3D anim only):**
  1986. #@markdown (Starts after frame 10,) skips diffusion steps and just uses depth map to warp images for skipped frames.
  1987. #@markdown Speeds up rendering by 2x-4x, and may improve image coherence between frames. frame_blend_mode smooths abrupt texture changes across 2 frames.
  1988. #@markdown For different settings tuned for Turbo Mode, refer to the original Disco-Turbo Github: https://github.com/zippy731/disco-diffusion-turbo
  1989. turbo_mode = False #@param {type:"boolean"}
  1990. turbo_steps = "3" #@param ["2","3","4","5","6"] {type:"string"}
  1991. turbo_preroll = 10 # frames
  1992. #insist turbo be used only w 3d anim.
  1993. if turbo_mode and animation_mode != '3D':
  1994. print('=====')
  1995. print('Turbo mode only available with 3D animations. Disabling Turbo.')
  1996. print('=====')
  1997. turbo_mode = False
  1998. #@markdown ---
  1999. #@markdown ####**Coherency Settings:**
  2000. #@markdown `frame_scale` tries to guide the new frame to looking like the old one. A good default is 1500.
  2001. frames_scale = 1500 #@param{type: 'integer'}
  2002. #@markdown `frame_skip_steps` will blur the previous frame - higher values will flicker less but struggle to add enough new detail to zoom into.
  2003. frames_skip_steps = '60%' #@param ['40%', '50%', '60%', '70%', '80%'] {type: 'string'}
  2004. def parse_key_frames(string, prompt_parser=None):
  2005. """Given a string representing frame numbers paired with parameter values at that frame,
  2006. return a dictionary with the frame numbers as keys and the parameter values as the values.
  2007. Parameters
  2008. ----------
  2009. string: string
  2010. Frame numbers paired with parameter values at that frame number, in the format
  2011. 'framenumber1: (parametervalues1), framenumber2: (parametervalues2), ...'
  2012. prompt_parser: function or None, optional
  2013. If provided, prompt_parser will be applied to each string of parameter values.
  2014. Returns
  2015. -------
  2016. dict
  2017. Frame numbers as keys, parameter values at that frame number as values
  2018. Raises
  2019. ------
  2020. RuntimeError
  2021. If the input string does not match the expected format.
  2022. Examples
  2023. --------
  2024. >>> parse_key_frames("10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)")
  2025. {10: 'Apple: 1| Orange: 0', 20: 'Apple: 0| Orange: 1| Peach: 1'}
  2026. >>> parse_key_frames("10:(Apple: 1| Orange: 0), 20: (Apple: 0| Orange: 1| Peach: 1)", prompt_parser=lambda x: x.lower()))
  2027. {10: 'apple: 1| orange: 0', 20: 'apple: 0| orange: 1| peach: 1'}
  2028. """
  2029. import re
  2030. pattern = r'((?P<frame>[0-9]+):[\s]*[\(](?P<param>[\S\s]*?)[\)])'
  2031. frames = dict()
  2032. for match_object in re.finditer(pattern, string):
  2033. frame = int(match_object.groupdict()['frame'])
  2034. param = match_object.groupdict()['param']
  2035. if prompt_parser:
  2036. frames[frame] = prompt_parser(param)
  2037. else:
  2038. frames[frame] = param
  2039. if frames == {} and len(string) != 0:
  2040. raise RuntimeError('Key Frame string not correctly formatted')
  2041. return frames
  2042. def get_inbetweens(key_frames, integer=False):
  2043. """Given a dict with frame numbers as keys and a parameter value as values,
  2044. return a pandas Series containing the value of the parameter at every frame from 0 to max_frames.
  2045. Any values not provided in the input dict are calculated by linear interpolation between
  2046. the values of the previous and next provided frames. If there is no previous provided frame, then
  2047. the value is equal to the value of the next provided frame, or if there is no next provided frame,
  2048. then the value is equal to the value of the previous provided frame. If no frames are provided,
  2049. all frame values are NaN.
  2050. Parameters
  2051. ----------
  2052. key_frames: dict
  2053. A dict with integer frame numbers as keys and numerical values of a particular parameter as values.
  2054. integer: Bool, optional
  2055. If True, the values of the output series are converted to integers.
  2056. Otherwise, the values are floats.
  2057. Returns
  2058. -------
  2059. pd.Series
  2060. A Series with length max_frames representing the parameter values for each frame.
  2061. Examples
  2062. --------
  2063. >>> max_frames = 5
  2064. >>> get_inbetweens({1: 5, 3: 6})
  2065. 0 5.0
  2066. 1 5.0
  2067. 2 5.5
  2068. 3 6.0
  2069. 4 6.0
  2070. dtype: float64
  2071. >>> get_inbetweens({1: 5, 3: 6}, integer=True)
  2072. 0 5
  2073. 1 5
  2074. 2 5
  2075. 3 6
  2076. 4 6
  2077. dtype: int64
  2078. """
  2079. key_frame_series = pd.Series([np.nan for a in range(max_frames)])
  2080. for i, value in key_frames.items():
  2081. key_frame_series[i] = value
  2082. key_frame_series = key_frame_series.astype(float)
  2083. interp_method = interp_spline
  2084. if interp_method == 'Cubic' and len(key_frames.items()) <=3:
  2085. interp_method = 'Quadratic'
  2086. if interp_method == 'Quadratic' and len(key_frames.items()) <= 2:
  2087. interp_method = 'Linear'
  2088. key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()]
  2089. key_frame_series[max_frames-1] = key_frame_series[key_frame_series.last_valid_index()]
  2090. # key_frame_series = key_frame_series.interpolate(method=intrp_method,order=1, limit_direction='both')
  2091. key_frame_series = key_frame_series.interpolate(method=interp_method.lower(),limit_direction='both')
  2092. if integer:
  2093. return key_frame_series.astype(int)
  2094. return key_frame_series
  2095. def split_prompts(prompts):
  2096. prompt_series = pd.Series([np.nan for a in range(max_frames)])
  2097. for i, prompt in prompts.items():
  2098. prompt_series[i] = prompt
  2099. # prompt_series = prompt_series.astype(str)
  2100. prompt_series = prompt_series.ffill().bfill()
  2101. return prompt_series
  2102. if key_frames:
  2103. try:
  2104. angle_series = get_inbetweens(parse_key_frames(angle))
  2105. except RuntimeError as e:
  2106. print(
  2107. "WARNING: You have selected to use key frames, but you have not "
  2108. "formatted `angle` correctly for key frames.\n"
  2109. "Attempting to interpret `angle` as "
  2110. f'"0: ({angle})"\n'
  2111. "Please read the instructions to find out how to use key frames "
  2112. "correctly.\n"
  2113. )
  2114. angle = f"0: ({angle})"
  2115. angle_series = get_inbetweens(parse_key_frames(angle))
  2116. try:
  2117. zoom_series = get_inbetweens(parse_key_frames(zoom))
  2118. except RuntimeError as e:
  2119. print(
  2120. "WARNING: You have selected to use key frames, but you have not "
  2121. "formatted `zoom` correctly for key frames.\n"
  2122. "Attempting to interpret `zoom` as "
  2123. f'"0: ({zoom})"\n'
  2124. "Please read the instructions to find out how to use key frames "
  2125. "correctly.\n"
  2126. )
  2127. zoom = f"0: ({zoom})"
  2128. zoom_series = get_inbetweens(parse_key_frames(zoom))
  2129. try:
  2130. translation_x_series = get_inbetweens(parse_key_frames(translation_x))
  2131. except RuntimeError as e:
  2132. print(
  2133. "WARNING: You have selected to use key frames, but you have not "
  2134. "formatted `translation_x` correctly for key frames.\n"
  2135. "Attempting to interpret `translation_x` as "
  2136. f'"0: ({translation_x})"\n'
  2137. "Please read the instructions to find out how to use key frames "
  2138. "correctly.\n"
  2139. )
  2140. translation_x = f"0: ({translation_x})"
  2141. translation_x_series = get_inbetweens(parse_key_frames(translation_x))
  2142. try:
  2143. translation_y_series = get_inbetweens(parse_key_frames(translation_y))
  2144. except RuntimeError as e:
  2145. print(
  2146. "WARNING: You have selected to use key frames, but you have not "
  2147. "formatted `translation_y` correctly for key frames.\n"
  2148. "Attempting to interpret `translation_y` as "
  2149. f'"0: ({translation_y})"\n'
  2150. "Please read the instructions to find out how to use key frames "
  2151. "correctly.\n"
  2152. )
  2153. translation_y = f"0: ({translation_y})"
  2154. translation_y_series = get_inbetweens(parse_key_frames(translation_y))
  2155. try:
  2156. translation_z_series = get_inbetweens(parse_key_frames(translation_z))
  2157. except RuntimeError as e:
  2158. print(
  2159. "WARNING: You have selected to use key frames, but you have not "
  2160. "formatted `translation_z` correctly for key frames.\n"
  2161. "Attempting to interpret `translation_z` as "
  2162. f'"0: ({translation_z})"\n'
  2163. "Please read the instructions to find out how to use key frames "
  2164. "correctly.\n"
  2165. )
  2166. translation_z = f"0: ({translation_z})"
  2167. translation_z_series = get_inbetweens(parse_key_frames(translation_z))
  2168. try:
  2169. rotation_3d_x_series = get_inbetweens(parse_key_frames(rotation_3d_x))
  2170. except RuntimeError as e:
  2171. print(
  2172. "WARNING: You have selected to use key frames, but you have not "
  2173. "formatted `rotation_3d_x` correctly for key frames.\n"
  2174. "Attempting to interpret `rotation_3d_x` as "
  2175. f'"0: ({rotation_3d_x})"\n'
  2176. "Please read the instructions to find out how to use key frames "
  2177. "correctly.\n"
  2178. )
  2179. rotation_3d_x = f"0: ({rotation_3d_x})"
  2180. rotation_3d_x_series = get_inbetweens(parse_key_frames(rotation_3d_x))
  2181. try:
  2182. rotation_3d_y_series = get_inbetweens(parse_key_frames(rotation_3d_y))
  2183. except RuntimeError as e:
  2184. print(
  2185. "WARNING: You have selected to use key frames, but you have not "
  2186. "formatted `rotation_3d_y` correctly for key frames.\n"
  2187. "Attempting to interpret `rotation_3d_y` as "
  2188. f'"0: ({rotation_3d_y})"\n'
  2189. "Please read the instructions to find out how to use key frames "
  2190. "correctly.\n"
  2191. )
  2192. rotation_3d_y = f"0: ({rotation_3d_y})"
  2193. rotation_3d_y_series = get_inbetweens(parse_key_frames(rotation_3d_y))
  2194. try:
  2195. rotation_3d_z_series = get_inbetweens(parse_key_frames(rotation_3d_z))
  2196. except RuntimeError as e:
  2197. print(
  2198. "WARNING: You have selected to use key frames, but you have not "
  2199. "formatted `rotation_3d_z` correctly for key frames.\n"
  2200. "Attempting to interpret `rotation_3d_z` as "
  2201. f'"0: ({rotation_3d_z})"\n'
  2202. "Please read the instructions to find out how to use key frames "
  2203. "correctly.\n"
  2204. )
  2205. rotation_3d_z = f"0: ({rotation_3d_z})"
  2206. rotation_3d_z_series = get_inbetweens(parse_key_frames(rotation_3d_z))
  2207. else:
  2208. angle = float(angle)
  2209. zoom = float(zoom)
  2210. translation_x = float(translation_x)
  2211. translation_y = float(translation_y)
  2212. translation_z = float(translation_z)
  2213. rotation_3d_x = float(rotation_3d_x)
  2214. rotation_3d_y = float(rotation_3d_y)
  2215. rotation_3d_z = float(rotation_3d_z)
  2216. """### Extra Settings
  2217. Partial Saves, Diffusion Sharpening, Advanced Settings, Cutn Scheduling
  2218. """
  2219. #@markdown ####**Saving:**
  2220. intermediate_saves = 4#@param{type: 'raw'}
  2221. intermediates_in_subfolder = True #@param{type: 'boolean'}
  2222. #@markdown Intermediate steps will save a copy at your specified intervals. You can either format it as a single integer or a list of specific steps
  2223. #@markdown A value of `2` will save a copy at 33% and 66%. 0 will save none.
  2224. #@markdown A value of `[5, 9, 34, 45]` will save at steps 5, 9, 34, and 45. (Make sure to include the brackets)
  2225. if type(intermediate_saves) is not list:
  2226. if intermediate_saves:
  2227. steps_per_checkpoint = math.floor((steps - skip_steps - 1) // (intermediate_saves+1))
  2228. steps_per_checkpoint = steps_per_checkpoint if steps_per_checkpoint > 0 else 1
  2229. print(f'Will save every {steps_per_checkpoint} steps')
  2230. else:
  2231. steps_per_checkpoint = steps+10
  2232. else:
  2233. steps_per_checkpoint = None
  2234. if intermediate_saves and intermediates_in_subfolder is True:
  2235. partialFolder = f'{batchFolder}/partials'
  2236. createPath(partialFolder)
  2237. #@markdown ---
  2238. #@markdown ####**SuperRes Sharpening:**
  2239. #@markdown *Sharpen each image using latent-diffusion. Does not run in animation mode. `keep_unsharp` will save both versions.*
  2240. sharpen_preset = 'Fast' #@param ['Off', 'Faster', 'Fast', 'Slow', 'Very Slow']
  2241. keep_unsharp = True #@param{type: 'boolean'}
  2242. if sharpen_preset != 'Off' and keep_unsharp is True:
  2243. unsharpenFolder = f'{batchFolder}/unsharpened'
  2244. createPath(unsharpenFolder)
  2245. #@markdown ---
  2246. #@markdown ####**Advanced Settings:**
  2247. #@markdown *There are a few extra advanced settings available if you double click this cell.*
  2248. #@markdown *Perlin init will replace your init, so uncheck if using one.*
  2249. perlin_init = False #@param{type: 'boolean'}
  2250. perlin_mode = 'mixed' #@param ['mixed', 'color', 'gray']
  2251. set_seed = 'random_seed' #@param{type: 'string'}
  2252. eta = 0.8#@param{type: 'number'}
  2253. clamp_grad = True #@param{type: 'boolean'}
  2254. clamp_max = 0.05 #@param{type: 'number'}
  2255. ### EXTRA ADVANCED SETTINGS:
  2256. randomize_class = True
  2257. clip_denoised = False
  2258. fuzzy_prompt = False
  2259. rand_mag = 0.05
  2260. #@markdown ---
  2261. #@markdown ####**Cutn Scheduling:**
  2262. #@markdown Format: `[40]*400+[20]*600` = 40 cuts for the first 400 /1000 steps, then 20 for the last 600/1000
  2263. #@markdown cut_overview and cut_innercut are cumulative for total cutn on any given step. Overview cuts see the entire image and are good for early structure, innercuts are your standard cutn.
  2264. cut_overview = "[12]*400+[4]*600" #@param {type: 'string'}
  2265. cut_innercut ="[4]*400+[12]*600"#@param {type: 'string'}
  2266. cut_ic_pow = 1#@param {type: 'number'}
  2267. cut_icgray_p = "[0.2]*400+[0]*600"#@param {type: 'string'}
  2268. """### Prompts
  2269. `animation_mode: None` will only use the first set. `animation_mode: 2D / Video` will run through them per the set frames and hold on the last one.
  2270. """
  2271. text_prompts = {
  2272. 0: [
  2273. "megastructure in the cloud, blame!, contemporary house in the mist, artstation",
  2274. ]
  2275. }
  2276. image_prompts = {
  2277. # 0:['ImagePromptsWorkButArentVeryGood.png:2',],
  2278. }
  2279. """# 4. Diffuse!"""
  2280. #@title Do the Run!
  2281. #@markdown `n_batches` ignored with animation modes.
  2282. display_rate = 50#@param{type: 'number'}
  2283. n_batches = 50#@param{type: 'number'}
  2284. #Update Model Settings
  2285. timestep_respacing = f'ddim{steps}'
  2286. diffusion_steps = (1000//steps)*steps if steps < 1000 else steps
  2287. model_config.update({
  2288. 'timestep_respacing': timestep_respacing,
  2289. 'diffusion_steps': diffusion_steps,
  2290. })
  2291. batch_size = 1
  2292. def move_files(start_num, end_num, old_folder, new_folder):
  2293. for i in range(start_num, end_num):
  2294. old_file = old_folder + f'/{batch_name}({batchNum})_{i:04}.png'
  2295. new_file = new_folder + f'/{batch_name}({batchNum})_{i:04}.png'
  2296. os.rename(old_file, new_file)
  2297. #@markdown ---
  2298. resume_run = False #@param{type: 'boolean'}
  2299. run_to_resume = 'latest' #@param{type: 'string'}
  2300. resume_from_frame = 'latest' #@param{type: 'string'}
  2301. retain_overwritten_frames = False #@param{type: 'boolean'}
  2302. if retain_overwritten_frames is True:
  2303. retainFolder = f'{batchFolder}/retained'
  2304. createPath(retainFolder)
  2305. skip_step_ratio = int(frames_skip_steps.rstrip("%")) / 100
  2306. calc_frames_skip_steps = math.floor(steps * skip_step_ratio)
  2307. if steps <= calc_frames_skip_steps:
  2308. sys.exit("ERROR: You can't skip more steps than your total steps")
  2309. if resume_run:
  2310. if run_to_resume == 'latest':
  2311. try:
  2312. batchNum
  2313. except:
  2314. batchNum = len(glob(f"{batchFolder}/{batch_name}(*)_settings.txt"))-1
  2315. else:
  2316. batchNum = int(run_to_resume)
  2317. if resume_from_frame == 'latest':
  2318. start_frame = len(glob(batchFolder+f"/{batch_name}({batchNum})_*.png"))
  2319. if animation_mode != '3D' and turbo_mode == True and start_frame > turbo_preroll and start_frame % int(turbo_steps) != 0:
  2320. start_frame = start_frame - (start_frame % int(turbo_steps))
  2321. else:
  2322. start_frame = int(resume_from_frame)+1
  2323. if animation_mode != '3D' and turbo_mode == True and start_frame > turbo_preroll and start_frame % int(turbo_steps) != 0:
  2324. start_frame = start_frame - (start_frame % int(turbo_steps))
  2325. if retain_overwritten_frames is True:
  2326. existing_frames = len(glob(batchFolder+f"/{batch_name}({batchNum})_*.png"))
  2327. frames_to_save = existing_frames - start_frame
  2328. print(f'Moving {frames_to_save} frames to the Retained folder')
  2329. move_files(start_frame, existing_frames, batchFolder, retainFolder)
  2330. else:
  2331. start_frame = 0
  2332. batchNum = len(glob(batchFolder+"/*.txt"))
  2333. while os.path.isfile(f"{batchFolder}/{batch_name}({batchNum})_settings.txt") is True or os.path.isfile(f"{batchFolder}/{batch_name}-{batchNum}_settings.txt") is True:
  2334. batchNum += 1
  2335. print(f'Starting Run: {batch_name}({batchNum}) at frame {start_frame}')
  2336. if set_seed == 'random_seed':
  2337. random.seed()
  2338. seed = random.randint(0, 2**32)
  2339. # print(f'Using seed: {seed}')
  2340. else:
  2341. seed = int(set_seed)
  2342. args = {
  2343. 'batchNum': batchNum,
  2344. 'prompts_series':split_prompts(text_prompts) if text_prompts else None,
  2345. 'image_prompts_series':split_prompts(image_prompts) if image_prompts else None,
  2346. 'seed': seed,
  2347. 'display_rate':display_rate,
  2348. 'n_batches':n_batches if animation_mode == 'None' else 1,
  2349. 'batch_size':batch_size,
  2350. 'batch_name': batch_name,
  2351. 'steps': steps,
  2352. 'diffusion_sampling_mode': diffusion_sampling_mode,
  2353. 'width_height': width_height,
  2354. 'clip_guidance_scale': clip_guidance_scale,
  2355. 'tv_scale': tv_scale,
  2356. 'range_scale': range_scale,
  2357. 'sat_scale': sat_scale,
  2358. 'cutn_batches': cutn_batches,
  2359. 'init_image': init_image,
  2360. 'init_scale': init_scale,
  2361. 'skip_steps': skip_steps,
  2362. 'sharpen_preset': sharpen_preset,
  2363. 'keep_unsharp': keep_unsharp,
  2364. 'side_x': side_x,
  2365. 'side_y': side_y,
  2366. 'timestep_respacing': timestep_respacing,
  2367. 'diffusion_steps': diffusion_steps,
  2368. 'animation_mode': animation_mode,
  2369. 'video_init_path': video_init_path,
  2370. 'extract_nth_frame': extract_nth_frame,
  2371. 'video_init_seed_continuity': video_init_seed_continuity,
  2372. 'key_frames': key_frames,
  2373. 'max_frames': max_frames if animation_mode != "None" else 1,
  2374. 'interp_spline': interp_spline,
  2375. 'start_frame': start_frame,
  2376. 'angle': angle,
  2377. 'zoom': zoom,
  2378. 'translation_x': translation_x,
  2379. 'translation_y': translation_y,
  2380. 'translation_z': translation_z,
  2381. 'rotation_3d_x': rotation_3d_x,
  2382. 'rotation_3d_y': rotation_3d_y,
  2383. 'rotation_3d_z': rotation_3d_z,
  2384. 'midas_depth_model': midas_depth_model,
  2385. 'midas_weight': midas_weight,
  2386. 'near_plane': near_plane,
  2387. 'far_plane': far_plane,
  2388. 'fov': fov,
  2389. 'padding_mode': padding_mode,
  2390. 'sampling_mode': sampling_mode,
  2391. 'angle_series':angle_series,
  2392. 'zoom_series':zoom_series,
  2393. 'translation_x_series':translation_x_series,
  2394. 'translation_y_series':translation_y_series,
  2395. 'translation_z_series':translation_z_series,
  2396. 'rotation_3d_x_series':rotation_3d_x_series,
  2397. 'rotation_3d_y_series':rotation_3d_y_series,
  2398. 'rotation_3d_z_series':rotation_3d_z_series,
  2399. 'frames_scale': frames_scale,
  2400. 'calc_frames_skip_steps': calc_frames_skip_steps,
  2401. 'skip_step_ratio': skip_step_ratio,
  2402. 'calc_frames_skip_steps': calc_frames_skip_steps,
  2403. 'text_prompts': text_prompts,
  2404. 'image_prompts': image_prompts,
  2405. 'cut_overview': eval(cut_overview),
  2406. 'cut_innercut': eval(cut_innercut),
  2407. 'cut_ic_pow': cut_ic_pow,
  2408. 'cut_icgray_p': eval(cut_icgray_p),
  2409. 'intermediate_saves': intermediate_saves,
  2410. 'intermediates_in_subfolder': intermediates_in_subfolder,
  2411. 'steps_per_checkpoint': steps_per_checkpoint,
  2412. 'perlin_init': perlin_init,
  2413. 'perlin_mode': perlin_mode,
  2414. 'set_seed': set_seed,
  2415. 'eta': eta,
  2416. 'clamp_grad': clamp_grad,
  2417. 'clamp_max': clamp_max,
  2418. 'skip_augs': skip_augs,
  2419. 'randomize_class': randomize_class,
  2420. 'clip_denoised': clip_denoised,
  2421. 'fuzzy_prompt': fuzzy_prompt,
  2422. 'rand_mag': rand_mag,
  2423. }
  2424. args = SimpleNamespace(**args)
  2425. print('Prepping model...')
  2426. model, diffusion = create_model_and_diffusion(**model_config)
  2427. model.load_state_dict(torch.load(f'{model_path}/{diffusion_model}.pt', map_location='cpu'))
  2428. model.requires_grad_(False).eval().to(device)
  2429. for name, param in model.named_parameters():
  2430. if 'qkv' in name or 'norm' in name or 'proj' in name:
  2431. param.requires_grad_()
  2432. if model_config['use_fp16']:
  2433. model.convert_to_fp16()
  2434. gc.collect()
  2435. torch.cuda.empty_cache()
  2436. try:
  2437. do_run()
  2438. except KeyboardInterrupt:
  2439. pass
  2440. finally:
  2441. print('Seed used:', seed)
  2442. gc.collect()
  2443. torch.cuda.empty_cache()
  2444. """# 5. Create the video"""
  2445. # @title ### **Create video**
  2446. #@markdown Video file will save in the same folder as your images.
  2447. skip_video_for_run_all = True #@param {type: 'boolean'}
  2448. if skip_video_for_run_all == True:
  2449. print('Skipping video creation, uncheck skip_video_for_run_all if you want to run it')
  2450. else:
  2451. # import subprocess in case this cell is run without the above cells
  2452. import subprocess
  2453. from base64 import b64encode
  2454. latest_run = batchNum
  2455. folder = batch_name #@param
  2456. run = latest_run #@param
  2457. final_frame = 'final_frame'
  2458. init_frame = 1#@param {type:"number"} This is the frame where the video will start
  2459. last_frame = final_frame#@param {type:"number"} You can change i to the number of the last frame you want to generate. It will raise an error if that number of frames does not exist.
  2460. fps = 12#@param {type:"number"}
  2461. # view_video_in_cell = True #@param {type: 'boolean'}
  2462. frames = []
  2463. # tqdm.write('Generating video...')
  2464. if last_frame == 'final_frame':
  2465. last_frame = len(glob(batchFolder+f"/{folder}({run})_*.png"))
  2466. print(f'Total frames: {last_frame}')
  2467. image_path = f"{outDirPath}/{folder}/{folder}({run})_%04d.png"
  2468. filepath = f"{outDirPath}/{folder}/{folder}({run}).mp4"
  2469. cmd = [
  2470. 'ffmpeg',
  2471. '-y',
  2472. '-vcodec',
  2473. 'png',
  2474. '-r',
  2475. str(fps),
  2476. '-start_number',
  2477. str(init_frame),
  2478. '-i',
  2479. image_path,
  2480. '-frames:v',
  2481. str(last_frame+1),
  2482. '-c:v',
  2483. 'libx264',
  2484. '-vf',
  2485. f'fps={fps}',
  2486. '-pix_fmt',
  2487. 'yuv420p',
  2488. '-crf',
  2489. '17',
  2490. '-preset',
  2491. 'veryslow',
  2492. filepath
  2493. ]
  2494. process = subprocess.Popen(cmd, cwd=f'{batchFolder}', stdout=subprocess.PIPE, stderr=subprocess.PIPE)
  2495. stdout, stderr = process.communicate()
  2496. if process.returncode != 0:
  2497. print(stderr)
  2498. raise RuntimeError(stderr)
  2499. else:
  2500. print("The video is ready and saved to the images folder")
  2501. # if view_video_in_cell:
  2502. # mp4 = open(filepath,'rb').read()
  2503. # data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
  2504. # display.HTML(f'<video width=400 controls><source src="{data_url}" type="video/mp4"></video>')