diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000..45532f8dc7 --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +# Unnecessary compiled python files. +__pycache__ +*.pyc +*.pyo + +# Output Images +outputs + +# Log files for colab-convert +cc-outputs.log \ No newline at end of file diff --git a/Deforum_Stable_Diffusion.ipynb b/Deforum_Stable_Diffusion.ipynb new file mode 100644 index 0000000000..058ea97837 --- /dev/null +++ b/Deforum_Stable_Diffusion.ipynb @@ -0,0 +1,1992 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "c442uQJ_gUgy" + }, + "source": [ + "# **Deforum Stable Diffusion v0.5**\n", + "[Stable Diffusion](https://github.com/CompVis/stable-diffusion) by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer and the [Stability.ai](https://stability.ai/) Team. [K Diffusion](https://github.com/crowsonkb/k-diffusion) by [Katherine Crowson](https://twitter.com/RiversHaveWings). You need to get the ckpt file and put it on your Google Drive first to use this. It can be downloaded from [HuggingFace](https://huggingface.co/CompVis/stable-diffusion).\n", + "\n", + "Notebook by [deforum](https://discord.gg/upmXXsrwZc)\n", + "\n", + "**Important note:** this notebook severely lacks maintainance as the most devs have moved to [the WebUI extension](https://github.com/deforum-art/deforum-for-automatic1111-webui). Please, visit the [Deforum Discord server](https://discord.gg/deforum) to get info on the more active forks\n", + "\n", + "If you still want to use this notebook, **proceed only if you know what you're doing!**" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "LBamKxcmNI7-" + }, + "source": [ + "By using this Notebook, you agree to the following Terms of Use, and license:\n", + "\n", + "**Stablity.AI Model Terms of Use**\n", + "\n", + "This model is open access and available to all, with a CreativeML OpenRAIL-M license further specifying rights and usage.\n", + "\n", + "The CreativeML OpenRAIL License specifies:\n", + "\n", + "You can't use the model to deliberately produce nor share illegal or harmful outputs or content\n", + "CompVis claims no rights on the outputs you generate, you are free to use them and are accountable for their use which must not go against the provisions set in the license\n", + "You may re-distribute the weights and use the model commercially and/or as a service. If you do, please be aware you have to include the same use restrictions as the ones in the license and share a copy of the CreativeML OpenRAIL-M to all your users (please read the license entirely and carefully)\n", + "\n", + "\n", + "Please read the full license here: https://huggingface.co/spaces/CompVis/stable-diffusion-license" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "T4knibRpAQ06" + }, + "source": [ + "# Setup" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "2g-f7cQmf2Nt", + "cellView": "form" + }, + "source": [ + "#@markdown **NVIDIA GPU**\n", + "import subprocess\n", + "sub_p_res = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total,memory.free', '--format=csv,noheader'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n", + "print(sub_p_res)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": { + "cellView": "form", + "id": "TxIOPT0G5Lx1" + }, + "source": [ + "#@markdown **Model and Output Paths**\n", + "# ask for the link\n", + "print(\"Local Path Variables:\\n\")\n", + "\n", + "models_path = \"/content/models\" #@param {type:\"string\"}\n", + "output_path = \"/content/output\" #@param {type:\"string\"}\n", + "\n", + "#@markdown **Google Drive Path Variables (Optional)**\n", + "mount_google_drive = True #@param {type:\"boolean\"}\n", + "force_remount = False\n", + "\n", + "if mount_google_drive:\n", + " from google.colab import drive # type: ignore\n", + " try:\n", + " drive_path = \"/content/drive\"\n", + " drive.mount(drive_path,force_remount=force_remount)\n", + " models_path_gdrive = \"/content/drive/MyDrive/AI/models\" #@param {type:\"string\"}\n", + " output_path_gdrive = \"/content/drive/MyDrive/AI/StableDiffusion\" #@param {type:\"string\"}\n", + " models_path = models_path_gdrive\n", + " output_path = output_path_gdrive\n", + " except:\n", + " print(\"...error mounting drive or with drive path variables\")\n", + " print(\"...reverting to default path variables\")\n", + "\n", + "import os\n", + "os.makedirs(models_path, exist_ok=True)\n", + "os.makedirs(output_path, exist_ok=True)\n", + "\n", + "print(f\"models_path: {models_path}\")\n", + "print(f\"output_path: {output_path}\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": { + "id": "VRNl2mfepEIe", + "cellView": "form" + }, + "source": [ + "#@markdown **Setup Environment**\n", + "\n", + "setup_environment = True #@param {type:\"boolean\"}\n", + "print_subprocess = False #@param {type:\"boolean\"}\n", + "\n", + "if setup_environment:\n", + " import subprocess, time\n", + " print(\"Setting up environment...\")\n", + " start_time = time.time()\n", + " all_process = [\n", + " ['pip', 'install', 'torch==1.12.1+cu113', 'torchvision==0.13.1+cu113', '--extra-index-url', 'https://download.pytorch.org/whl/cu113'],\n", + " ['pip', 'install', 'omegaconf==2.2.3', 'einops==0.4.1', 'pytorch-lightning==1.7.4', 'torchmetrics==0.9.3', 'torchtext==0.13.1', 'transformers==4.21.2', 'kornia==0.6.7'],\n", + " ['git', 'clone', 'https://github.com/deforum/stable-diffusion'],\n", + " ['pip', 'install', '-e', 'git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers'],\n", + " ['pip', 'install', '-e', 'git+https://github.com/openai/CLIP.git@main#egg=clip'],\n", + " ['pip', 'install', 'accelerate', 'ftfy', 'jsonmerge', 'matplotlib', 'resize-right', 'timm', 'torchdiffeq'],\n", + " ['git', 'clone', 'https://github.com/shariqfarooq123/AdaBins.git'],\n", + " ['git', 'clone', 'https://github.com/isl-org/MiDaS.git'],\n", + " ['git', 'clone', 'https://github.com/MSFTserver/pytorch3d-lite.git'],\n", + " ]\n", + " for process in all_process:\n", + " running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8')\n", + " if print_subprocess:\n", + " print(running)\n", + " \n", + " print(subprocess.run(['git', 'clone', 'https://github.com/deforum/k-diffusion/'], stdout=subprocess.PIPE).stdout.decode('utf-8'))\n", + " with open('k-diffusion/k_diffusion/__init__.py', 'w') as f:\n", + " f.write('')\n", + "\n", + " end_time = time.time()\n", + " print(f\"Environment set up in {end_time-start_time:.0f} seconds\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": { + "id": "81qmVZbrm4uu", + "cellView": "form" + }, + "source": [ + "#@markdown **Python Definitions**\n", + "import json\n", + "from IPython import display\n", + "\n", + "import gc, math, os, pathlib, subprocess, sys, time\n", + "import cv2\n", + "import numpy as np\n", + "import pandas as pd\n", + "import random\n", + "import requests\n", + "import torch\n", + "import torch.nn as nn\n", + "import torchvision.transforms as T\n", + "import torchvision.transforms.functional as TF\n", + "from contextlib import contextmanager, nullcontext\n", + "from einops import rearrange, repeat\n", + "from omegaconf import OmegaConf\n", + "from PIL import Image\n", + "from pytorch_lightning import seed_everything\n", + "from skimage.exposure import match_histograms\n", + "from torchvision.utils import make_grid\n", + "from tqdm import tqdm, trange\n", + "from types import SimpleNamespace\n", + "from torch import autocast\n", + "import re\n", + "from scipy.ndimage import gaussian_filter\n", + "\n", + "sys.path.extend([\n", + " 'src/taming-transformers',\n", + " 'src/clip',\n", + " 'stable-diffusion/',\n", + " 'k-diffusion',\n", + " 'pytorch3d-lite',\n", + " 'AdaBins',\n", + " 'MiDaS',\n", + "])\n", + "\n", + "import py3d_tools as p3d\n", + "\n", + "from helpers import DepthModel, sampler_fn\n", + "from k_diffusion.external import CompVisDenoiser\n", + "from ldm.util import instantiate_from_config\n", + "from ldm.models.diffusion.ddim import DDIMSampler\n", + "from ldm.models.diffusion.plms import PLMSSampler\n", + "\n", + "def sanitize(prompt):\n", + " whitelist = set('abcdefghijklmnopqrstuvwxyz ABCDEFGHIJKLMNOPQRSTUVWXYZ')\n", + " tmp = ''.join(filter(whitelist.__contains__, prompt))\n", + " return tmp.replace(' ', '_')\n", + "\n", + "from functools import reduce\n", + "def construct_RotationMatrixHomogenous(rotation_angles):\n", + " assert(type(rotation_angles)==list and len(rotation_angles)==3)\n", + " RH = np.eye(4,4)\n", + " cv2.Rodrigues(np.array(rotation_angles), RH[0:3, 0:3])\n", + " return RH\n", + "\n", + "# https://en.wikipedia.org/wiki/Rotation_matrix\n", + "def getRotationMatrixManual(rotation_angles):\n", + "\t\n", + " rotation_angles = [np.deg2rad(x) for x in rotation_angles]\n", + " \n", + " phi = rotation_angles[0] # around x\n", + " gamma = rotation_angles[1] # around y\n", + " theta = rotation_angles[2] # around z\n", + " \n", + " # X rotation\n", + " Rphi = np.eye(4,4)\n", + " sp = np.sin(phi)\n", + " cp = np.cos(phi)\n", + " Rphi[1,1] = cp\n", + " Rphi[2,2] = Rphi[1,1]\n", + " Rphi[1,2] = -sp\n", + " Rphi[2,1] = sp\n", + " \n", + " # Y rotation\n", + " Rgamma = np.eye(4,4)\n", + " sg = np.sin(gamma)\n", + " cg = np.cos(gamma)\n", + " Rgamma[0,0] = cg\n", + " Rgamma[2,2] = Rgamma[0,0]\n", + " Rgamma[0,2] = sg\n", + " Rgamma[2,0] = -sg\n", + " \n", + " # Z rotation (in-image-plane)\n", + " Rtheta = np.eye(4,4)\n", + " st = np.sin(theta)\n", + " ct = np.cos(theta)\n", + " Rtheta[0,0] = ct\n", + " Rtheta[1,1] = Rtheta[0,0]\n", + " Rtheta[0,1] = -st\n", + " Rtheta[1,0] = st\n", + " \n", + " R = reduce(lambda x,y : np.matmul(x,y), [Rphi, Rgamma, Rtheta]) \n", + " \n", + " return R\n", + "\n", + "\n", + "def getPoints_for_PerspectiveTranformEstimation(ptsIn, ptsOut, W, H, sidelength):\n", + " \n", + " ptsIn2D = ptsIn[0,:]\n", + " ptsOut2D = ptsOut[0,:]\n", + " ptsOut2Dlist = []\n", + " ptsIn2Dlist = []\n", + " \n", + " for i in range(0,4):\n", + " ptsOut2Dlist.append([ptsOut2D[i,0], ptsOut2D[i,1]])\n", + " ptsIn2Dlist.append([ptsIn2D[i,0], ptsIn2D[i,1]])\n", + " \n", + " pin = np.array(ptsIn2Dlist) + [W/2.,H/2.]\n", + " pout = (np.array(ptsOut2Dlist) + [1.,1.]) * (0.5*sidelength)\n", + " pin = pin.astype(np.float32)\n", + " pout = pout.astype(np.float32)\n", + " \n", + " return pin, pout\n", + "\n", + "def warpMatrix(W, H, theta, phi, gamma, scale, fV):\n", + " \n", + " # M is to be estimated\n", + " M = np.eye(4, 4)\n", + " \n", + " fVhalf = np.deg2rad(fV/2.)\n", + " d = np.sqrt(W*W+H*H)\n", + " sideLength = scale*d/np.cos(fVhalf)\n", + " h = d/(2.0*np.sin(fVhalf))\n", + " n = h-(d/2.0);\n", + " f = h+(d/2.0);\n", + " \n", + " # Translation along Z-axis by -h\n", + " T = np.eye(4,4)\n", + " T[2,3] = -h\n", + " \n", + " # Rotation matrices around x,y,z\n", + " R = getRotationMatrixManual([phi, gamma, theta])\n", + " \n", + " \n", + " # Projection Matrix \n", + " P = np.eye(4,4)\n", + " P[0,0] = 1.0/np.tan(fVhalf)\n", + " P[1,1] = P[0,0]\n", + " P[2,2] = -(f+n)/(f-n)\n", + " P[2,3] = -(2.0*f*n)/(f-n)\n", + " P[3,2] = -1.0\n", + " \n", + " # pythonic matrix multiplication\n", + " F = reduce(lambda x,y : np.matmul(x,y), [P, T, R]) \n", + " \n", + " # shape should be 1,4,3 for ptsIn and ptsOut since perspectiveTransform() expects data in this way. \n", + " # In C++, this can be achieved by Mat ptsIn(1,4,CV_64FC3);\n", + " ptsIn = np.array([[\n", + " [-W/2., H/2., 0.],[ W/2., H/2., 0.],[ W/2.,-H/2., 0.],[-W/2.,-H/2., 0.]\n", + " ]])\n", + " ptsOut = np.array(np.zeros((ptsIn.shape), dtype=ptsIn.dtype))\n", + " ptsOut = cv2.perspectiveTransform(ptsIn, F)\n", + " \n", + " ptsInPt2f, ptsOutPt2f = getPoints_for_PerspectiveTranformEstimation(ptsIn, ptsOut, W, H, sideLength)\n", + " \n", + " # check float32 otherwise OpenCV throws an error\n", + " assert(ptsInPt2f.dtype == np.float32)\n", + " assert(ptsOutPt2f.dtype == np.float32)\n", + " M33 = cv2.getPerspectiveTransform(ptsInPt2f,ptsOutPt2f)\n", + "\n", + " return M33, sideLength\n", + "\n", + "def anim_frame_warp_2d(prev_img_cv2, args, anim_args, keys, frame_idx):\n", + " angle = keys.angle_series[frame_idx]\n", + " zoom = keys.zoom_series[frame_idx]\n", + " translation_x = keys.translation_x_series[frame_idx]\n", + " translation_y = keys.translation_y_series[frame_idx]\n", + "\n", + " center = (args.W // 2, args.H // 2)\n", + " trans_mat = np.float32([[1, 0, translation_x], [0, 1, translation_y]])\n", + " rot_mat = cv2.getRotationMatrix2D(center, angle, zoom)\n", + " trans_mat = np.vstack([trans_mat, [0,0,1]])\n", + " rot_mat = np.vstack([rot_mat, [0,0,1]])\n", + " if anim_args.flip_2d_perspective:\n", + " perspective_flip_theta = keys.perspective_flip_theta_series[frame_idx]\n", + " perspective_flip_phi = keys.perspective_flip_phi_series[frame_idx]\n", + " perspective_flip_gamma = keys.perspective_flip_gamma_series[frame_idx]\n", + " perspective_flip_fv = keys.perspective_flip_fv_series[frame_idx]\n", + " M,sl = warpMatrix(args.W, args.H, perspective_flip_theta, perspective_flip_phi, perspective_flip_gamma, 1., perspective_flip_fv);\n", + " post_trans_mat = np.float32([[1, 0, (args.W-sl)/2], [0, 1, (args.H-sl)/2]])\n", + " post_trans_mat = np.vstack([post_trans_mat, [0,0,1]])\n", + " bM = np.matmul(M, post_trans_mat)\n", + " xform = np.matmul(bM, rot_mat, trans_mat)\n", + " else:\n", + " xform = np.matmul(rot_mat, trans_mat)\n", + "\n", + " return cv2.warpPerspective(\n", + " prev_img_cv2,\n", + " xform,\n", + " (prev_img_cv2.shape[1], prev_img_cv2.shape[0]),\n", + " borderMode=cv2.BORDER_WRAP if anim_args.border == 'wrap' else cv2.BORDER_REPLICATE\n", + " )\n", + "\n", + "def anim_frame_warp_3d(prev_img_cv2, depth, anim_args, keys, frame_idx):\n", + " TRANSLATION_SCALE = 1.0/200.0 # matches Disco\n", + " translate_xyz = [\n", + " -keys.translation_x_series[frame_idx] * TRANSLATION_SCALE, \n", + " keys.translation_y_series[frame_idx] * TRANSLATION_SCALE, \n", + " -keys.translation_z_series[frame_idx] * TRANSLATION_SCALE\n", + " ]\n", + " rotate_xyz = [\n", + " math.radians(keys.rotation_3d_x_series[frame_idx]), \n", + " math.radians(keys.rotation_3d_y_series[frame_idx]), \n", + " math.radians(keys.rotation_3d_z_series[frame_idx])\n", + " ]\n", + " rot_mat = p3d.euler_angles_to_matrix(torch.tensor(rotate_xyz, device=device), \"XYZ\").unsqueeze(0)\n", + " result = transform_image_3d(prev_img_cv2, depth, rot_mat, translate_xyz, anim_args)\n", + " torch.cuda.empty_cache()\n", + " return result\n", + "\n", + "def add_noise(sample: torch.Tensor, noise_amt: float) -> torch.Tensor:\n", + " return sample + torch.randn(sample.shape, device=sample.device) * noise_amt\n", + "\n", + "def get_output_folder(output_path, batch_folder):\n", + " out_path = os.path.join(output_path,time.strftime('%Y-%m'))\n", + " if batch_folder != \"\":\n", + " out_path = os.path.join(out_path, batch_folder)\n", + " os.makedirs(out_path, exist_ok=True)\n", + " return out_path\n", + "\n", + "def load_img(path, shape, use_alpha_as_mask=False):\n", + " # use_alpha_as_mask: Read the alpha channel of the image as the mask image\n", + " if path.startswith('http://') or path.startswith('https://'):\n", + " image = Image.open(requests.get(path, stream=True).raw)\n", + " else:\n", + " image = Image.open(path)\n", + "\n", + " if use_alpha_as_mask:\n", + " image = image.convert('RGBA')\n", + " else:\n", + " image = image.convert('RGB')\n", + "\n", + " image = image.resize(shape, resample=Image.LANCZOS)\n", + "\n", + " mask_image = None\n", + " if use_alpha_as_mask:\n", + " # Split alpha channel into a mask_image\n", + " red, green, blue, alpha = Image.Image.split(image)\n", + " mask_image = alpha.convert('L')\n", + " image = image.convert('RGB')\n", + "\n", + " image = np.array(image).astype(np.float16) / 255.0\n", + " image = image[None].transpose(0, 3, 1, 2)\n", + " image = torch.from_numpy(image)\n", + " image = 2.*image - 1.\n", + "\n", + " return image, mask_image\n", + "\n", + "def load_mask_latent(mask_input, shape):\n", + " # mask_input (str or PIL Image.Image): Path to the mask image or a PIL Image object\n", + " # shape (list-like len(4)): shape of the image to match, usually latent_image.shape\n", + " \n", + " if isinstance(mask_input, str): # mask input is probably a file name\n", + " if mask_input.startswith('http://') or mask_input.startswith('https://'):\n", + " mask_image = Image.open(requests.get(mask_input, stream=True).raw).convert('RGBA')\n", + " else:\n", + " mask_image = Image.open(mask_input).convert('RGBA')\n", + " elif isinstance(mask_input, Image.Image):\n", + " mask_image = mask_input\n", + " else:\n", + " raise Exception(\"mask_input must be a PIL image or a file name\")\n", + "\n", + " mask_w_h = (shape[-1], shape[-2])\n", + " mask = mask_image.resize(mask_w_h, resample=Image.LANCZOS)\n", + " mask = mask.convert(\"L\")\n", + " return mask\n", + "\n", + "def prepare_mask(mask_input, mask_shape, mask_brightness_adjust=1.0, mask_contrast_adjust=1.0):\n", + " # mask_input (str or PIL Image.Image): Path to the mask image or a PIL Image object\n", + " # shape (list-like len(4)): shape of the image to match, usually latent_image.shape\n", + " # mask_brightness_adjust (non-negative float): amount to adjust brightness of the iamge, \n", + " # 0 is black, 1 is no adjustment, >1 is brighter\n", + " # mask_contrast_adjust (non-negative float): amount to adjust contrast of the image, \n", + " # 0 is a flat grey image, 1 is no adjustment, >1 is more contrast\n", + " \n", + " mask = load_mask_latent(mask_input, mask_shape)\n", + "\n", + " # Mask brightness/contrast adjustments\n", + " if mask_brightness_adjust != 1:\n", + " mask = TF.adjust_brightness(mask, mask_brightness_adjust)\n", + " if mask_contrast_adjust != 1:\n", + " mask = TF.adjust_contrast(mask, mask_contrast_adjust)\n", + "\n", + " # Mask image to array\n", + " mask = np.array(mask).astype(np.float32) / 255.0\n", + " mask = np.tile(mask,(4,1,1))\n", + " mask = np.expand_dims(mask,axis=0)\n", + " mask = torch.from_numpy(mask)\n", + "\n", + " if args.invert_mask:\n", + " mask = ( (mask - 0.5) * -1) + 0.5\n", + " \n", + " mask = np.clip(mask,0,1)\n", + " return mask\n", + "\n", + "def maintain_colors(prev_img, color_match_sample, mode):\n", + " if mode == 'Match Frame 0 RGB':\n", + " return match_histograms(prev_img, color_match_sample, multichannel=True)\n", + " elif mode == 'Match Frame 0 HSV':\n", + " prev_img_hsv = cv2.cvtColor(prev_img, cv2.COLOR_RGB2HSV)\n", + " color_match_hsv = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2HSV)\n", + " matched_hsv = match_histograms(prev_img_hsv, color_match_hsv, multichannel=True)\n", + " return cv2.cvtColor(matched_hsv, cv2.COLOR_HSV2RGB)\n", + " else: # Match Frame 0 LAB\n", + " prev_img_lab = cv2.cvtColor(prev_img, cv2.COLOR_RGB2LAB)\n", + " color_match_lab = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2LAB)\n", + " matched_lab = match_histograms(prev_img_lab, color_match_lab, multichannel=True)\n", + " return cv2.cvtColor(matched_lab, cv2.COLOR_LAB2RGB)\n", + "\n", + "\n", + "#\n", + "# Callback functions\n", + "#\n", + "class SamplerCallback(object):\n", + " # Creates the callback function to be passed into the samplers for each step\n", + " def __init__(self, args, mask=None, init_latent=None, sigmas=None, sampler=None,\n", + " verbose=False):\n", + " self.sampler_name = args.sampler\n", + " self.dynamic_threshold = args.dynamic_threshold\n", + " self.static_threshold = args.static_threshold\n", + " self.mask = mask\n", + " self.init_latent = init_latent \n", + " self.sigmas = sigmas\n", + " self.sampler = sampler\n", + " self.verbose = verbose\n", + "\n", + " self.batch_size = args.n_samples\n", + " self.save_sample_per_step = args.save_sample_per_step\n", + " self.show_sample_per_step = args.show_sample_per_step\n", + " self.paths_to_image_steps = [os.path.join( args.outdir, f\"{args.timestring}_{index:02}_{args.seed}\") for index in range(args.n_samples) ]\n", + "\n", + " if self.save_sample_per_step:\n", + " for path in self.paths_to_image_steps:\n", + " os.makedirs(path, exist_ok=True)\n", + "\n", + " self.step_index = 0\n", + "\n", + " self.noise = None\n", + " if init_latent is not None:\n", + " self.noise = torch.randn_like(init_latent, device=device)\n", + "\n", + " self.mask_schedule = None\n", + " if sigmas is not None and len(sigmas) > 0:\n", + " self.mask_schedule, _ = torch.sort(sigmas/torch.max(sigmas))\n", + " elif len(sigmas) == 0:\n", + " self.mask = None # no mask needed if no steps (usually happens because strength==1.0)\n", + "\n", + " if self.sampler_name in [\"plms\",\"ddim\"]: \n", + " if mask is not None:\n", + " assert sampler is not None, \"Callback function for stable-diffusion samplers requires sampler variable\"\n", + "\n", + " if self.sampler_name in [\"plms\",\"ddim\"]: \n", + " # Callback function formated for compvis latent diffusion samplers\n", + " self.callback = self.img_callback_\n", + " else: \n", + " # Default callback function uses k-diffusion sampler variables\n", + " self.callback = self.k_callback_\n", + "\n", + " self.verbose_print = print if verbose else lambda *args, **kwargs: None\n", + "\n", + " def view_sample_step(self, latents, path_name_modifier=''):\n", + " if self.save_sample_per_step or self.show_sample_per_step:\n", + " samples = model.decode_first_stage(latents)\n", + " if self.save_sample_per_step:\n", + " fname = f'{path_name_modifier}_{self.step_index:05}.png'\n", + " for i, sample in enumerate(samples):\n", + " sample = sample.double().cpu().add(1).div(2).clamp(0, 1)\n", + " sample = torch.tensor(np.array(sample))\n", + " grid = make_grid(sample, 4).cpu()\n", + " TF.to_pil_image(grid).save(os.path.join(self.paths_to_image_steps[i], fname))\n", + " if self.show_sample_per_step:\n", + " print(path_name_modifier)\n", + " self.display_images(samples)\n", + " return\n", + "\n", + " def display_images(self, images):\n", + " images = images.double().cpu().add(1).div(2).clamp(0, 1)\n", + " images = torch.tensor(np.array(images))\n", + " grid = make_grid(images, 4).cpu()\n", + " display.display(TF.to_pil_image(grid))\n", + " return\n", + "\n", + " # The callback function is applied to the image at each step\n", + " def dynamic_thresholding_(self, img, threshold):\n", + " # Dynamic thresholding from Imagen paper (May 2022)\n", + " s = np.percentile(np.abs(img.cpu()), threshold, axis=tuple(range(1,img.ndim)))\n", + " s = np.max(np.append(s,1.0))\n", + " torch.clamp_(img, -1*s, s)\n", + " torch.FloatTensor.div_(img, s)\n", + "\n", + " # Callback for samplers in the k-diffusion repo, called thus:\n", + " # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})\n", + " def k_callback_(self, args_dict):\n", + " self.step_index = args_dict['i']\n", + " if self.dynamic_threshold is not None:\n", + " self.dynamic_thresholding_(args_dict['x'], self.dynamic_threshold)\n", + " if self.static_threshold is not None:\n", + " torch.clamp_(args_dict['x'], -1*self.static_threshold, self.static_threshold)\n", + " if self.mask is not None:\n", + " init_noise = self.init_latent + self.noise * args_dict['sigma']\n", + " is_masked = torch.logical_and(self.mask >= self.mask_schedule[args_dict['i']], self.mask != 0 )\n", + " new_img = init_noise * torch.where(is_masked,1,0) + args_dict['x'] * torch.where(is_masked,0,1)\n", + " args_dict['x'].copy_(new_img)\n", + "\n", + " self.view_sample_step(args_dict['denoised'], \"x0_pred\")\n", + "\n", + " # Callback for Compvis samplers\n", + " # Function that is called on the image (img) and step (i) at each step\n", + " def img_callback_(self, img, i):\n", + " self.step_index = i\n", + " # Thresholding functions\n", + " if self.dynamic_threshold is not None:\n", + " self.dynamic_thresholding_(img, self.dynamic_threshold)\n", + " if self.static_threshold is not None:\n", + " torch.clamp_(img, -1*self.static_threshold, self.static_threshold)\n", + " if self.mask is not None:\n", + " i_inv = len(self.sigmas) - i - 1\n", + " init_noise = self.sampler.stochastic_encode(self.init_latent, torch.tensor([i_inv]*self.batch_size).to(device), noise=self.noise)\n", + " is_masked = torch.logical_and(self.mask >= self.mask_schedule[i], self.mask != 0 )\n", + " new_img = init_noise * torch.where(is_masked,1,0) + img * torch.where(is_masked,0,1)\n", + " img.copy_(new_img)\n", + "\n", + " self.view_sample_step(img, \"x\")\n", + "\n", + "def sample_from_cv2(sample: np.ndarray) -> torch.Tensor:\n", + " sample = ((sample.astype(float) / 255.0) * 2) - 1\n", + " sample = sample[None].transpose(0, 3, 1, 2).astype(np.float16)\n", + " sample = torch.from_numpy(sample)\n", + " return sample\n", + "\n", + "def sample_to_cv2(sample: torch.Tensor, type=np.uint8) -> np.ndarray:\n", + " sample_f32 = rearrange(sample.squeeze().cpu().numpy(), \"c h w -> h w c\").astype(np.float32)\n", + " sample_f32 = ((sample_f32 * 0.5) + 0.5).clip(0, 1)\n", + " sample_int8 = (sample_f32 * 255)\n", + " return sample_int8.astype(type)\n", + "\n", + "def transform_image_3d(prev_img_cv2, depth_tensor, rot_mat, translate, anim_args):\n", + " # adapted and optimized version of transform_image_3d from Disco Diffusion https://github.com/alembics/disco-diffusion \n", + " w, h = prev_img_cv2.shape[1], prev_img_cv2.shape[0]\n", + "\n", + " aspect_ratio = float(w)/float(h)\n", + " near, far, fov_deg = anim_args.near_plane, anim_args.far_plane, anim_args.fov\n", + " persp_cam_old = p3d.FoVPerspectiveCameras(near, far, aspect_ratio, fov=fov_deg, degrees=True, device=device)\n", + " persp_cam_new = p3d.FoVPerspectiveCameras(near, far, aspect_ratio, fov=fov_deg, degrees=True, R=rot_mat, T=torch.tensor([translate]), device=device)\n", + "\n", + " # range of [-1,1] is important to torch grid_sample's padding handling\n", + " y,x = torch.meshgrid(torch.linspace(-1.,1.,h,dtype=torch.float32,device=device),torch.linspace(-1.,1.,w,dtype=torch.float32,device=device))\n", + " z = torch.as_tensor(depth_tensor, dtype=torch.float32, device=device)\n", + " xyz_old_world = torch.stack((x.flatten(), y.flatten(), z.flatten()), dim=1)\n", + "\n", + " xyz_old_cam_xy = persp_cam_old.get_full_projection_transform().transform_points(xyz_old_world)[:,0:2]\n", + " xyz_new_cam_xy = persp_cam_new.get_full_projection_transform().transform_points(xyz_old_world)[:,0:2]\n", + "\n", + " offset_xy = xyz_new_cam_xy - xyz_old_cam_xy\n", + " # affine_grid theta param expects a batch of 2D mats. Each is 2x3 to do rotation+translation.\n", + " identity_2d_batch = torch.tensor([[1.,0.,0.],[0.,1.,0.]], device=device).unsqueeze(0)\n", + " # coords_2d will have shape (N,H,W,2).. which is also what grid_sample needs.\n", + " coords_2d = torch.nn.functional.affine_grid(identity_2d_batch, [1,1,h,w], align_corners=False)\n", + " offset_coords_2d = coords_2d - torch.reshape(offset_xy, (h,w,2)).unsqueeze(0)\n", + "\n", + " image_tensor = rearrange(torch.from_numpy(prev_img_cv2.astype(np.float32)), 'h w c -> c h w').to(device)\n", + " new_image = torch.nn.functional.grid_sample(\n", + " image_tensor.add(1/512 - 0.0001).unsqueeze(0), \n", + " offset_coords_2d, \n", + " mode=anim_args.sampling_mode, \n", + " padding_mode=anim_args.padding_mode, \n", + " align_corners=False\n", + " )\n", + "\n", + " # convert back to cv2 style numpy array\n", + " result = rearrange(\n", + " new_image.squeeze().clamp(0,255), \n", + " 'c h w -> h w c'\n", + " ).cpu().numpy().astype(prev_img_cv2.dtype)\n", + " return result\n", + "\n", + "def check_is_number(value):\n", + " float_pattern = r'^(?=.)([+-]?([0-9]*)(\\.([0-9]+))?)$'\n", + " return re.match(float_pattern, value)\n", + "\n", + "# prompt weighting with colons and number coefficients (like 'bacon:0.75 eggs:0.25')\n", + "# borrowed from https://github.com/kylewlacy/stable-diffusion/blob/0a4397094eb6e875f98f9d71193e350d859c4220/ldm/dream/conditioning.py\n", + "# and https://github.com/raefu/stable-diffusion-automatic/blob/unstablediffusion/modules/processing.py\n", + "def get_uc_and_c(prompts, model, args, frame = 0):\n", + " prompt = prompts[0] # they are the same in a batch anyway\n", + "\n", + " # get weighted sub-prompts\n", + " negative_subprompts, positive_subprompts = split_weighted_subprompts(\n", + " prompt, frame, not args.normalize_prompt_weights\n", + " )\n", + "\n", + " uc = get_learned_conditioning(model, negative_subprompts, \"\", args, -1)\n", + " c = get_learned_conditioning(model, positive_subprompts, prompt, args, 1)\n", + "\n", + " return (uc, c)\n", + "\n", + "def get_learned_conditioning(model, weighted_subprompts, text, args, sign = 1):\n", + " if len(weighted_subprompts) < 1:\n", + " log_tokenization(text, model, args.log_weighted_subprompts, sign)\n", + " c = model.get_learned_conditioning(args.n_samples * [text])\n", + " else:\n", + " c = None\n", + " for subtext, subweight in weighted_subprompts:\n", + " log_tokenization(subtext, model, args.log_weighted_subprompts, sign * subweight)\n", + " if c is None:\n", + " c = model.get_learned_conditioning(args.n_samples * [subtext])\n", + " c *= subweight\n", + " else:\n", + " c.add_(model.get_learned_conditioning(args.n_samples * [subtext]), alpha=subweight)\n", + " \n", + " return c\n", + "\n", + "def parse_weight(match, frame = 0)->float:\n", + " import numexpr\n", + " w_raw = match.group(\"weight\")\n", + " if w_raw == None:\n", + " return 1\n", + " if check_is_number(w_raw):\n", + " return float(w_raw)\n", + " else:\n", + " t = frame\n", + " if len(w_raw) < 3:\n", + " print('the value inside `-characters cannot represent a math function')\n", + " return 1\n", + " return float(numexpr.evaluate(w_raw[1:-1]))\n", + "\n", + "def normalize_prompt_weights(parsed_prompts):\n", + " if len(parsed_prompts) == 0:\n", + " return parsed_prompts\n", + " weight_sum = sum(map(lambda x: x[1], parsed_prompts))\n", + " if weight_sum == 0:\n", + " print(\n", + " \"Warning: Subprompt weights add up to zero. Discarding and using even weights instead.\")\n", + " equal_weight = 1 / max(len(parsed_prompts), 1)\n", + " return [(x[0], equal_weight) for x in parsed_prompts]\n", + " return [(x[0], x[1] / weight_sum) for x in parsed_prompts]\n", + "\n", + "def split_weighted_subprompts(text, frame = 0, skip_normalize=False):\n", + " \"\"\"\n", + " grabs all text up to the first occurrence of ':'\n", + " uses the grabbed text as a sub-prompt, and takes the value following ':' as weight\n", + " if ':' has no value defined, defaults to 1.0\n", + " repeats until no text remaining\n", + " \"\"\"\n", + " prompt_parser = re.compile(\"\"\"\n", + " (?P # capture group for 'prompt'\n", + " (?:\\\\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\\:'\n", + " ) # end 'prompt'\n", + " (?: # non-capture group\n", + " :+ # match one or more ':' characters\n", + " (?P(( # capture group for 'weight'\n", + " -?\\d+(?:\\.\\d+)? # match positive or negative integer or decimal number\n", + " )|( # or\n", + " `[\\S\\s]*?`# a math function\n", + " )))? # end weight capture group, make optional\n", + " \\s* # strip spaces after weight\n", + " | # OR\n", + " $ # else, if no ':' then match end of line\n", + " ) # end non-capture group\n", + " \"\"\", re.VERBOSE)\n", + " negative_prompts = []\n", + " positive_prompts = []\n", + " for match in re.finditer(prompt_parser, text):\n", + " w = parse_weight(match, frame)\n", + " if w < 0:\n", + " # negating the sign as we'll feed this to uc\n", + " negative_prompts.append((match.group(\"prompt\").replace(\"\\\\:\", \":\"), -w))\n", + " elif w > 0:\n", + " positive_prompts.append((match.group(\"prompt\").replace(\"\\\\:\", \":\"), w))\n", + "\n", + " if skip_normalize:\n", + " return (negative_prompts, positive_prompts)\n", + " return (normalize_prompt_weights(negative_prompts), normalize_prompt_weights(positive_prompts))\n", + "\n", + "# shows how the prompt is tokenized\n", + "# usually tokens have '' to indicate end-of-word,\n", + "# but for readability it has been replaced with ' '\n", + "def log_tokenization(text, model, log=False, weight=1):\n", + " if not log:\n", + " return\n", + " tokens = model.cond_stage_model.tokenizer._tokenize(text)\n", + " tokenized = \"\"\n", + " discarded = \"\"\n", + " usedTokens = 0\n", + " totalTokens = len(tokens)\n", + " for i in range(0, totalTokens):\n", + " token = tokens[i].replace('', ' ')\n", + " # alternate color\n", + " s = (usedTokens % 6) + 1\n", + " if i < model.cond_stage_model.max_length:\n", + " tokenized = tokenized + f\"\\x1b[0;3{s};40m{token}\"\n", + " usedTokens += 1\n", + " else: # over max token length\n", + " discarded = discarded + f\"\\x1b[0;3{s};40m{token}\"\n", + " print(f\"\\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\\n{tokenized}\\x1b[0m\")\n", + " if discarded != \"\":\n", + " print(\n", + " f\">> Tokens Discarded ({totalTokens-usedTokens}):\\n{discarded}\\x1b[0m\"\n", + " )\n", + "\n", + "def generate(args, frame = 0, return_latent=False, return_sample=False, return_c=False):\n", + " seed_everything(args.seed)\n", + " os.makedirs(args.outdir, exist_ok=True)\n", + "\n", + " sampler = PLMSSampler(model) if args.sampler == 'plms' else DDIMSampler(model)\n", + " model_wrap = CompVisDenoiser(model)\n", + " batch_size = args.n_samples\n", + " prompt = args.prompt\n", + " assert prompt is not None\n", + " data = [batch_size * [prompt]]\n", + " precision_scope = autocast if args.precision == \"autocast\" else nullcontext\n", + "\n", + " init_latent = None\n", + " mask_image = None\n", + " init_image = None\n", + " if args.init_latent is not None:\n", + " init_latent = args.init_latent\n", + " elif args.init_sample is not None:\n", + " with precision_scope(\"cuda\"):\n", + " init_latent = model.get_first_stage_encoding(model.encode_first_stage(args.init_sample))\n", + " elif args.use_init and args.init_image != None and args.init_image != '':\n", + " init_image, mask_image = load_img(args.init_image, \n", + " shape=(args.W, args.H), \n", + " use_alpha_as_mask=args.use_alpha_as_mask)\n", + " init_image = init_image.to(device)\n", + " init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)\n", + " with precision_scope(\"cuda\"):\n", + " init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space \n", + "\n", + " if not args.use_init and args.strength > 0 and args.strength_0_no_init:\n", + " print(\"\\nNo init image, but strength > 0. Strength has been auto set to 0, since use_init is False.\")\n", + " print(\"If you want to force strength > 0 with no init, please set strength_0_no_init to False.\\n\")\n", + " args.strength = 0\n", + "\n", + " # Mask functions\n", + " if args.use_mask:\n", + " assert args.mask_file is not None or mask_image is not None, \"use_mask==True: An mask image is required for a mask. Please enter a mask_file or use an init image with an alpha channel\"\n", + " assert args.use_init, \"use_mask==True: use_init is required for a mask\"\n", + " assert init_latent is not None, \"use_mask==True: An latent init image is required for a mask\"\n", + "\n", + "\n", + " mask = prepare_mask(args.mask_file if mask_image is None else mask_image, \n", + " init_latent.shape, \n", + " args.mask_contrast_adjust, \n", + " args.mask_brightness_adjust)\n", + " \n", + " if (torch.all(mask == 0) or torch.all(mask == 1)) and args.use_alpha_as_mask:\n", + " raise Warning(\"use_alpha_as_mask==True: Using the alpha channel from the init image as a mask, but the alpha channel is blank.\")\n", + " \n", + " mask = mask.to(device)\n", + " mask = repeat(mask, '1 ... -> b ...', b=batch_size)\n", + " else:\n", + " mask = None\n", + "\n", + " assert not ( (args.use_mask and args.overlay_mask) and (args.init_sample is None and init_image is None)), \"Need an init image when use_mask == True and overlay_mask == True\"\n", + " \n", + " t_enc = int((1.0-args.strength) * args.steps)\n", + "\n", + " # Noise schedule for the k-diffusion samplers (used for masking)\n", + " k_sigmas = model_wrap.get_sigmas(args.steps)\n", + " k_sigmas = k_sigmas[len(k_sigmas)-t_enc-1:]\n", + "\n", + " if args.sampler in ['plms','ddim']:\n", + " sampler.make_schedule(ddim_num_steps=args.steps, ddim_eta=args.ddim_eta, ddim_discretize='fill', verbose=False)\n", + "\n", + " callback = SamplerCallback(args=args,\n", + " mask=mask, \n", + " init_latent=init_latent,\n", + " sigmas=k_sigmas,\n", + " sampler=sampler,\n", + " verbose=False).callback \n", + "\n", + " results = []\n", + " with torch.no_grad():\n", + " with precision_scope(\"cuda\"):\n", + " with model.ema_scope():\n", + " for prompts in data:\n", + " if isinstance(prompts, tuple):\n", + " prompts = list(prompts)\n", + " if args.prompt_weighting:\n", + " uc, c = get_uc_and_c(prompts, model, args, frame)\n", + " else:\n", + " uc = model.get_learned_conditioning(batch_size * [\"\"])\n", + " c = model.get_learned_conditioning(prompts)\n", + "\n", + "\n", + " if args.scale == 1.0:\n", + " uc = None\n", + " if args.init_c != None:\n", + " c = args.init_c\n", + "\n", + " if args.sampler in [\"klms\",\"dpm2\",\"dpm2_ancestral\",\"heun\",\"euler\",\"euler_ancestral\"]:\n", + " samples = sampler_fn(\n", + " c=c, \n", + " uc=uc, \n", + " args=args, \n", + " model_wrap=model_wrap, \n", + " init_latent=init_latent, \n", + " t_enc=t_enc, \n", + " device=device, \n", + " cb=callback)\n", + " else:\n", + " # args.sampler == 'plms' or args.sampler == 'ddim':\n", + " if init_latent is not None and args.strength > 0:\n", + " z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))\n", + " else:\n", + " z_enc = torch.randn([args.n_samples, args.C, args.H // args.f, args.W // args.f], device=device)\n", + " if args.sampler == 'ddim':\n", + " samples = sampler.decode(z_enc, \n", + " c, \n", + " t_enc, \n", + " unconditional_guidance_scale=args.scale,\n", + " unconditional_conditioning=uc,\n", + " img_callback=callback)\n", + " elif args.sampler == 'plms': # no \"decode\" function in plms, so use \"sample\"\n", + " shape = [args.C, args.H // args.f, args.W // args.f]\n", + " samples, _ = sampler.sample(S=args.steps,\n", + " conditioning=c,\n", + " batch_size=args.n_samples,\n", + " shape=shape,\n", + " verbose=False,\n", + " unconditional_guidance_scale=args.scale,\n", + " unconditional_conditioning=uc,\n", + " eta=args.ddim_eta,\n", + " x_T=z_enc,\n", + " img_callback=callback)\n", + " else:\n", + " raise Exception(f\"Sampler {args.sampler} not recognised.\")\n", + "\n", + " \n", + " if return_latent:\n", + " results.append(samples.clone())\n", + "\n", + " x_samples = model.decode_first_stage(samples)\n", + "\n", + " if args.use_mask and args.overlay_mask:\n", + " # Overlay the masked image after the image is generated\n", + " if args.init_sample is not None:\n", + " img_original = args.init_sample\n", + " elif init_image is not None:\n", + " img_original = init_image\n", + " else:\n", + " raise Exception(\"Cannot overlay the masked image without an init image to overlay\")\n", + "\n", + " mask_fullres = prepare_mask(args.mask_file if mask_image is None else mask_image, \n", + " img_original.shape, \n", + " args.mask_contrast_adjust, \n", + " args.mask_brightness_adjust)\n", + " mask_fullres = mask_fullres[:,:3,:,:]\n", + " mask_fullres = repeat(mask_fullres, '1 ... -> b ...', b=batch_size)\n", + "\n", + " mask_fullres[mask_fullres < mask_fullres.max()] = 0\n", + " mask_fullres = gaussian_filter(mask_fullres, args.mask_overlay_blur)\n", + " mask_fullres = torch.Tensor(mask_fullres).to(device)\n", + "\n", + " x_samples = img_original * mask_fullres + x_samples * ((mask_fullres * -1.0) + 1)\n", + "\n", + "\n", + " if return_sample:\n", + " results.append(x_samples.clone())\n", + "\n", + " x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)\n", + "\n", + " if return_c:\n", + " results.append(c.clone())\n", + "\n", + " for x_sample in x_samples:\n", + " x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c')\n", + " image = Image.fromarray(x_sample.astype(np.uint8))\n", + " results.append(image)\n", + " return results" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": { + "cellView": "form", + "id": "CIUJ7lWI4v53" + }, + "source": [ + "#@markdown **Select and Load Model**\n", + "\n", + "model_config = \"v1-inference.yaml\" #@param [\"custom\",\"v1-inference.yaml\"]\n", + "model_checkpoint = \"sd-v1-4.ckpt\" #@param [\"custom\",\"sd-v1-4-full-ema.ckpt\",\"sd-v1-4.ckpt\",\"sd-v1-3-full-ema.ckpt\",\"sd-v1-3.ckpt\",\"sd-v1-2-full-ema.ckpt\",\"sd-v1-2.ckpt\",\"sd-v1-1-full-ema.ckpt\",\"sd-v1-1.ckpt\", \"robo-diffusion-v1.ckpt\",\"waifu-diffusion-v1-3.ckpt\"]\n", + "if model_checkpoint == \"waifu-diffusion-v1-3.ckpt\":\n", + " model_checkpoint = \"model-epoch05-float16.ckpt\"\n", + "custom_config_path = \"\" #@param {type:\"string\"}\n", + "custom_checkpoint_path = \"\" #@param {type:\"string\"}\n", + "\n", + "load_on_run_all = True #@param {type: 'boolean'}\n", + "half_precision = True # check\n", + "check_sha256 = True #@param {type:\"boolean\"}\n", + "\n", + "model_map = {\n", + " \"sd-v1-4-full-ema.ckpt\": {\n", + " 'sha256': '14749efc0ae8ef0329391ad4436feb781b402f4fece4883c7ad8d10556d8a36a',\n", + " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/blob/main/sd-v1-4-full-ema.ckpt',\n", + " 'requires_login': True,\n", + " },\n", + " \"sd-v1-4.ckpt\": {\n", + " 'sha256': 'fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556',\n", + " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt',\n", + " 'requires_login': True,\n", + " },\n", + " \"sd-v1-3-full-ema.ckpt\": {\n", + " 'sha256': '54632c6e8a36eecae65e36cb0595fab314e1a1545a65209f24fde221a8d4b2ca',\n", + " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-3-original/blob/main/sd-v1-3-full-ema.ckpt',\n", + " 'requires_login': True,\n", + " },\n", + " \"sd-v1-3.ckpt\": {\n", + " 'sha256': '2cff93af4dcc07c3e03110205988ff98481e86539c51a8098d4f2236e41f7f2f',\n", + " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-3-original/resolve/main/sd-v1-3.ckpt',\n", + " 'requires_login': True,\n", + " },\n", + " \"sd-v1-2-full-ema.ckpt\": {\n", + " 'sha256': 'bc5086a904d7b9d13d2a7bccf38f089824755be7261c7399d92e555e1e9ac69a',\n", + " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/blob/main/sd-v1-2-full-ema.ckpt',\n", + " 'requires_login': True,\n", + " },\n", + " \"sd-v1-2.ckpt\": {\n", + " 'sha256': '3b87d30facd5bafca1cbed71cfb86648aad75d1c264663c0cc78c7aea8daec0d',\n", + " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/resolve/main/sd-v1-2.ckpt',\n", + " 'requires_login': True,\n", + " },\n", + " \"sd-v1-1-full-ema.ckpt\": {\n", + " 'sha256': 'efdeb5dc418a025d9a8cc0a8617e106c69044bc2925abecc8a254b2910d69829',\n", + " 'url':'https://huggingface.co/CompVis/stable-diffusion-v-1-1-original/resolve/main/sd-v1-1-full-ema.ckpt',\n", + " 'requires_login': True,\n", + " },\n", + " \"sd-v1-1.ckpt\": {\n", + " 'sha256': '86cd1d3ccb044d7ba8db743d717c9bac603c4043508ad2571383f954390f3cea',\n", + " 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-1-original/resolve/main/sd-v1-1.ckpt',\n", + " 'requires_login': True,\n", + " },\n", + " \"robo-diffusion-v1.ckpt\": {\n", + " 'sha256': '244dbe0dcb55c761bde9c2ac0e9b46cc9705ebfe5f1f3a7cc46251573ea14e16',\n", + " 'url': 'https://huggingface.co/nousr/robo-diffusion/resolve/main/models/robo-diffusion-v1.ckpt',\n", + " 'requires_login': False,\n", + " },\n", + " \"model-epoch05-float16.ckpt\": {\n", + " 'sha256': '26cf2a2e30095926bb9fd9de0c83f47adc0b442dbfdc3d667d43778e8b70bece',\n", + " 'url': 'https://huggingface.co/hakurei/waifu-diffusion-v1-3/resolve/main/model-epoch05-float16.ckpt',\n", + " 'requires_login': False,\n", + " },\n", + "}\n", + "\n", + "# config path\n", + "ckpt_config_path = custom_config_path if model_config == \"custom\" else os.path.join(models_path, model_config)\n", + "if os.path.exists(ckpt_config_path):\n", + " print(f\"{ckpt_config_path} exists\")\n", + "else:\n", + " ckpt_config_path = \"./stable-diffusion/configs/stable-diffusion/v1-inference.yaml\"\n", + "print(f\"Using config: {ckpt_config_path}\")\n", + "\n", + "# checkpoint path or download\n", + "ckpt_path = custom_checkpoint_path if model_checkpoint == \"custom\" else os.path.join(models_path, model_checkpoint)\n", + "ckpt_valid = True\n", + "if os.path.exists(ckpt_path):\n", + " print(f\"{ckpt_path} exists\")\n", + "elif 'url' in model_map[model_checkpoint]:\n", + " url = model_map[model_checkpoint]['url']\n", + "\n", + " # CLI dialogue to authenticate download\n", + " if model_map[model_checkpoint]['requires_login']:\n", + " print(\"This model requires an authentication token\")\n", + " print(\"Please ensure you have accepted its terms of service before continuing.\")\n", + "\n", + " username = input(\"What is your huggingface username?:\")\n", + " token = input(\"What is your huggingface token?:\")\n", + "\n", + " _, path = url.split(\"https://\")\n", + "\n", + " url = f\"https://{username}:{token}@{path}\"\n", + "\n", + " # contact server for model\n", + " print(f\"Attempting to download {model_checkpoint}...this may take a while\")\n", + " ckpt_request = requests.get(url)\n", + " request_status = ckpt_request.status_code\n", + "\n", + " # inform user of errors\n", + " if request_status == 403:\n", + " raise ConnectionRefusedError(\"You have not accepted the license for this model.\")\n", + " elif request_status == 404:\n", + " raise ConnectionError(\"Could not make contact with server\")\n", + " elif request_status != 200:\n", + " raise ConnectionError(f\"Some other error has ocurred - response code: {request_status}\")\n", + "\n", + " # write to model path\n", + " with open(os.path.join(models_path, model_checkpoint), 'wb') as model_file:\n", + " model_file.write(ckpt_request.content)\n", + "else:\n", + " print(f\"Please download model checkpoint and place in {os.path.join(models_path, model_checkpoint)}\")\n", + " ckpt_valid = False\n", + "\n", + "if check_sha256 and model_checkpoint != \"custom\" and ckpt_valid:\n", + " import hashlib\n", + " print(\"\\n...checking sha256\")\n", + " with open(ckpt_path, \"rb\") as f:\n", + " bytes = f.read() \n", + " hash = hashlib.sha256(bytes).hexdigest()\n", + " del bytes\n", + " if model_map[model_checkpoint][\"sha256\"] == hash:\n", + " print(\"hash is correct\\n\")\n", + " else:\n", + " print(\"hash in not correct\\n\")\n", + " ckpt_valid = False\n", + "\n", + "if ckpt_valid:\n", + " print(f\"Using ckpt: {ckpt_path}\")\n", + "\n", + "def load_model_from_config(config, ckpt, verbose=False, device='cuda', half_precision=True):\n", + " map_location = \"cuda\" #@param [\"cpu\", \"cuda\"]\n", + " print(f\"Loading model from {ckpt}\")\n", + " pl_sd = torch.load(ckpt, map_location=map_location)\n", + " if \"global_step\" in pl_sd:\n", + " print(f\"Global Step: {pl_sd['global_step']}\")\n", + " sd = pl_sd[\"state_dict\"]\n", + " model = instantiate_from_config(config.model)\n", + " m, u = model.load_state_dict(sd, strict=False)\n", + " if len(m) > 0 and verbose:\n", + " print(\"missing keys:\")\n", + " print(m)\n", + " if len(u) > 0 and verbose:\n", + " print(\"unexpected keys:\")\n", + " print(u)\n", + "\n", + " if half_precision:\n", + " model = model.half().to(device)\n", + " else:\n", + " model = model.to(device)\n", + " model.eval()\n", + " return model\n", + "\n", + "if load_on_run_all and ckpt_valid:\n", + " local_config = OmegaConf.load(f\"{ckpt_config_path}\")\n", + " model = load_model_from_config(local_config, f\"{ckpt_path}\", half_precision=half_precision)\n", + " device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", + " model = model.to(device)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ov3r4RD1tzsT" + }, + "source": [ + "# Settings" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0j7rgxvLvfay" + }, + "source": [ + "### Animation Settings" + ] + }, + { + "cell_type": "code", + "metadata": { + "cellView": "form", + "id": "8HJN2TE3vh-J" + }, + "source": [ + "\n", + "def DeforumAnimArgs():\n", + "\n", + " #@markdown ####**Animation:**\n", + " animation_mode = 'None' #@param ['None', '2D', '3D', 'Video Input', 'Interpolation'] {type:'string'}\n", + " max_frames = 1000 #@param {type:\"number\"}\n", + " border = 'replicate' #@param ['wrap', 'replicate'] {type:'string'}\n", + "\n", + " #@markdown ####**Motion Parameters:**\n", + " angle = \"0:(0)\"#@param {type:\"string\"}\n", + " zoom = \"0:(1.04)\"#@param {type:\"string\"}\n", + " translation_x = \"0:(10*sin(2*3.14*t/10))\"#@param {type:\"string\"}\n", + " translation_y = \"0:(0)\"#@param {type:\"string\"}\n", + " translation_z = \"0:(10)\"#@param {type:\"string\"}\n", + " rotation_3d_x = \"0:(0)\"#@param {type:\"string\"}\n", + " rotation_3d_y = \"0:(0)\"#@param {type:\"string\"}\n", + " rotation_3d_z = \"0:(0)\"#@param {type:\"string\"}\n", + " flip_2d_perspective = False #@param {type:\"boolean\"}\n", + " perspective_flip_theta = \"0:(0)\"#@param {type:\"string\"}\n", + " perspective_flip_phi = \"0:(t%15)\"#@param {type:\"string\"}\n", + " perspective_flip_gamma = \"0:(0)\"#@param {type:\"string\"}\n", + " perspective_flip_fv = \"0:(53)\"#@param {type:\"string\"}\n", + " noise_schedule = \"0: (0.02)\"#@param {type:\"string\"}\n", + " strength_schedule = \"0: (0.65)\"#@param {type:\"string\"}\n", + " contrast_schedule = \"0: (1.0)\"#@param {type:\"string\"}\n", + "\n", + " #@markdown ####**Coherence:**\n", + " color_coherence = 'Match Frame 0 LAB' #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'}\n", + " diffusion_cadence = '1' #@param ['1','2','3','4','5','6','7','8'] {type:'string'}\n", + "\n", + " #@markdown ####**3D Depth Warping:**\n", + " use_depth_warping = True #@param {type:\"boolean\"}\n", + " midas_weight = 0.3#@param {type:\"number\"}\n", + " near_plane = 200\n", + " far_plane = 10000\n", + " fov = 40#@param {type:\"number\"}\n", + " padding_mode = 'border'#@param ['border', 'reflection', 'zeros'] {type:'string'}\n", + " sampling_mode = 'bicubic'#@param ['bicubic', 'bilinear', 'nearest'] {type:'string'}\n", + " save_depth_maps = False #@param {type:\"boolean\"}\n", + "\n", + " #@markdown ####**Video Input:**\n", + " video_init_path ='/content/video_in.mp4'#@param {type:\"string\"}\n", + " extract_nth_frame = 1#@param {type:\"number\"}\n", + " overwrite_extracted_frames = True #@param {type:\"boolean\"}\n", + " use_mask_video = False #@param {type:\"boolean\"}\n", + " video_mask_path ='/content/video_in.mp4'#@param {type:\"string\"}\n", + "\n", + " #@markdown ####**Interpolation:**\n", + " interpolate_key_frames = False #@param {type:\"boolean\"}\n", + " interpolate_x_frames = 4 #@param {type:\"number\"}\n", + " \n", + " #@markdown ####**Resume Animation:**\n", + " resume_from_timestring = False #@param {type:\"boolean\"}\n", + " resume_timestring = \"20220829210106\" #@param {type:\"string\"}\n", + "\n", + " return locals()\n", + "\n", + "class DeformAnimKeys():\n", + " def __init__(self, anim_args):\n", + " self.angle_series = get_inbetweens(parse_key_frames(anim_args.angle), anim_args.max_frames)\n", + " self.zoom_series = get_inbetweens(parse_key_frames(anim_args.zoom), anim_args.max_frames)\n", + " self.translation_x_series = get_inbetweens(parse_key_frames(anim_args.translation_x), anim_args.max_frames)\n", + " self.translation_y_series = get_inbetweens(parse_key_frames(anim_args.translation_y), anim_args.max_frames)\n", + " self.translation_z_series = get_inbetweens(parse_key_frames(anim_args.translation_z), anim_args.max_frames)\n", + " self.rotation_3d_x_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_x), anim_args.max_frames)\n", + " self.rotation_3d_y_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_y), anim_args.max_frames)\n", + " self.rotation_3d_z_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_z), anim_args.max_frames)\n", + " self.perspective_flip_theta_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_theta), anim_args.max_frames)\n", + " self.perspective_flip_phi_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_phi), anim_args.max_frames)\n", + " self.perspective_flip_gamma_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_gamma), anim_args.max_frames)\n", + " self.perspective_flip_fv_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_fv), anim_args.max_frames)\n", + " self.noise_schedule_series = get_inbetweens(parse_key_frames(anim_args.noise_schedule), anim_args.max_frames)\n", + " self.strength_schedule_series = get_inbetweens(parse_key_frames(anim_args.strength_schedule), anim_args.max_frames)\n", + " self.contrast_schedule_series = get_inbetweens(parse_key_frames(anim_args.contrast_schedule), anim_args.max_frames)\n", + "\n", + "\n", + "def get_inbetweens(key_frames, max_frames, integer=False, interp_method='Linear'):\n", + " import numexpr\n", + " key_frame_series = pd.Series([np.nan for a in range(max_frames)])\n", + " \n", + " for i in range(0, max_frames):\n", + " if i in key_frames:\n", + " value = key_frames[i]\n", + " value_is_number = check_is_number(value)\n", + " # if it's only a number, leave the rest for the default interpolation\n", + " if value_is_number:\n", + " t = i\n", + " key_frame_series[i] = value\n", + " if not value_is_number:\n", + " t = i\n", + " key_frame_series[i] = numexpr.evaluate(value)\n", + " key_frame_series = key_frame_series.astype(float)\n", + " \n", + " if interp_method == 'Cubic' and len(key_frames.items()) <= 3:\n", + " interp_method = 'Quadratic' \n", + " if interp_method == 'Quadratic' and len(key_frames.items()) <= 2:\n", + " interp_method = 'Linear'\n", + " \n", + " key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()]\n", + " key_frame_series[max_frames-1] = key_frame_series[key_frame_series.last_valid_index()]\n", + " key_frame_series = key_frame_series.interpolate(method=interp_method.lower(), limit_direction='both')\n", + " if integer:\n", + " return key_frame_series.astype(int)\n", + " return key_frame_series\n", + "\n", + "def parse_key_frames(string, prompt_parser=None):\n", + " # because math functions (i.e. sin(t)) can utilize brackets \n", + " # it extracts the value in form of some stuff\n", + " # which has previously been enclosed with brackets and\n", + " # with a comma or end of line existing after the closing one\n", + " pattern = r'((?P[0-9]+):[\\s]*\\((?P[\\S\\s]*?)\\)([,][\\s]?|[\\s]?$))'\n", + " frames = dict()\n", + " for match_object in re.finditer(pattern, string):\n", + " frame = int(match_object.groupdict()['frame'])\n", + " param = match_object.groupdict()['param']\n", + " if prompt_parser:\n", + " frames[frame] = prompt_parser(param)\n", + " else:\n", + " frames[frame] = param\n", + " if frames == {} and len(string) != 0:\n", + " raise RuntimeError('Key Frame string not correctly formatted')\n", + " return frames" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": { + "id": "63UOJvU3xdPS" + }, + "source": [ + "### Prompts\n", + "`animation_mode: None` batches on list of *prompts*. `animation_mode: 2D` uses *animation_prompts* key frame sequence" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "2ujwkGZTcGev" + }, + "source": [ + "\n", + "prompts = [\n", + " \"a beautiful forest by Asher Brown Durand, trending on Artstation\", # the first prompt I want\n", + " \"a beautiful portrait of a woman by Artgerm, trending on Artstation\", # the second prompt I want\n", + " #\"this prompt I don't want it I commented it out\",\n", + " #\"a nousr robot, trending on Artstation\", # use \"nousr robot\" with the robot diffusion model (see model_checkpoint setting)\n", + " #\"touhou 1girl komeiji_koishi portrait, green hair\", # waifu diffusion prompts can use danbooru tag groups (see model_checkpoint)\n", + " #\"this prompt has weights if prompt weighting enabled:2 can also do negative:-2\", # (see prompt_weighting)\n", + "]\n", + "\n", + "animation_prompts = {\n", + " 0: \"a beautiful apple, trending on Artstation\",\n", + " 20: \"a beautiful banana, trending on Artstation\",\n", + " 30: \"a beautiful coconut, trending on Artstation\",\n", + " 40: \"a beautiful durian, trending on Artstation\",\n", + "}" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": { + "id": "s8RAo2zI-vQm" + }, + "source": [ + "# Run" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "qH74gBWDd2oq", + "cellView": "form" + }, + "source": [ + "#@markdown **Load Settings**\n", + "override_settings_with_file = False #@param {type:\"boolean\"}\n", + "custom_settings_file = \"/content/drive/MyDrive/Settings.txt\"#@param {type:\"string\"}\n", + "\n", + "def DeforumArgs():\n", + " #@markdown **Image Settings**\n", + " W = 512 #@param\n", + " H = 512 #@param\n", + " W, H = map(lambda x: x - x % 64, (W, H)) # resize to integer multiple of 64\n", + "\n", + " #@markdown **Sampling Settings**\n", + " seed = -1 #@param\n", + " sampler = 'klms' #@param [\"klms\",\"dpm2\",\"dpm2_ancestral\",\"heun\",\"euler\",\"euler_ancestral\",\"plms\", \"ddim\"]\n", + " steps = 50 #@param\n", + " scale = 7 #@param\n", + " ddim_eta = 0.0 #@param\n", + " dynamic_threshold = None\n", + " static_threshold = None \n", + "\n", + " #@markdown **Save & Display Settings**\n", + " save_samples = True #@param {type:\"boolean\"}\n", + " save_settings = True #@param {type:\"boolean\"}\n", + " display_samples = True #@param {type:\"boolean\"}\n", + " save_sample_per_step = False #@param {type:\"boolean\"}\n", + " show_sample_per_step = False #@param {type:\"boolean\"}\n", + "\n", + " #@markdown **Prompt Settings**\n", + " prompt_weighting = False #@param {type:\"boolean\"}\n", + " normalize_prompt_weights = True #@param {type:\"boolean\"}\n", + " log_weighted_subprompts = False #@param {type:\"boolean\"}\n", + "\n", + " #@markdown **Batch Settings**\n", + " n_batch = 1 #@param\n", + " batch_name = \"StableFun\" #@param {type:\"string\"}\n", + " filename_format = \"{timestring}_{index}_{prompt}.png\" #@param [\"{timestring}_{index}_{seed}.png\",\"{timestring}_{index}_{prompt}.png\"]\n", + " seed_behavior = \"iter\" #@param [\"iter\",\"fixed\",\"random\"]\n", + " make_grid = False #@param {type:\"boolean\"}\n", + " grid_rows = 2 #@param \n", + " outdir = get_output_folder(output_path, batch_name)\n", + "\n", + " #@markdown **Init Settings**\n", + " use_init = False #@param {type:\"boolean\"}\n", + " strength = 0.0 #@param {type:\"number\"}\n", + " strength_0_no_init = True # Set the strength to 0 automatically when no init image is used\n", + " init_image = \"https://cdn.pixabay.com/photo/2022/07/30/13/10/green-longhorn-beetle-7353749_1280.jpg\" #@param {type:\"string\"}\n", + " # Whiter areas of the mask are areas that change more\n", + " use_mask = False #@param {type:\"boolean\"}\n", + " use_alpha_as_mask = False # use the alpha channel of the init image as the mask\n", + " mask_file = \"https://www.filterforge.com/wiki/images/archive/b/b7/20080927223728%21Polygonal_gradient_thumb.jpg\" #@param {type:\"string\"}\n", + " invert_mask = False #@param {type:\"boolean\"}\n", + " # Adjust mask image, 1.0 is no adjustment. Should be positive numbers.\n", + " mask_brightness_adjust = 1.0 #@param {type:\"number\"}\n", + " mask_contrast_adjust = 1.0 #@param {type:\"number\"}\n", + " # Overlay the masked image at the end of the generation so it does not get degraded by encoding and decoding\n", + " overlay_mask = True # {type:\"boolean\"}\n", + " # Blur edges of final overlay mask, if used. Minimum = 0 (no blur)\n", + " mask_overlay_blur = 5 # {type:\"number\"}\n", + "\n", + " n_samples = 1 # doesnt do anything\n", + " precision = 'autocast' \n", + " C = 4\n", + " f = 8\n", + "\n", + " prompt = \"\"\n", + " timestring = \"\"\n", + " init_latent = None\n", + " init_sample = None\n", + " init_c = None\n", + "\n", + " return locals()\n", + "\n", + "\n", + "\n", + "def next_seed(args):\n", + " if args.seed_behavior == 'iter':\n", + " args.seed += 1\n", + " elif args.seed_behavior == 'fixed':\n", + " pass # always keep seed the same\n", + " else:\n", + " args.seed = random.randint(0, 2**32 - 1)\n", + " return args.seed\n", + "\n", + "def render_image_batch(args):\n", + " args.prompts = {k: f\"{v:05d}\" for v, k in enumerate(prompts)}\n", + " \n", + " # create output folder for the batch\n", + " os.makedirs(args.outdir, exist_ok=True)\n", + " if args.save_settings or args.save_samples:\n", + " print(f\"Saving to {os.path.join(args.outdir, args.timestring)}_*\")\n", + "\n", + " # save settings for the batch\n", + " if args.save_settings:\n", + " filename = os.path.join(args.outdir, f\"{args.timestring}_settings.txt\")\n", + " with open(filename, \"w+\", encoding=\"utf-8\") as f:\n", + " json.dump(dict(args.__dict__), f, ensure_ascii=False, indent=4)\n", + "\n", + " index = 0\n", + " \n", + " # function for init image batching\n", + " init_array = []\n", + " if args.use_init:\n", + " if args.init_image == \"\":\n", + " raise FileNotFoundError(\"No path was given for init_image\")\n", + " if args.init_image.startswith('http://') or args.init_image.startswith('https://'):\n", + " init_array.append(args.init_image)\n", + " elif not os.path.isfile(args.init_image):\n", + " if args.init_image[-1] != \"/\": # avoids path error by adding / to end if not there\n", + " args.init_image += \"/\" \n", + " for image in sorted(os.listdir(args.init_image)): # iterates dir and appends images to init_array\n", + " if image.split(\".\")[-1] in (\"png\", \"jpg\", \"jpeg\"):\n", + " init_array.append(args.init_image + image)\n", + " else:\n", + " init_array.append(args.init_image)\n", + " else:\n", + " init_array = [\"\"]\n", + "\n", + " # when doing large batches don't flood browser with images\n", + " clear_between_batches = args.n_batch >= 32\n", + "\n", + " for iprompt, prompt in enumerate(prompts): \n", + " args.prompt = prompt\n", + " print(f\"Prompt {iprompt+1} of {len(prompts)}\")\n", + " print(f\"{args.prompt}\")\n", + "\n", + " all_images = []\n", + "\n", + " for batch_index in range(args.n_batch):\n", + " if clear_between_batches and batch_index % 32 == 0: \n", + " display.clear_output(wait=True) \n", + " print(f\"Batch {batch_index+1} of {args.n_batch}\")\n", + " \n", + " for image in init_array: # iterates the init images\n", + " args.init_image = image\n", + " results = generate(args)\n", + " for image in results:\n", + " if args.make_grid:\n", + " all_images.append(T.functional.pil_to_tensor(image))\n", + " if args.save_samples:\n", + " if args.filename_format == \"{timestring}_{index}_{prompt}.png\":\n", + " filename = f\"{args.timestring}_{index:05}_{sanitize(prompt)[:160]}.png\"\n", + " else:\n", + " filename = f\"{args.timestring}_{index:05}_{args.seed}.png\"\n", + " image.save(os.path.join(args.outdir, filename))\n", + " if args.display_samples:\n", + " display.display(image)\n", + " index += 1\n", + " args.seed = next_seed(args)\n", + "\n", + " #print(len(all_images))\n", + " if args.make_grid:\n", + " grid = make_grid(all_images, nrow=int(len(all_images)/args.grid_rows))\n", + " grid = rearrange(grid, 'c h w -> h w c').cpu().numpy()\n", + " filename = f\"{args.timestring}_{iprompt:05d}_grid_{args.seed}.png\"\n", + " grid_image = Image.fromarray(grid.astype(np.uint8))\n", + " grid_image.save(os.path.join(args.outdir, filename))\n", + " display.clear_output(wait=True) \n", + " display.display(grid_image)\n", + "\n", + "\n", + "def render_animation(args, anim_args):\n", + " # animations use key framed prompts\n", + " args.prompts = animation_prompts\n", + "\n", + " # expand key frame strings to values\n", + " keys = DeformAnimKeys(anim_args)\n", + "\n", + " # resume animation\n", + " start_frame = 0\n", + " if anim_args.resume_from_timestring:\n", + " for tmp in os.listdir(args.outdir):\n", + " if tmp.split(\"_\")[0] == anim_args.resume_timestring:\n", + " start_frame += 1\n", + " start_frame = start_frame - 1\n", + "\n", + " # create output folder for the batch\n", + " os.makedirs(args.outdir, exist_ok=True)\n", + " print(f\"Saving animation frames to {args.outdir}\")\n", + "\n", + " # save settings for the batch\n", + " settings_filename = os.path.join(args.outdir, f\"{args.timestring}_settings.txt\")\n", + " with open(settings_filename, \"w+\", encoding=\"utf-8\") as f:\n", + " s = {**dict(args.__dict__), **dict(anim_args.__dict__)}\n", + " json.dump(s, f, ensure_ascii=False, indent=4)\n", + " \n", + " # resume from timestring\n", + " if anim_args.resume_from_timestring:\n", + " args.timestring = anim_args.resume_timestring\n", + "\n", + " # expand prompts out to per-frame\n", + " prompt_series = pd.Series([np.nan for a in range(anim_args.max_frames)])\n", + " for i, prompt in animation_prompts.items():\n", + " prompt_series[int(i)] = prompt\n", + " prompt_series = prompt_series.ffill().bfill()\n", + "\n", + " # check for video inits\n", + " using_vid_init = anim_args.animation_mode == 'Video Input'\n", + "\n", + " # load depth model for 3D\n", + " predict_depths = (anim_args.animation_mode == '3D' and anim_args.use_depth_warping) or anim_args.save_depth_maps\n", + " if predict_depths:\n", + " depth_model = DepthModel(device)\n", + " depth_model.load_midas(models_path)\n", + " if anim_args.midas_weight < 1.0:\n", + " depth_model.load_adabins()\n", + " else:\n", + " depth_model = None\n", + " anim_args.save_depth_maps = False\n", + "\n", + " # state for interpolating between diffusion steps\n", + " turbo_steps = 1 if using_vid_init else int(anim_args.diffusion_cadence)\n", + " turbo_prev_image, turbo_prev_frame_idx = None, 0\n", + " turbo_next_image, turbo_next_frame_idx = None, 0\n", + "\n", + " # resume animation\n", + " prev_sample = None\n", + " color_match_sample = None\n", + " if anim_args.resume_from_timestring:\n", + " last_frame = start_frame-1\n", + " if turbo_steps > 1:\n", + " last_frame -= last_frame%turbo_steps\n", + " path = os.path.join(args.outdir,f\"{args.timestring}_{last_frame:05}.png\")\n", + " img = cv2.imread(path)\n", + " img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)\n", + " prev_sample = sample_from_cv2(img)\n", + " if anim_args.color_coherence != 'None':\n", + " color_match_sample = img\n", + " if turbo_steps > 1:\n", + " turbo_next_image, turbo_next_frame_idx = sample_to_cv2(prev_sample, type=np.float32), last_frame\n", + " turbo_prev_image, turbo_prev_frame_idx = turbo_next_image, turbo_next_frame_idx\n", + " start_frame = last_frame+turbo_steps\n", + "\n", + " args.n_samples = 1\n", + " frame_idx = start_frame\n", + " while frame_idx < anim_args.max_frames:\n", + " print(f\"Rendering animation frame {frame_idx} of {anim_args.max_frames}\")\n", + " noise = keys.noise_schedule_series[frame_idx]\n", + " strength = keys.strength_schedule_series[frame_idx]\n", + " contrast = keys.contrast_schedule_series[frame_idx]\n", + " depth = None\n", + " \n", + " # emit in-between frames\n", + " if turbo_steps > 1:\n", + " tween_frame_start_idx = max(0, frame_idx-turbo_steps)\n", + " for tween_frame_idx in range(tween_frame_start_idx, frame_idx):\n", + " tween = float(tween_frame_idx - tween_frame_start_idx + 1) / float(frame_idx - tween_frame_start_idx)\n", + " print(f\" creating in between frame {tween_frame_idx} tween:{tween:0.2f}\")\n", + "\n", + " advance_prev = turbo_prev_image is not None and tween_frame_idx > turbo_prev_frame_idx\n", + " advance_next = tween_frame_idx > turbo_next_frame_idx\n", + "\n", + " if depth_model is not None:\n", + " assert(turbo_next_image is not None)\n", + " depth = depth_model.predict(turbo_next_image, anim_args)\n", + "\n", + " if anim_args.animation_mode == '2D':\n", + " if advance_prev:\n", + " turbo_prev_image = anim_frame_warp_2d(turbo_prev_image, args, anim_args, keys, tween_frame_idx)\n", + " if advance_next:\n", + " turbo_next_image = anim_frame_warp_2d(turbo_next_image, args, anim_args, keys, tween_frame_idx)\n", + " else: # '3D'\n", + " if advance_prev:\n", + " turbo_prev_image = anim_frame_warp_3d(turbo_prev_image, depth, anim_args, keys, tween_frame_idx)\n", + " if advance_next:\n", + " turbo_next_image = anim_frame_warp_3d(turbo_next_image, depth, anim_args, keys, tween_frame_idx)\n", + " turbo_prev_frame_idx = turbo_next_frame_idx = tween_frame_idx\n", + "\n", + " if turbo_prev_image is not None and tween < 1.0:\n", + " img = turbo_prev_image*(1.0-tween) + turbo_next_image*tween\n", + " else:\n", + " img = turbo_next_image\n", + "\n", + " filename = f\"{args.timestring}_{tween_frame_idx:05}.png\"\n", + " cv2.imwrite(os.path.join(args.outdir, filename), cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_RGB2BGR))\n", + " if anim_args.save_depth_maps:\n", + " depth_model.save(os.path.join(args.outdir, f\"{args.timestring}_depth_{tween_frame_idx:05}.png\"), depth)\n", + " if turbo_next_image is not None:\n", + " prev_sample = sample_from_cv2(turbo_next_image)\n", + "\n", + " # apply transforms to previous frame\n", + " if prev_sample is not None:\n", + " if anim_args.animation_mode == '2D':\n", + " prev_img = anim_frame_warp_2d(sample_to_cv2(prev_sample), args, anim_args, keys, frame_idx)\n", + " else: # '3D'\n", + " prev_img_cv2 = sample_to_cv2(prev_sample)\n", + " depth = depth_model.predict(prev_img_cv2, anim_args) if depth_model else None\n", + " prev_img = anim_frame_warp_3d(prev_img_cv2, depth, anim_args, keys, frame_idx)\n", + "\n", + " # apply color matching\n", + " if anim_args.color_coherence != 'None':\n", + " if color_match_sample is None:\n", + " color_match_sample = prev_img.copy()\n", + " else:\n", + " prev_img = maintain_colors(prev_img, color_match_sample, anim_args.color_coherence)\n", + "\n", + " # apply scaling\n", + " contrast_sample = prev_img * contrast\n", + " # apply frame noising\n", + " noised_sample = add_noise(sample_from_cv2(contrast_sample), noise)\n", + "\n", + " # use transformed previous frame as init for current\n", + " args.use_init = True\n", + " if half_precision:\n", + " args.init_sample = noised_sample.half().to(device)\n", + " else:\n", + " args.init_sample = noised_sample.to(device)\n", + " args.strength = max(0.0, min(1.0, strength))\n", + "\n", + " # grab prompt for current frame\n", + " args.prompt = prompt_series[frame_idx]\n", + " print(f\"{args.prompt} {args.seed}\")\n", + " if not using_vid_init:\n", + " print(f\"Angle: {keys.angle_series[frame_idx]} Zoom: {keys.zoom_series[frame_idx]}\")\n", + " print(f\"Tx: {keys.translation_x_series[frame_idx]} Ty: {keys.translation_y_series[frame_idx]} Tz: {keys.translation_z_series[frame_idx]}\")\n", + " print(f\"Rx: {keys.rotation_3d_x_series[frame_idx]} Ry: {keys.rotation_3d_y_series[frame_idx]} Rz: {keys.rotation_3d_z_series[frame_idx]}\")\n", + "\n", + " # grab init image for current frame\n", + " if using_vid_init:\n", + " init_frame = os.path.join(args.outdir, 'inputframes', f\"{frame_idx+1:05}.jpg\") \n", + " print(f\"Using video init frame {init_frame}\")\n", + " args.init_image = init_frame\n", + " if anim_args.use_mask_video:\n", + " mask_frame = os.path.join(args.outdir, 'maskframes', f\"{frame_idx+1:05}.jpg\")\n", + " args.mask_file = mask_frame\n", + "\n", + " # sample the diffusion model\n", + " sample, image = generate(args, frame_idx, return_latent=False, return_sample=True)\n", + " if not using_vid_init:\n", + " prev_sample = sample\n", + "\n", + " if turbo_steps > 1:\n", + " turbo_prev_image, turbo_prev_frame_idx = turbo_next_image, turbo_next_frame_idx\n", + " turbo_next_image, turbo_next_frame_idx = sample_to_cv2(sample, type=np.float32), frame_idx\n", + " frame_idx += turbo_steps\n", + " else: \n", + " filename = f\"{args.timestring}_{frame_idx:05}.png\"\n", + " image.save(os.path.join(args.outdir, filename))\n", + " if anim_args.save_depth_maps:\n", + " if depth is None:\n", + " depth = depth_model.predict(sample_to_cv2(sample), anim_args)\n", + " depth_model.save(os.path.join(args.outdir, f\"{args.timestring}_depth_{frame_idx:05}.png\"), depth)\n", + " frame_idx += 1\n", + "\n", + " display.clear_output(wait=True)\n", + " display.display(image)\n", + "\n", + " args.seed = next_seed(args)\n", + "\n", + "def vid2frames(video_path, frames_path, n=1, overwrite=True): \n", + " if not os.path.exists(frames_path) or overwrite: \n", + " try:\n", + " for f in pathlib.Path(video_in_frame_path).glob('*.jpg'):\n", + " f.unlink()\n", + " except:\n", + " pass\n", + " assert os.path.exists(video_path), f\"Video input {video_path} does not exist\"\n", + " \n", + " vidcap = cv2.VideoCapture(video_path)\n", + " success,image = vidcap.read()\n", + " count = 0\n", + " t=1\n", + " success = True\n", + " while success:\n", + " if count % n == 0:\n", + " cv2.imwrite(frames_path + os.path.sep + f\"{t:05}.jpg\" , image) # save frame as JPEG file\n", + " t += 1\n", + " success,image = vidcap.read()\n", + " count += 1\n", + " print(\"Converted %d frames\" % count)\n", + " else: print(\"Frames already unpacked\")\n", + "\n", + "def render_input_video(args, anim_args):\n", + " # create a folder for the video input frames to live in\n", + " video_in_frame_path = os.path.join(args.outdir, 'inputframes') \n", + " os.makedirs(video_in_frame_path, exist_ok=True)\n", + " \n", + " # save the video frames from input video\n", + " print(f\"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {video_in_frame_path}...\")\n", + " vid2frames(anim_args.video_init_path, video_in_frame_path, anim_args.extract_nth_frame, anim_args.overwrite_extracted_frames)\n", + "\n", + " # determine max frames from length of input frames\n", + " anim_args.max_frames = len([f for f in pathlib.Path(video_in_frame_path).glob('*.jpg')])\n", + " args.use_init = True\n", + " print(f\"Loading {anim_args.max_frames} input frames from {video_in_frame_path} and saving video frames to {args.outdir}\")\n", + "\n", + " if anim_args.use_mask_video:\n", + " # create a folder for the mask video input frames to live in\n", + " mask_in_frame_path = os.path.join(args.outdir, 'maskframes') \n", + " os.makedirs(mask_in_frame_path, exist_ok=True)\n", + "\n", + " # save the video frames from mask video\n", + " print(f\"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {mask_in_frame_path}...\")\n", + " vid2frames(anim_args.video_mask_path, mask_in_frame_path, anim_args.extract_nth_frame, anim_args.overwrite_extracted_frames)\n", + " args.use_mask = True\n", + " args.overlay_mask = True\n", + "\n", + " render_animation(args, anim_args)\n", + "\n", + "def render_interpolation(args, anim_args):\n", + " # animations use key framed prompts\n", + " args.prompts = animation_prompts\n", + "\n", + " # create output folder for the batch\n", + " os.makedirs(args.outdir, exist_ok=True)\n", + " print(f\"Saving animation frames to {args.outdir}\")\n", + "\n", + " # save settings for the batch\n", + " settings_filename = os.path.join(args.outdir, f\"{args.timestring}_settings.txt\")\n", + " with open(settings_filename, \"w+\", encoding=\"utf-8\") as f:\n", + " s = {**dict(args.__dict__), **dict(anim_args.__dict__)}\n", + " json.dump(s, f, ensure_ascii=False, indent=4)\n", + " \n", + " # Interpolation Settings\n", + " args.n_samples = 1\n", + " args.seed_behavior = 'fixed' # force fix seed at the moment bc only 1 seed is available\n", + " prompts_c_s = [] # cache all the text embeddings\n", + "\n", + " print(f\"Preparing for interpolation of the following...\")\n", + "\n", + " for i, prompt in animation_prompts.items():\n", + " args.prompt = prompt\n", + "\n", + " # sample the diffusion model\n", + " results = generate(args, return_c=True)\n", + " c, image = results[0], results[1]\n", + " prompts_c_s.append(c) \n", + " \n", + " # display.clear_output(wait=True)\n", + " display.display(image)\n", + " \n", + " args.seed = next_seed(args)\n", + "\n", + " display.clear_output(wait=True)\n", + " print(f\"Interpolation start...\")\n", + "\n", + " frame_idx = 0\n", + "\n", + " if anim_args.interpolate_key_frames:\n", + " for i in range(len(prompts_c_s)-1):\n", + " dist_frames = list(animation_prompts.items())[i+1][0] - list(animation_prompts.items())[i][0]\n", + " if dist_frames <= 0:\n", + " print(\"key frames duplicated or reversed. interpolation skipped.\")\n", + " return\n", + " else:\n", + " for j in range(dist_frames):\n", + " # interpolate the text embedding\n", + " prompt1_c = prompts_c_s[i]\n", + " prompt2_c = prompts_c_s[i+1] \n", + " args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/dist_frames))\n", + "\n", + " # sample the diffusion model\n", + " results = generate(args)\n", + " image = results[0]\n", + "\n", + " filename = f\"{args.timestring}_{frame_idx:05}.png\"\n", + " image.save(os.path.join(args.outdir, filename))\n", + " frame_idx += 1\n", + "\n", + " display.clear_output(wait=True)\n", + " display.display(image)\n", + "\n", + " args.seed = next_seed(args)\n", + "\n", + " else:\n", + " for i in range(len(prompts_c_s)-1):\n", + " for j in range(anim_args.interpolate_x_frames+1):\n", + " # interpolate the text embedding\n", + " prompt1_c = prompts_c_s[i]\n", + " prompt2_c = prompts_c_s[i+1] \n", + " args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/(anim_args.interpolate_x_frames+1)))\n", + "\n", + " # sample the diffusion model\n", + " results = generate(args)\n", + " image = results[0]\n", + "\n", + " filename = f\"{args.timestring}_{frame_idx:05}.png\"\n", + " image.save(os.path.join(args.outdir, filename))\n", + " frame_idx += 1\n", + "\n", + " display.clear_output(wait=True)\n", + " display.display(image)\n", + "\n", + " args.seed = next_seed(args)\n", + "\n", + " # generate the last prompt\n", + " args.init_c = prompts_c_s[-1]\n", + " results = generate(args)\n", + " image = results[0]\n", + " filename = f\"{args.timestring}_{frame_idx:05}.png\"\n", + " image.save(os.path.join(args.outdir, filename))\n", + "\n", + " display.clear_output(wait=True)\n", + " display.display(image)\n", + " args.seed = next_seed(args)\n", + "\n", + " #clear init_c\n", + " args.init_c = None\n", + "\n", + "\n", + "args_dict = DeforumArgs()\n", + "anim_args_dict = DeforumAnimArgs()\n", + "\n", + "if override_settings_with_file:\n", + " print(f\"reading custom settings from {custom_settings_file}\")\n", + " if not os.path.isfile(custom_settings_file):\n", + " print('The custom settings file does not exist. The in-notebook settings will be used instead')\n", + " else:\n", + " with open(custom_settings_file, \"r\") as f:\n", + " jdata = json.loads(f.read())\n", + " animation_prompts = jdata[\"prompts\"]\n", + " for i, k in enumerate(args_dict):\n", + " if k in jdata:\n", + " args_dict[k] = jdata[k]\n", + " else:\n", + " print(f\"key {k} doesn't exist in the custom settings data! using the default value of {args_dict[k]}\")\n", + " for i, k in enumerate(anim_args_dict):\n", + " if k in jdata:\n", + " anim_args_dict[k] = jdata[k]\n", + " else:\n", + " print(f\"key {k} doesn't exist in the custom settings data! using the default value of {anim_args_dict[k]}\")\n", + " print(args_dict)\n", + " print(anim_args_dict)\n", + "\n", + "args = SimpleNamespace(**args_dict)\n", + "anim_args = SimpleNamespace(**anim_args_dict)\n", + "\n", + "args.timestring = time.strftime('%Y%m%d%H%M%S')\n", + "args.strength = max(0.0, min(1.0, args.strength))\n", + "\n", + "if args.seed == -1:\n", + " args.seed = random.randint(0, 2**32 - 1)\n", + "if not args.use_init:\n", + " args.init_image = None\n", + "if args.sampler == 'plms' and (args.use_init or anim_args.animation_mode != 'None'):\n", + " print(f\"Init images aren't supported with PLMS yet, switching to KLMS\")\n", + " args.sampler = 'klms'\n", + "if args.sampler != 'ddim':\n", + " args.ddim_eta = 0\n", + "\n", + "if anim_args.animation_mode == 'None':\n", + " anim_args.max_frames = 1\n", + "elif anim_args.animation_mode == 'Video Input':\n", + " args.use_init = True\n", + "\n", + "# clean up unused memory\n", + "gc.collect()\n", + "torch.cuda.empty_cache()\n", + "\n", + "# dispatch to appropriate renderer\n", + "if anim_args.animation_mode == '2D' or anim_args.animation_mode == '3D':\n", + " render_animation(args, anim_args)\n", + "elif anim_args.animation_mode == 'Video Input':\n", + " render_input_video(args, anim_args)\n", + "elif anim_args.animation_mode == 'Interpolation':\n", + " render_interpolation(args, anim_args)\n", + "else:\n", + " render_image_batch(args) " + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": { + "id": "4zV0J_YbMCTx" + }, + "source": [ + "# Create video from frames" + ] + }, + { + "cell_type": "code", + "metadata": { + "cellView": "form", + "id": "no2jP8HTMBM0" + }, + "source": [ + "skip_video_for_run_all = True #@param {type: 'boolean'}\n", + "fps = 12 #@param {type:\"number\"}\n", + "#@markdown **Manual Settings**\n", + "use_manual_settings = False #@param {type:\"boolean\"}\n", + "image_path = \"/content/drive/MyDrive/AI/StableDiffusion/2022-09/20220903000939_%05d.png\" #@param {type:\"string\"}\n", + "mp4_path = \"/content/drive/MyDrive/AI/StableDiffu'/content/drive/MyDrive/AI/StableDiffusion/2022-09/sion/2022-09/20220903000939.mp4\" #@param {type:\"string\"}\n", + "render_steps = True #@param {type: 'boolean'}\n", + "path_name_modifier = \"x0_pred\" #@param [\"x0_pred\",\"x\"]\n", + "\n", + "\n", + "if skip_video_for_run_all == True:\n", + " print('Skipping video creation, uncheck skip_video_for_run_all if you want to run it')\n", + "else:\n", + " import os\n", + " import subprocess\n", + " from base64 import b64encode\n", + "\n", + " print(f\"{image_path} -> {mp4_path}\")\n", + "\n", + " if use_manual_settings:\n", + " max_frames = \"200\" #@param {type:\"string\"}\n", + " else:\n", + " if render_steps: # render steps from a single image\n", + " fname = f\"{path_name_modifier}_%05d.png\"\n", + " all_step_dirs = [os.path.join(args.outdir, d) for d in os.listdir(args.outdir) if os.path.isdir(os.path.join(args.outdir,d))]\n", + " newest_dir = max(all_step_dirs, key=os.path.getmtime)\n", + " image_path = os.path.join(newest_dir, fname)\n", + " print(f\"Reading images from {image_path}\")\n", + " mp4_path = os.path.join(newest_dir, f\"{args.timestring}_{path_name_modifier}.mp4\")\n", + " max_frames = str(args.steps)\n", + " else: # render images for a video\n", + " image_path = os.path.join(args.outdir, f\"{args.timestring}_%05d.png\")\n", + " mp4_path = os.path.join(args.outdir, f\"{args.timestring}.mp4\")\n", + " max_frames = str(anim_args.max_frames)\n", + "\n", + " # make video\n", + " cmd = [\n", + " 'ffmpeg',\n", + " '-y',\n", + " '-vcodec', 'png',\n", + " '-r', str(fps),\n", + " '-start_number', str(0),\n", + " '-i', image_path,\n", + " '-frames:v', max_frames,\n", + " '-c:v', 'libx264',\n", + " '-vf',\n", + " f'fps={fps}',\n", + " '-pix_fmt', 'yuv420p',\n", + " '-crf', '17',\n", + " '-preset', 'veryfast',\n", + " '-pattern_type', 'sequence',\n", + " mp4_path\n", + " ]\n", + " process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)\n", + " stdout, stderr = process.communicate()\n", + " if process.returncode != 0:\n", + " print(stderr)\n", + " raise RuntimeError(stderr)\n", + "\n", + " mp4 = open(mp4_path,'rb').read()\n", + " data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n", + " display.display( display.HTML(f'') )" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "source": [ + "# Disconnect when finished" + ], + "metadata": { + "id": "RoECylTun7AA" + } + }, + { + "cell_type": "code", + "source": [ + "skip_disconnect_for_run_all = True #@param {type: 'boolean'}\n", + "\n", + "if skip_disconnect_for_run_all == True:\n", + " print('Skipping disconnect, uncheck skip_disconnect_for_run_all if you want to run it')\n", + "else:\n", + " from google.colab import runtime\n", + " runtime.unassign()" + ], + "metadata": { + "cellView": "form", + "id": "bfXpWRgSn-eH" + }, + "execution_count": null, + "outputs": [] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "provenance": [], + "private_outputs": true + }, + "gpuClass": "standard", + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/Deforum_Stable_Diffusion.py b/Deforum_Stable_Diffusion.py new file mode 100644 index 0000000000..d6aad9ffac --- /dev/null +++ b/Deforum_Stable_Diffusion.py @@ -0,0 +1,1924 @@ +# %% +# !! {"metadata":{ +# !! "id": "c442uQJ_gUgy" +# !! }} +""" +# **Deforum Stable Diffusion v0.5** +[Stable Diffusion](https://github.com/CompVis/stable-diffusion) by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer and the [Stability.ai](https://stability.ai/) Team. [K Diffusion](https://github.com/crowsonkb/k-diffusion) by [Katherine Crowson](https://twitter.com/RiversHaveWings). You need to get the ckpt file and put it on your Google Drive first to use this. It can be downloaded from [HuggingFace](https://huggingface.co/CompVis/stable-diffusion). + +Notebook by [deforum](https://discord.gg/upmXXsrwZc) +""" + +# %% +# !! {"metadata":{ +# !! "id": "LBamKxcmNI7-" +# !! }} +""" +By using this Notebook, you agree to the following Terms of Use, and license: + +**Stablity.AI Model Terms of Use** + +This model is open access and available to all, with a CreativeML OpenRAIL-M license further specifying rights and usage. + +The CreativeML OpenRAIL License specifies: + +You can't use the model to deliberately produce nor share illegal or harmful outputs or content +CompVis claims no rights on the outputs you generate, you are free to use them and are accountable for their use which must not go against the provisions set in the license +You may re-distribute the weights and use the model commercially and/or as a service. If you do, please be aware you have to include the same use restrictions as the ones in the license and share a copy of the CreativeML OpenRAIL-M to all your users (please read the license entirely and carefully) + + +Please read the full license here: https://huggingface.co/spaces/CompVis/stable-diffusion-license +""" + +# %% +# !! {"metadata":{ +# !! "id": "T4knibRpAQ06" +# !! }} +""" +# Setup +""" + +# %% +# !! {"metadata":{ +# !! "id": "2g-f7cQmf2Nt", +# !! "cellView": "form" +# !! }} +#@markdown **NVIDIA GPU** +import subprocess +sub_p_res = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total,memory.free', '--format=csv,noheader'], stdout=subprocess.PIPE).stdout.decode('utf-8') +print(sub_p_res) + +# %% +# !! {"metadata":{ +# !! "cellView": "form", +# !! "id": "TxIOPT0G5Lx1" +# !! }} +#@markdown **Model and Output Paths** +# ask for the link +print("Local Path Variables:\n") + +models_path = "/content/models" #@param {type:"string"} +output_path = "/content/output" #@param {type:"string"} + +#@markdown **Google Drive Path Variables (Optional)** +mount_google_drive = True #@param {type:"boolean"} +force_remount = False + +if mount_google_drive: + from google.colab import drive # type: ignore + try: + drive_path = "/content/drive" + drive.mount(drive_path,force_remount=force_remount) + models_path_gdrive = "/content/drive/MyDrive/AI/models" #@param {type:"string"} + output_path_gdrive = "/content/drive/MyDrive/AI/StableDiffusion" #@param {type:"string"} + models_path = models_path_gdrive + output_path = output_path_gdrive + except: + print("...error mounting drive or with drive path variables") + print("...reverting to default path variables") + +import os +os.makedirs(models_path, exist_ok=True) +os.makedirs(output_path, exist_ok=True) + +print(f"models_path: {models_path}") +print(f"output_path: {output_path}") + +# %% +# !! {"metadata":{ +# !! "id": "VRNl2mfepEIe", +# !! "cellView": "form" +# !! }} +#@markdown **Setup Environment** + +setup_environment = True #@param {type:"boolean"} +print_subprocess = False #@param {type:"boolean"} + +if setup_environment: + import subprocess, time + print("Setting up environment...") + start_time = time.time() + all_process = [ + ['pip', 'install', 'torch==1.12.1+cu113', 'torchvision==0.13.1+cu113', '--extra-index-url', 'https://download.pytorch.org/whl/cu113'], + ['pip', 'install', 'omegaconf==2.2.3', 'einops==0.4.1', 'pytorch-lightning==1.7.4', 'torchmetrics==0.9.3', 'torchtext==0.13.1', 'transformers==4.21.2', 'kornia==0.6.7'], + ['git', 'clone', 'https://github.com/deforum/stable-diffusion'], + ['pip', 'install', '-e', 'git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers'], + ['pip', 'install', '-e', 'git+https://github.com/openai/CLIP.git@main#egg=clip'], + ['pip', 'install', 'accelerate', 'ftfy', 'jsonmerge', 'matplotlib', 'resize-right', 'timm', 'torchdiffeq'], + ['git', 'clone', 'https://github.com/shariqfarooq123/AdaBins.git'], + ['git', 'clone', 'https://github.com/isl-org/MiDaS.git'], + ['git', 'clone', 'https://github.com/MSFTserver/pytorch3d-lite.git'], + ] + for process in all_process: + running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8') + if print_subprocess: + print(running) + + print(subprocess.run(['git', 'clone', 'https://github.com/deforum/k-diffusion/'], stdout=subprocess.PIPE).stdout.decode('utf-8')) + with open('k-diffusion/k_diffusion/__init__.py', 'w') as f: + f.write('') + + end_time = time.time() + print(f"Environment set up in {end_time-start_time:.0f} seconds") + +# %% +# !! {"metadata":{ +# !! "id": "81qmVZbrm4uu", +# !! "cellView": "form" +# !! }} +#@markdown **Python Definitions** +import json +from IPython import display + +import gc, math, os, pathlib, subprocess, sys, time +import cv2 +import numpy as np +import pandas as pd +import random +import requests +import torch +import torch.nn as nn +import torchvision.transforms as T +import torchvision.transforms.functional as TF +from contextlib import contextmanager, nullcontext +from einops import rearrange, repeat +from omegaconf import OmegaConf +from PIL import Image +from pytorch_lightning import seed_everything +from skimage.exposure import match_histograms +from torchvision.utils import make_grid +from tqdm import tqdm, trange +from types import SimpleNamespace +from torch import autocast +import re +from scipy.ndimage import gaussian_filter + +sys.path.extend([ + 'src/taming-transformers', + 'src/clip', + 'stable-diffusion/', + 'k-diffusion', + 'pytorch3d-lite', + 'AdaBins', + 'MiDaS', +]) + +import py3d_tools as p3d + +from helpers import DepthModel, sampler_fn +from k_diffusion.external import CompVisDenoiser +from ldm.util import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler + +def sanitize(prompt): + whitelist = set('abcdefghijklmnopqrstuvwxyz ABCDEFGHIJKLMNOPQRSTUVWXYZ') + tmp = ''.join(filter(whitelist.__contains__, prompt)) + return tmp.replace(' ', '_') + +from functools import reduce +def construct_RotationMatrixHomogenous(rotation_angles): + assert(type(rotation_angles)==list and len(rotation_angles)==3) + RH = np.eye(4,4) + cv2.Rodrigues(np.array(rotation_angles), RH[0:3, 0:3]) + return RH + +# https://en.wikipedia.org/wiki/Rotation_matrix +def getRotationMatrixManual(rotation_angles): + + rotation_angles = [np.deg2rad(x) for x in rotation_angles] + + phi = rotation_angles[0] # around x + gamma = rotation_angles[1] # around y + theta = rotation_angles[2] # around z + + # X rotation + Rphi = np.eye(4,4) + sp = np.sin(phi) + cp = np.cos(phi) + Rphi[1,1] = cp + Rphi[2,2] = Rphi[1,1] + Rphi[1,2] = -sp + Rphi[2,1] = sp + + # Y rotation + Rgamma = np.eye(4,4) + sg = np.sin(gamma) + cg = np.cos(gamma) + Rgamma[0,0] = cg + Rgamma[2,2] = Rgamma[0,0] + Rgamma[0,2] = sg + Rgamma[2,0] = -sg + + # Z rotation (in-image-plane) + Rtheta = np.eye(4,4) + st = np.sin(theta) + ct = np.cos(theta) + Rtheta[0,0] = ct + Rtheta[1,1] = Rtheta[0,0] + Rtheta[0,1] = -st + Rtheta[1,0] = st + + R = reduce(lambda x,y : np.matmul(x,y), [Rphi, Rgamma, Rtheta]) + + return R + + +def getPoints_for_PerspectiveTranformEstimation(ptsIn, ptsOut, W, H, sidelength): + + ptsIn2D = ptsIn[0,:] + ptsOut2D = ptsOut[0,:] + ptsOut2Dlist = [] + ptsIn2Dlist = [] + + for i in range(0,4): + ptsOut2Dlist.append([ptsOut2D[i,0], ptsOut2D[i,1]]) + ptsIn2Dlist.append([ptsIn2D[i,0], ptsIn2D[i,1]]) + + pin = np.array(ptsIn2Dlist) + [W/2.,H/2.] + pout = (np.array(ptsOut2Dlist) + [1.,1.]) * (0.5*sidelength) + pin = pin.astype(np.float32) + pout = pout.astype(np.float32) + + return pin, pout + +def warpMatrix(W, H, theta, phi, gamma, scale, fV): + + # M is to be estimated + M = np.eye(4, 4) + + fVhalf = np.deg2rad(fV/2.) + d = np.sqrt(W*W+H*H) + sideLength = scale*d/np.cos(fVhalf) + h = d/(2.0*np.sin(fVhalf)) + n = h-(d/2.0); + f = h+(d/2.0); + + # Translation along Z-axis by -h + T = np.eye(4,4) + T[2,3] = -h + + # Rotation matrices around x,y,z + R = getRotationMatrixManual([phi, gamma, theta]) + + + # Projection Matrix + P = np.eye(4,4) + P[0,0] = 1.0/np.tan(fVhalf) + P[1,1] = P[0,0] + P[2,2] = -(f+n)/(f-n) + P[2,3] = -(2.0*f*n)/(f-n) + P[3,2] = -1.0 + + # pythonic matrix multiplication + F = reduce(lambda x,y : np.matmul(x,y), [P, T, R]) + + # shape should be 1,4,3 for ptsIn and ptsOut since perspectiveTransform() expects data in this way. + # In C++, this can be achieved by Mat ptsIn(1,4,CV_64FC3); + ptsIn = np.array([[ + [-W/2., H/2., 0.],[ W/2., H/2., 0.],[ W/2.,-H/2., 0.],[-W/2.,-H/2., 0.] + ]]) + ptsOut = np.array(np.zeros((ptsIn.shape), dtype=ptsIn.dtype)) + ptsOut = cv2.perspectiveTransform(ptsIn, F) + + ptsInPt2f, ptsOutPt2f = getPoints_for_PerspectiveTranformEstimation(ptsIn, ptsOut, W, H, sideLength) + + # check float32 otherwise OpenCV throws an error + assert(ptsInPt2f.dtype == np.float32) + assert(ptsOutPt2f.dtype == np.float32) + M33 = cv2.getPerspectiveTransform(ptsInPt2f,ptsOutPt2f) + + return M33, sideLength + +def anim_frame_warp_2d(prev_img_cv2, args, anim_args, keys, frame_idx): + angle = keys.angle_series[frame_idx] + zoom = keys.zoom_series[frame_idx] + translation_x = keys.translation_x_series[frame_idx] + translation_y = keys.translation_y_series[frame_idx] + + center = (args.W // 2, args.H // 2) + trans_mat = np.float32([[1, 0, translation_x], [0, 1, translation_y]]) + rot_mat = cv2.getRotationMatrix2D(center, angle, zoom) + trans_mat = np.vstack([trans_mat, [0,0,1]]) + rot_mat = np.vstack([rot_mat, [0,0,1]]) + if anim_args.flip_2d_perspective: + perspective_flip_theta = keys.perspective_flip_theta_series[frame_idx] + perspective_flip_phi = keys.perspective_flip_phi_series[frame_idx] + perspective_flip_gamma = keys.perspective_flip_gamma_series[frame_idx] + perspective_flip_fv = keys.perspective_flip_fv_series[frame_idx] + M,sl = warpMatrix(args.W, args.H, perspective_flip_theta, perspective_flip_phi, perspective_flip_gamma, 1., perspective_flip_fv); + post_trans_mat = np.float32([[1, 0, (args.W-sl)/2], [0, 1, (args.H-sl)/2]]) + post_trans_mat = np.vstack([post_trans_mat, [0,0,1]]) + bM = np.matmul(M, post_trans_mat) + xform = np.matmul(bM, rot_mat, trans_mat) + else: + xform = np.matmul(rot_mat, trans_mat) + + return cv2.warpPerspective( + prev_img_cv2, + xform, + (prev_img_cv2.shape[1], prev_img_cv2.shape[0]), + borderMode=cv2.BORDER_WRAP if anim_args.border == 'wrap' else cv2.BORDER_REPLICATE + ) + +def anim_frame_warp_3d(prev_img_cv2, depth, anim_args, keys, frame_idx): + TRANSLATION_SCALE = 1.0/200.0 # matches Disco + translate_xyz = [ + -keys.translation_x_series[frame_idx] * TRANSLATION_SCALE, + keys.translation_y_series[frame_idx] * TRANSLATION_SCALE, + -keys.translation_z_series[frame_idx] * TRANSLATION_SCALE + ] + rotate_xyz = [ + math.radians(keys.rotation_3d_x_series[frame_idx]), + math.radians(keys.rotation_3d_y_series[frame_idx]), + math.radians(keys.rotation_3d_z_series[frame_idx]) + ] + rot_mat = p3d.euler_angles_to_matrix(torch.tensor(rotate_xyz, device=device), "XYZ").unsqueeze(0) + result = transform_image_3d(prev_img_cv2, depth, rot_mat, translate_xyz, anim_args) + torch.cuda.empty_cache() + return result + +def add_noise(sample: torch.Tensor, noise_amt: float) -> torch.Tensor: + return sample + torch.randn(sample.shape, device=sample.device) * noise_amt + +def get_output_folder(output_path, batch_folder): + out_path = os.path.join(output_path,time.strftime('%Y-%m')) + if batch_folder != "": + out_path = os.path.join(out_path, batch_folder) + os.makedirs(out_path, exist_ok=True) + return out_path + +def load_img(path, shape, use_alpha_as_mask=False): + # use_alpha_as_mask: Read the alpha channel of the image as the mask image + if path.startswith('http://') or path.startswith('https://'): + image = Image.open(requests.get(path, stream=True).raw) + else: + image = Image.open(path) + + if use_alpha_as_mask: + image = image.convert('RGBA') + else: + image = image.convert('RGB') + + image = image.resize(shape, resample=Image.LANCZOS) + + mask_image = None + if use_alpha_as_mask: + # Split alpha channel into a mask_image + red, green, blue, alpha = Image.Image.split(image) + mask_image = alpha.convert('L') + image = image.convert('RGB') + + image = np.array(image).astype(np.float16) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + image = 2.*image - 1. + + return image, mask_image + +def load_mask_latent(mask_input, shape): + # mask_input (str or PIL Image.Image): Path to the mask image or a PIL Image object + # shape (list-like len(4)): shape of the image to match, usually latent_image.shape + + if isinstance(mask_input, str): # mask input is probably a file name + if mask_input.startswith('http://') or mask_input.startswith('https://'): + mask_image = Image.open(requests.get(mask_input, stream=True).raw).convert('RGBA') + else: + mask_image = Image.open(mask_input).convert('RGBA') + elif isinstance(mask_input, Image.Image): + mask_image = mask_input + else: + raise Exception("mask_input must be a PIL image or a file name") + + mask_w_h = (shape[-1], shape[-2]) + mask = mask_image.resize(mask_w_h, resample=Image.LANCZOS) + mask = mask.convert("L") + return mask + +def prepare_mask(mask_input, mask_shape, mask_brightness_adjust=1.0, mask_contrast_adjust=1.0): + # mask_input (str or PIL Image.Image): Path to the mask image or a PIL Image object + # shape (list-like len(4)): shape of the image to match, usually latent_image.shape + # mask_brightness_adjust (non-negative float): amount to adjust brightness of the iamge, + # 0 is black, 1 is no adjustment, >1 is brighter + # mask_contrast_adjust (non-negative float): amount to adjust contrast of the image, + # 0 is a flat grey image, 1 is no adjustment, >1 is more contrast + + mask = load_mask_latent(mask_input, mask_shape) + + # Mask brightness/contrast adjustments + if mask_brightness_adjust != 1: + mask = TF.adjust_brightness(mask, mask_brightness_adjust) + if mask_contrast_adjust != 1: + mask = TF.adjust_contrast(mask, mask_contrast_adjust) + + # Mask image to array + mask = np.array(mask).astype(np.float32) / 255.0 + mask = np.tile(mask,(4,1,1)) + mask = np.expand_dims(mask,axis=0) + mask = torch.from_numpy(mask) + + if args.invert_mask: + mask = ( (mask - 0.5) * -1) + 0.5 + + mask = np.clip(mask,0,1) + return mask + +def maintain_colors(prev_img, color_match_sample, mode): + if mode == 'Match Frame 0 RGB': + return match_histograms(prev_img, color_match_sample, multichannel=True) + elif mode == 'Match Frame 0 HSV': + prev_img_hsv = cv2.cvtColor(prev_img, cv2.COLOR_RGB2HSV) + color_match_hsv = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2HSV) + matched_hsv = match_histograms(prev_img_hsv, color_match_hsv, multichannel=True) + return cv2.cvtColor(matched_hsv, cv2.COLOR_HSV2RGB) + else: # Match Frame 0 LAB + prev_img_lab = cv2.cvtColor(prev_img, cv2.COLOR_RGB2LAB) + color_match_lab = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2LAB) + matched_lab = match_histograms(prev_img_lab, color_match_lab, multichannel=True) + return cv2.cvtColor(matched_lab, cv2.COLOR_LAB2RGB) + + +# +# Callback functions +# +class SamplerCallback(object): + # Creates the callback function to be passed into the samplers for each step + def __init__(self, args, mask=None, init_latent=None, sigmas=None, sampler=None, + verbose=False): + self.sampler_name = args.sampler + self.dynamic_threshold = args.dynamic_threshold + self.static_threshold = args.static_threshold + self.mask = mask + self.init_latent = init_latent + self.sigmas = sigmas + self.sampler = sampler + self.verbose = verbose + + self.batch_size = args.n_samples + self.save_sample_per_step = args.save_sample_per_step + self.show_sample_per_step = args.show_sample_per_step + self.paths_to_image_steps = [os.path.join( args.outdir, f"{args.timestring}_{index:02}_{args.seed}") for index in range(args.n_samples) ] + + if self.save_sample_per_step: + for path in self.paths_to_image_steps: + os.makedirs(path, exist_ok=True) + + self.step_index = 0 + + self.noise = None + if init_latent is not None: + self.noise = torch.randn_like(init_latent, device=device) + + self.mask_schedule = None + if sigmas is not None and len(sigmas) > 0: + self.mask_schedule, _ = torch.sort(sigmas/torch.max(sigmas)) + elif len(sigmas) == 0: + self.mask = None # no mask needed if no steps (usually happens because strength==1.0) + + if self.sampler_name in ["plms","ddim"]: + if mask is not None: + assert sampler is not None, "Callback function for stable-diffusion samplers requires sampler variable" + + if self.sampler_name in ["plms","ddim"]: + # Callback function formated for compvis latent diffusion samplers + self.callback = self.img_callback_ + else: + # Default callback function uses k-diffusion sampler variables + self.callback = self.k_callback_ + + self.verbose_print = print if verbose else lambda *args, **kwargs: None + + def view_sample_step(self, latents, path_name_modifier=''): + if self.save_sample_per_step or self.show_sample_per_step: + samples = model.decode_first_stage(latents) + if self.save_sample_per_step: + fname = f'{path_name_modifier}_{self.step_index:05}.png' + for i, sample in enumerate(samples): + sample = sample.double().cpu().add(1).div(2).clamp(0, 1) + sample = torch.tensor(np.array(sample)) + grid = make_grid(sample, 4).cpu() + TF.to_pil_image(grid).save(os.path.join(self.paths_to_image_steps[i], fname)) + if self.show_sample_per_step: + print(path_name_modifier) + self.display_images(samples) + return + + def display_images(self, images): + images = images.double().cpu().add(1).div(2).clamp(0, 1) + images = torch.tensor(np.array(images)) + grid = make_grid(images, 4).cpu() + display.display(TF.to_pil_image(grid)) + return + + # The callback function is applied to the image at each step + def dynamic_thresholding_(self, img, threshold): + # Dynamic thresholding from Imagen paper (May 2022) + s = np.percentile(np.abs(img.cpu()), threshold, axis=tuple(range(1,img.ndim))) + s = np.max(np.append(s,1.0)) + torch.clamp_(img, -1*s, s) + torch.FloatTensor.div_(img, s) + + # Callback for samplers in the k-diffusion repo, called thus: + # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + def k_callback_(self, args_dict): + self.step_index = args_dict['i'] + if self.dynamic_threshold is not None: + self.dynamic_thresholding_(args_dict['x'], self.dynamic_threshold) + if self.static_threshold is not None: + torch.clamp_(args_dict['x'], -1*self.static_threshold, self.static_threshold) + if self.mask is not None: + init_noise = self.init_latent + self.noise * args_dict['sigma'] + is_masked = torch.logical_and(self.mask >= self.mask_schedule[args_dict['i']], self.mask != 0 ) + new_img = init_noise * torch.where(is_masked,1,0) + args_dict['x'] * torch.where(is_masked,0,1) + args_dict['x'].copy_(new_img) + + self.view_sample_step(args_dict['denoised'], "x0_pred") + + # Callback for Compvis samplers + # Function that is called on the image (img) and step (i) at each step + def img_callback_(self, img, i): + self.step_index = i + # Thresholding functions + if self.dynamic_threshold is not None: + self.dynamic_thresholding_(img, self.dynamic_threshold) + if self.static_threshold is not None: + torch.clamp_(img, -1*self.static_threshold, self.static_threshold) + if self.mask is not None: + i_inv = len(self.sigmas) - i - 1 + init_noise = self.sampler.stochastic_encode(self.init_latent, torch.tensor([i_inv]*self.batch_size).to(device), noise=self.noise) + is_masked = torch.logical_and(self.mask >= self.mask_schedule[i], self.mask != 0 ) + new_img = init_noise * torch.where(is_masked,1,0) + img * torch.where(is_masked,0,1) + img.copy_(new_img) + + self.view_sample_step(img, "x") + +def sample_from_cv2(sample: np.ndarray) -> torch.Tensor: + sample = ((sample.astype(float) / 255.0) * 2) - 1 + sample = sample[None].transpose(0, 3, 1, 2).astype(np.float16) + sample = torch.from_numpy(sample) + return sample + +def sample_to_cv2(sample: torch.Tensor, type=np.uint8) -> np.ndarray: + sample_f32 = rearrange(sample.squeeze().cpu().numpy(), "c h w -> h w c").astype(np.float32) + sample_f32 = ((sample_f32 * 0.5) + 0.5).clip(0, 1) + sample_int8 = (sample_f32 * 255) + return sample_int8.astype(type) + +def transform_image_3d(prev_img_cv2, depth_tensor, rot_mat, translate, anim_args): + # adapted and optimized version of transform_image_3d from Disco Diffusion https://github.com/alembics/disco-diffusion + w, h = prev_img_cv2.shape[1], prev_img_cv2.shape[0] + + aspect_ratio = float(w)/float(h) + near, far, fov_deg = anim_args.near_plane, anim_args.far_plane, anim_args.fov + persp_cam_old = p3d.FoVPerspectiveCameras(near, far, aspect_ratio, fov=fov_deg, degrees=True, device=device) + persp_cam_new = p3d.FoVPerspectiveCameras(near, far, aspect_ratio, fov=fov_deg, degrees=True, R=rot_mat, T=torch.tensor([translate]), device=device) + + # range of [-1,1] is important to torch grid_sample's padding handling + y,x = torch.meshgrid(torch.linspace(-1.,1.,h,dtype=torch.float32,device=device),torch.linspace(-1.,1.,w,dtype=torch.float32,device=device)) + z = torch.as_tensor(depth_tensor, dtype=torch.float32, device=device) + xyz_old_world = torch.stack((x.flatten(), y.flatten(), z.flatten()), dim=1) + + xyz_old_cam_xy = persp_cam_old.get_full_projection_transform().transform_points(xyz_old_world)[:,0:2] + xyz_new_cam_xy = persp_cam_new.get_full_projection_transform().transform_points(xyz_old_world)[:,0:2] + + offset_xy = xyz_new_cam_xy - xyz_old_cam_xy + # affine_grid theta param expects a batch of 2D mats. Each is 2x3 to do rotation+translation. + identity_2d_batch = torch.tensor([[1.,0.,0.],[0.,1.,0.]], device=device).unsqueeze(0) + # coords_2d will have shape (N,H,W,2).. which is also what grid_sample needs. + coords_2d = torch.nn.functional.affine_grid(identity_2d_batch, [1,1,h,w], align_corners=False) + offset_coords_2d = coords_2d - torch.reshape(offset_xy, (h,w,2)).unsqueeze(0) + + image_tensor = rearrange(torch.from_numpy(prev_img_cv2.astype(np.float32)), 'h w c -> c h w').to(device) + new_image = torch.nn.functional.grid_sample( + image_tensor.add(1/512 - 0.0001).unsqueeze(0), + offset_coords_2d, + mode=anim_args.sampling_mode, + padding_mode=anim_args.padding_mode, + align_corners=False + ) + + # convert back to cv2 style numpy array + result = rearrange( + new_image.squeeze().clamp(0,255), + 'c h w -> h w c' + ).cpu().numpy().astype(prev_img_cv2.dtype) + return result + +def check_is_number(value): + float_pattern = r'^(?=.)([+-]?([0-9]*)(\.([0-9]+))?)$' + return re.match(float_pattern, value) + +# prompt weighting with colons and number coefficients (like 'bacon:0.75 eggs:0.25') +# borrowed from https://github.com/kylewlacy/stable-diffusion/blob/0a4397094eb6e875f98f9d71193e350d859c4220/ldm/dream/conditioning.py +# and https://github.com/raefu/stable-diffusion-automatic/blob/unstablediffusion/modules/processing.py +def get_uc_and_c(prompts, model, args, frame = 0): + prompt = prompts[0] # they are the same in a batch anyway + + # get weighted sub-prompts + negative_subprompts, positive_subprompts = split_weighted_subprompts( + prompt, frame, not args.normalize_prompt_weights + ) + + uc = get_learned_conditioning(model, negative_subprompts, "", args, -1) + c = get_learned_conditioning(model, positive_subprompts, prompt, args, 1) + + return (uc, c) + +def get_learned_conditioning(model, weighted_subprompts, text, args, sign = 1): + if len(weighted_subprompts) < 1: + log_tokenization(text, model, args.log_weighted_subprompts, sign) + c = model.get_learned_conditioning(args.n_samples * [text]) + else: + c = None + for subtext, subweight in weighted_subprompts: + log_tokenization(subtext, model, args.log_weighted_subprompts, sign * subweight) + if c is None: + c = model.get_learned_conditioning(args.n_samples * [subtext]) + c *= subweight + else: + c.add_(model.get_learned_conditioning(args.n_samples * [subtext]), alpha=subweight) + + return c + +def parse_weight(match, frame = 0)->float: + import numexpr + w_raw = match.group("weight") + if w_raw == None: + return 1 + if check_is_number(w_raw): + return float(w_raw) + else: + t = frame + if len(w_raw) < 3: + print('the value inside `-characters cannot represent a math function') + return 1 + return float(numexpr.evaluate(w_raw[1:-1])) + +def normalize_prompt_weights(parsed_prompts): + if len(parsed_prompts) == 0: + return parsed_prompts + weight_sum = sum(map(lambda x: x[1], parsed_prompts)) + if weight_sum == 0: + print( + "Warning: Subprompt weights add up to zero. Discarding and using even weights instead.") + equal_weight = 1 / max(len(parsed_prompts), 1) + return [(x[0], equal_weight) for x in parsed_prompts] + return [(x[0], x[1] / weight_sum) for x in parsed_prompts] + +def split_weighted_subprompts(text, frame = 0, skip_normalize=False): + """ + grabs all text up to the first occurrence of ':' + uses the grabbed text as a sub-prompt, and takes the value following ':' as weight + if ':' has no value defined, defaults to 1.0 + repeats until no text remaining + """ + prompt_parser = re.compile(""" + (?P # capture group for 'prompt' + (?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:' + ) # end 'prompt' + (?: # non-capture group + :+ # match one or more ':' characters + (?P(( # capture group for 'weight' + -?\d+(?:\.\d+)? # match positive or negative integer or decimal number + )|( # or + `[\S\s]*?`# a math function + )))? # end weight capture group, make optional + \s* # strip spaces after weight + | # OR + $ # else, if no ':' then match end of line + ) # end non-capture group + """, re.VERBOSE) + negative_prompts = [] + positive_prompts = [] + for match in re.finditer(prompt_parser, text): + w = parse_weight(match, frame) + if w < 0: + # negating the sign as we'll feed this to uc + negative_prompts.append((match.group("prompt").replace("\\:", ":"), -w)) + elif w > 0: + positive_prompts.append((match.group("prompt").replace("\\:", ":"), w)) + + if skip_normalize: + return (negative_prompts, positive_prompts) + return (normalize_prompt_weights(negative_prompts), normalize_prompt_weights(positive_prompts)) + +# shows how the prompt is tokenized +# usually tokens have '' to indicate end-of-word, +# but for readability it has been replaced with ' ' +def log_tokenization(text, model, log=False, weight=1): + if not log: + return + tokens = model.cond_stage_model.tokenizer._tokenize(text) + tokenized = "" + discarded = "" + usedTokens = 0 + totalTokens = len(tokens) + for i in range(0, totalTokens): + token = tokens[i].replace('', ' ') + # alternate color + s = (usedTokens % 6) + 1 + if i < model.cond_stage_model.max_length: + tokenized = tokenized + f"\x1b[0;3{s};40m{token}" + usedTokens += 1 + else: # over max token length + discarded = discarded + f"\x1b[0;3{s};40m{token}" + print(f"\n>> Tokens ({usedTokens}), Weight ({weight:.2f}):\n{tokenized}\x1b[0m") + if discarded != "": + print( + f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m" + ) + +def generate(args, frame = 0, return_latent=False, return_sample=False, return_c=False): + seed_everything(args.seed) + os.makedirs(args.outdir, exist_ok=True) + + sampler = PLMSSampler(model) if args.sampler == 'plms' else DDIMSampler(model) + model_wrap = CompVisDenoiser(model) + batch_size = args.n_samples + prompt = args.prompt + assert prompt is not None + data = [batch_size * [prompt]] + precision_scope = autocast if args.precision == "autocast" else nullcontext + + init_latent = None + mask_image = None + init_image = None + if args.init_latent is not None: + init_latent = args.init_latent + elif args.init_sample is not None: + with precision_scope("cuda"): + init_latent = model.get_first_stage_encoding(model.encode_first_stage(args.init_sample)) + elif args.use_init and args.init_image != None and args.init_image != '': + init_image, mask_image = load_img(args.init_image, + shape=(args.W, args.H), + use_alpha_as_mask=args.use_alpha_as_mask) + init_image = init_image.to(device) + init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) + with precision_scope("cuda"): + init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space + + if not args.use_init and args.strength > 0 and args.strength_0_no_init: + print("\nNo init image, but strength > 0. Strength has been auto set to 0, since use_init is False.") + print("If you want to force strength > 0 with no init, please set strength_0_no_init to False.\n") + args.strength = 0 + + # Mask functions + if args.use_mask: + assert args.mask_file is not None or mask_image is not None, "use_mask==True: An mask image is required for a mask. Please enter a mask_file or use an init image with an alpha channel" + assert args.use_init, "use_mask==True: use_init is required for a mask" + assert init_latent is not None, "use_mask==True: An latent init image is required for a mask" + + + mask = prepare_mask(args.mask_file if mask_image is None else mask_image, + init_latent.shape, + args.mask_contrast_adjust, + args.mask_brightness_adjust) + + if (torch.all(mask == 0) or torch.all(mask == 1)) and args.use_alpha_as_mask: + raise Warning("use_alpha_as_mask==True: Using the alpha channel from the init image as a mask, but the alpha channel is blank.") + + mask = mask.to(device) + mask = repeat(mask, '1 ... -> b ...', b=batch_size) + else: + mask = None + + assert not ( (args.use_mask and args.overlay_mask) and (args.init_sample is None and init_image is None)), "Need an init image when use_mask == True and overlay_mask == True" + + t_enc = int((1.0-args.strength) * args.steps) + + # Noise schedule for the k-diffusion samplers (used for masking) + k_sigmas = model_wrap.get_sigmas(args.steps) + k_sigmas = k_sigmas[len(k_sigmas)-t_enc-1:] + + if args.sampler in ['plms','ddim']: + sampler.make_schedule(ddim_num_steps=args.steps, ddim_eta=args.ddim_eta, ddim_discretize='fill', verbose=False) + + callback = SamplerCallback(args=args, + mask=mask, + init_latent=init_latent, + sigmas=k_sigmas, + sampler=sampler, + verbose=False).callback + + results = [] + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): + for prompts in data: + if isinstance(prompts, tuple): + prompts = list(prompts) + if args.prompt_weighting: + uc, c = get_uc_and_c(prompts, model, args, frame) + else: + uc = model.get_learned_conditioning(batch_size * [""]) + c = model.get_learned_conditioning(prompts) + + + if args.scale == 1.0: + uc = None + if args.init_c != None: + c = args.init_c + + if args.sampler in ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral"]: + samples = sampler_fn( + c=c, + uc=uc, + args=args, + model_wrap=model_wrap, + init_latent=init_latent, + t_enc=t_enc, + device=device, + cb=callback) + else: + # args.sampler == 'plms' or args.sampler == 'ddim': + if init_latent is not None and args.strength > 0: + z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device)) + else: + z_enc = torch.randn([args.n_samples, args.C, args.H // args.f, args.W // args.f], device=device) + if args.sampler == 'ddim': + samples = sampler.decode(z_enc, + c, + t_enc, + unconditional_guidance_scale=args.scale, + unconditional_conditioning=uc, + img_callback=callback) + elif args.sampler == 'plms': # no "decode" function in plms, so use "sample" + shape = [args.C, args.H // args.f, args.W // args.f] + samples, _ = sampler.sample(S=args.steps, + conditioning=c, + batch_size=args.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=args.scale, + unconditional_conditioning=uc, + eta=args.ddim_eta, + x_T=z_enc, + img_callback=callback) + else: + raise Exception(f"Sampler {args.sampler} not recognised.") + + + if return_latent: + results.append(samples.clone()) + + x_samples = model.decode_first_stage(samples) + + if args.use_mask and args.overlay_mask: + # Overlay the masked image after the image is generated + if args.init_sample is not None: + img_original = args.init_sample + elif init_image is not None: + img_original = init_image + else: + raise Exception("Cannot overlay the masked image without an init image to overlay") + + mask_fullres = prepare_mask(args.mask_file if mask_image is None else mask_image, + img_original.shape, + args.mask_contrast_adjust, + args.mask_brightness_adjust) + mask_fullres = mask_fullres[:,:3,:,:] + mask_fullres = repeat(mask_fullres, '1 ... -> b ...', b=batch_size) + + mask_fullres[mask_fullres < mask_fullres.max()] = 0 + mask_fullres = gaussian_filter(mask_fullres, args.mask_overlay_blur) + mask_fullres = torch.Tensor(mask_fullres).to(device) + + x_samples = img_original * mask_fullres + x_samples * ((mask_fullres * -1.0) + 1) + + + if return_sample: + results.append(x_samples.clone()) + + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + + if return_c: + results.append(c.clone()) + + for x_sample in x_samples: + x_sample = 255. * rearrange(x_sample.cpu().numpy(), 'c h w -> h w c') + image = Image.fromarray(x_sample.astype(np.uint8)) + results.append(image) + return results + +# %% +# !! {"metadata":{ +# !! "cellView": "form", +# !! "id": "CIUJ7lWI4v53" +# !! }} +#@markdown **Select and Load Model** + +model_config = "v1-inference.yaml" #@param ["custom","v1-inference.yaml"] +model_checkpoint = "sd-v1-4.ckpt" #@param ["custom","sd-v1-4-full-ema.ckpt","sd-v1-4.ckpt","sd-v1-3-full-ema.ckpt","sd-v1-3.ckpt","sd-v1-2-full-ema.ckpt","sd-v1-2.ckpt","sd-v1-1-full-ema.ckpt","sd-v1-1.ckpt", "robo-diffusion-v1.ckpt","waifu-diffusion-v1-3.ckpt"] +if model_checkpoint == "waifu-diffusion-v1-3.ckpt": + model_checkpoint = "model-epoch05-float16.ckpt" +custom_config_path = "" #@param {type:"string"} +custom_checkpoint_path = "" #@param {type:"string"} + +load_on_run_all = True #@param {type: 'boolean'} +half_precision = True # check +check_sha256 = True #@param {type:"boolean"} + +model_map = { + "sd-v1-4-full-ema.ckpt": { + 'sha256': '14749efc0ae8ef0329391ad4436feb781b402f4fece4883c7ad8d10556d8a36a', + 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/blob/main/sd-v1-4-full-ema.ckpt', + 'requires_login': True, + }, + "sd-v1-4.ckpt": { + 'sha256': 'fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556', + 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-4-original/resolve/main/sd-v1-4.ckpt', + 'requires_login': True, + }, + "sd-v1-3-full-ema.ckpt": { + 'sha256': '54632c6e8a36eecae65e36cb0595fab314e1a1545a65209f24fde221a8d4b2ca', + 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-3-original/blob/main/sd-v1-3-full-ema.ckpt', + 'requires_login': True, + }, + "sd-v1-3.ckpt": { + 'sha256': '2cff93af4dcc07c3e03110205988ff98481e86539c51a8098d4f2236e41f7f2f', + 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-3-original/resolve/main/sd-v1-3.ckpt', + 'requires_login': True, + }, + "sd-v1-2-full-ema.ckpt": { + 'sha256': 'bc5086a904d7b9d13d2a7bccf38f089824755be7261c7399d92e555e1e9ac69a', + 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/blob/main/sd-v1-2-full-ema.ckpt', + 'requires_login': True, + }, + "sd-v1-2.ckpt": { + 'sha256': '3b87d30facd5bafca1cbed71cfb86648aad75d1c264663c0cc78c7aea8daec0d', + 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-2-original/resolve/main/sd-v1-2.ckpt', + 'requires_login': True, + }, + "sd-v1-1-full-ema.ckpt": { + 'sha256': 'efdeb5dc418a025d9a8cc0a8617e106c69044bc2925abecc8a254b2910d69829', + 'url':'https://huggingface.co/CompVis/stable-diffusion-v-1-1-original/resolve/main/sd-v1-1-full-ema.ckpt', + 'requires_login': True, + }, + "sd-v1-1.ckpt": { + 'sha256': '86cd1d3ccb044d7ba8db743d717c9bac603c4043508ad2571383f954390f3cea', + 'url': 'https://huggingface.co/CompVis/stable-diffusion-v-1-1-original/resolve/main/sd-v1-1.ckpt', + 'requires_login': True, + }, + "robo-diffusion-v1.ckpt": { + 'sha256': '244dbe0dcb55c761bde9c2ac0e9b46cc9705ebfe5f1f3a7cc46251573ea14e16', + 'url': 'https://huggingface.co/nousr/robo-diffusion/resolve/main/models/robo-diffusion-v1.ckpt', + 'requires_login': False, + }, + "model-epoch05-float16.ckpt": { + 'sha256': '26cf2a2e30095926bb9fd9de0c83f47adc0b442dbfdc3d667d43778e8b70bece', + 'url': 'https://huggingface.co/hakurei/waifu-diffusion-v1-3/resolve/main/model-epoch05-float16.ckpt', + 'requires_login': False, + }, +} + +# config path +ckpt_config_path = custom_config_path if model_config == "custom" else os.path.join(models_path, model_config) +if os.path.exists(ckpt_config_path): + print(f"{ckpt_config_path} exists") +else: + ckpt_config_path = "./stable-diffusion/configs/stable-diffusion/v1-inference.yaml" +print(f"Using config: {ckpt_config_path}") + +# checkpoint path or download +ckpt_path = custom_checkpoint_path if model_checkpoint == "custom" else os.path.join(models_path, model_checkpoint) +ckpt_valid = True +if os.path.exists(ckpt_path): + print(f"{ckpt_path} exists") +elif 'url' in model_map[model_checkpoint]: + url = model_map[model_checkpoint]['url'] + + # CLI dialogue to authenticate download + if model_map[model_checkpoint]['requires_login']: + print("This model requires an authentication token") + print("Please ensure you have accepted its terms of service before continuing.") + + username = input("What is your huggingface username?:") + token = input("What is your huggingface token?:") + + _, path = url.split("https://") + + url = f"https://{username}:{token}@{path}" + + # contact server for model + print(f"Attempting to download {model_checkpoint}...this may take a while") + ckpt_request = requests.get(url) + request_status = ckpt_request.status_code + + # inform user of errors + if request_status == 403: + raise ConnectionRefusedError("You have not accepted the license for this model.") + elif request_status == 404: + raise ConnectionError("Could not make contact with server") + elif request_status != 200: + raise ConnectionError(f"Some other error has ocurred - response code: {request_status}") + + # write to model path + with open(os.path.join(models_path, model_checkpoint), 'wb') as model_file: + model_file.write(ckpt_request.content) +else: + print(f"Please download model checkpoint and place in {os.path.join(models_path, model_checkpoint)}") + ckpt_valid = False + +if check_sha256 and model_checkpoint != "custom" and ckpt_valid: + import hashlib + print("\n...checking sha256") + with open(ckpt_path, "rb") as f: + bytes = f.read() + hash = hashlib.sha256(bytes).hexdigest() + del bytes + if model_map[model_checkpoint]["sha256"] == hash: + print("hash is correct\n") + else: + print("hash in not correct\n") + ckpt_valid = False + +if ckpt_valid: + print(f"Using ckpt: {ckpt_path}") + +def load_model_from_config(config, ckpt, verbose=False, device='cuda', half_precision=True): + map_location = "cuda" #@param ["cpu", "cuda"] + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location=map_location) + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + if half_precision: + model = model.half().to(device) + else: + model = model.to(device) + model.eval() + return model + +if load_on_run_all and ckpt_valid: + local_config = OmegaConf.load(f"{ckpt_config_path}") + model = load_model_from_config(local_config, f"{ckpt_path}", half_precision=half_precision) + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + +# %% +# !! {"metadata":{ +# !! "id": "ov3r4RD1tzsT" +# !! }} +""" +# Settings +""" + +# %% +# !! {"metadata":{ +# !! "id": "0j7rgxvLvfay" +# !! }} +""" +### Animation Settings +""" + +# %% +# !! {"metadata":{ +# !! "cellView": "form", +# !! "id": "8HJN2TE3vh-J" +# !! }} + +def DeforumAnimArgs(): + + #@markdown ####**Animation:** + animation_mode = 'None' #@param ['None', '2D', '3D', 'Video Input', 'Interpolation'] {type:'string'} + max_frames = 1000 #@param {type:"number"} + border = 'replicate' #@param ['wrap', 'replicate'] {type:'string'} + + #@markdown ####**Motion Parameters:** + angle = "0:(0)"#@param {type:"string"} + zoom = "0:(1.04)"#@param {type:"string"} + translation_x = "0:(10*sin(2*3.14*t/10))"#@param {type:"string"} + translation_y = "0:(0)"#@param {type:"string"} + translation_z = "0:(10)"#@param {type:"string"} + rotation_3d_x = "0:(0)"#@param {type:"string"} + rotation_3d_y = "0:(0)"#@param {type:"string"} + rotation_3d_z = "0:(0)"#@param {type:"string"} + flip_2d_perspective = False #@param {type:"boolean"} + perspective_flip_theta = "0:(0)"#@param {type:"string"} + perspective_flip_phi = "0:(t%15)"#@param {type:"string"} + perspective_flip_gamma = "0:(0)"#@param {type:"string"} + perspective_flip_fv = "0:(53)"#@param {type:"string"} + noise_schedule = "0: (0.02)"#@param {type:"string"} + strength_schedule = "0: (0.65)"#@param {type:"string"} + contrast_schedule = "0: (1.0)"#@param {type:"string"} + + #@markdown ####**Coherence:** + color_coherence = 'Match Frame 0 LAB' #@param ['None', 'Match Frame 0 HSV', 'Match Frame 0 LAB', 'Match Frame 0 RGB'] {type:'string'} + diffusion_cadence = '1' #@param ['1','2','3','4','5','6','7','8'] {type:'string'} + + #@markdown ####**3D Depth Warping:** + use_depth_warping = True #@param {type:"boolean"} + midas_weight = 0.3#@param {type:"number"} + near_plane = 200 + far_plane = 10000 + fov = 40#@param {type:"number"} + padding_mode = 'border'#@param ['border', 'reflection', 'zeros'] {type:'string'} + sampling_mode = 'bicubic'#@param ['bicubic', 'bilinear', 'nearest'] {type:'string'} + save_depth_maps = False #@param {type:"boolean"} + + #@markdown ####**Video Input:** + video_init_path ='/content/video_in.mp4'#@param {type:"string"} + extract_nth_frame = 1#@param {type:"number"} + overwrite_extracted_frames = True #@param {type:"boolean"} + use_mask_video = False #@param {type:"boolean"} + video_mask_path ='/content/video_in.mp4'#@param {type:"string"} + + #@markdown ####**Interpolation:** + interpolate_key_frames = False #@param {type:"boolean"} + interpolate_x_frames = 4 #@param {type:"number"} + + #@markdown ####**Resume Animation:** + resume_from_timestring = False #@param {type:"boolean"} + resume_timestring = "20220829210106" #@param {type:"string"} + + return locals() + +class DeformAnimKeys(): + def __init__(self, anim_args): + self.angle_series = get_inbetweens(parse_key_frames(anim_args.angle), anim_args.max_frames) + self.zoom_series = get_inbetweens(parse_key_frames(anim_args.zoom), anim_args.max_frames) + self.translation_x_series = get_inbetweens(parse_key_frames(anim_args.translation_x), anim_args.max_frames) + self.translation_y_series = get_inbetweens(parse_key_frames(anim_args.translation_y), anim_args.max_frames) + self.translation_z_series = get_inbetweens(parse_key_frames(anim_args.translation_z), anim_args.max_frames) + self.rotation_3d_x_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_x), anim_args.max_frames) + self.rotation_3d_y_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_y), anim_args.max_frames) + self.rotation_3d_z_series = get_inbetweens(parse_key_frames(anim_args.rotation_3d_z), anim_args.max_frames) + self.perspective_flip_theta_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_theta), anim_args.max_frames) + self.perspective_flip_phi_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_phi), anim_args.max_frames) + self.perspective_flip_gamma_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_gamma), anim_args.max_frames) + self.perspective_flip_fv_series = get_inbetweens(parse_key_frames(anim_args.perspective_flip_fv), anim_args.max_frames) + self.noise_schedule_series = get_inbetweens(parse_key_frames(anim_args.noise_schedule), anim_args.max_frames) + self.strength_schedule_series = get_inbetweens(parse_key_frames(anim_args.strength_schedule), anim_args.max_frames) + self.contrast_schedule_series = get_inbetweens(parse_key_frames(anim_args.contrast_schedule), anim_args.max_frames) + + +def get_inbetweens(key_frames, max_frames, integer=False, interp_method='Linear'): + import numexpr + key_frame_series = pd.Series([np.nan for a in range(max_frames)]) + + for i in range(0, max_frames): + if i in key_frames: + value = key_frames[i] + value_is_number = check_is_number(value) + # if it's only a number, leave the rest for the default interpolation + if value_is_number: + t = i + key_frame_series[i] = value + if not value_is_number: + t = i + key_frame_series[i] = numexpr.evaluate(value) + key_frame_series = key_frame_series.astype(float) + + if interp_method == 'Cubic' and len(key_frames.items()) <= 3: + interp_method = 'Quadratic' + if interp_method == 'Quadratic' and len(key_frames.items()) <= 2: + interp_method = 'Linear' + + key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()] + key_frame_series[max_frames-1] = key_frame_series[key_frame_series.last_valid_index()] + key_frame_series = key_frame_series.interpolate(method=interp_method.lower(), limit_direction='both') + if integer: + return key_frame_series.astype(int) + return key_frame_series + +def parse_key_frames(string, prompt_parser=None): + # because math functions (i.e. sin(t)) can utilize brackets + # it extracts the value in form of some stuff + # which has previously been enclosed with brackets and + # with a comma or end of line existing after the closing one + pattern = r'((?P[0-9]+):[\s]*\((?P[\S\s]*?)\)([,][\s]?|[\s]?$))' + frames = dict() + for match_object in re.finditer(pattern, string): + frame = int(match_object.groupdict()['frame']) + param = match_object.groupdict()['param'] + if prompt_parser: + frames[frame] = prompt_parser(param) + else: + frames[frame] = param + if frames == {} and len(string) != 0: + raise RuntimeError('Key Frame string not correctly formatted') + return frames + +# %% +# !! {"metadata":{ +# !! "id": "63UOJvU3xdPS" +# !! }} +""" +### Prompts +`animation_mode: None` batches on list of *prompts*. `animation_mode: 2D` uses *animation_prompts* key frame sequence +""" + +# %% +# !! {"metadata":{ +# !! "id": "2ujwkGZTcGev" +# !! }} + +prompts = [ + "a beautiful forest by Asher Brown Durand, trending on Artstation", # the first prompt I want + "a beautiful portrait of a woman by Artgerm, trending on Artstation", # the second prompt I want + #"this prompt I don't want it I commented it out", + #"a nousr robot, trending on Artstation", # use "nousr robot" with the robot diffusion model (see model_checkpoint setting) + #"touhou 1girl komeiji_koishi portrait, green hair", # waifu diffusion prompts can use danbooru tag groups (see model_checkpoint) + #"this prompt has weights if prompt weighting enabled:2 can also do negative:-2", # (see prompt_weighting) +] + +animation_prompts = { + 0: "a beautiful apple, trending on Artstation", + 20: "a beautiful banana, trending on Artstation", + 30: "a beautiful coconut, trending on Artstation", + 40: "a beautiful durian, trending on Artstation", +} + +# %% +# !! {"metadata":{ +# !! "id": "s8RAo2zI-vQm" +# !! }} +""" +# Run +""" + +# %% +# !! {"metadata":{ +# !! "id": "qH74gBWDd2oq", +# !! "cellView": "form" +# !! }} +#@markdown **Load Settings** +override_settings_with_file = False #@param {type:"boolean"} +custom_settings_file = "/content/drive/MyDrive/Settings.txt"#@param {type:"string"} + +def DeforumArgs(): + #@markdown **Image Settings** + W = 512 #@param + H = 512 #@param + W, H = map(lambda x: x - x % 64, (W, H)) # resize to integer multiple of 64 + + #@markdown **Sampling Settings** + seed = -1 #@param + sampler = 'klms' #@param ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral","plms", "ddim"] + steps = 50 #@param + scale = 7 #@param + ddim_eta = 0.0 #@param + dynamic_threshold = None + static_threshold = None + + #@markdown **Save & Display Settings** + save_samples = True #@param {type:"boolean"} + save_settings = True #@param {type:"boolean"} + display_samples = True #@param {type:"boolean"} + save_sample_per_step = False #@param {type:"boolean"} + show_sample_per_step = False #@param {type:"boolean"} + + #@markdown **Prompt Settings** + prompt_weighting = False #@param {type:"boolean"} + normalize_prompt_weights = True #@param {type:"boolean"} + log_weighted_subprompts = False #@param {type:"boolean"} + + #@markdown **Batch Settings** + n_batch = 1 #@param + batch_name = "StableFun" #@param {type:"string"} + filename_format = "{timestring}_{index}_{prompt}.png" #@param ["{timestring}_{index}_{seed}.png","{timestring}_{index}_{prompt}.png"] + seed_behavior = "iter" #@param ["iter","fixed","random"] + make_grid = False #@param {type:"boolean"} + grid_rows = 2 #@param + outdir = get_output_folder(output_path, batch_name) + + #@markdown **Init Settings** + use_init = False #@param {type:"boolean"} + strength = 0.0 #@param {type:"number"} + strength_0_no_init = True # Set the strength to 0 automatically when no init image is used + init_image = "https://cdn.pixabay.com/photo/2022/07/30/13/10/green-longhorn-beetle-7353749_1280.jpg" #@param {type:"string"} + # Whiter areas of the mask are areas that change more + use_mask = False #@param {type:"boolean"} + use_alpha_as_mask = False # use the alpha channel of the init image as the mask + mask_file = "https://www.filterforge.com/wiki/images/archive/b/b7/20080927223728%21Polygonal_gradient_thumb.jpg" #@param {type:"string"} + invert_mask = False #@param {type:"boolean"} + # Adjust mask image, 1.0 is no adjustment. Should be positive numbers. + mask_brightness_adjust = 1.0 #@param {type:"number"} + mask_contrast_adjust = 1.0 #@param {type:"number"} + # Overlay the masked image at the end of the generation so it does not get degraded by encoding and decoding + overlay_mask = True # {type:"boolean"} + # Blur edges of final overlay mask, if used. Minimum = 0 (no blur) + mask_overlay_blur = 5 # {type:"number"} + + n_samples = 1 # doesnt do anything + precision = 'autocast' + C = 4 + f = 8 + + prompt = "" + timestring = "" + init_latent = None + init_sample = None + init_c = None + + return locals() + + + +def next_seed(args): + if args.seed_behavior == 'iter': + args.seed += 1 + elif args.seed_behavior == 'fixed': + pass # always keep seed the same + else: + args.seed = random.randint(0, 2**32 - 1) + return args.seed + +def render_image_batch(args): + args.prompts = {k: f"{v:05d}" for v, k in enumerate(prompts)} + + # create output folder for the batch + os.makedirs(args.outdir, exist_ok=True) + if args.save_settings or args.save_samples: + print(f"Saving to {os.path.join(args.outdir, args.timestring)}_*") + + # save settings for the batch + if args.save_settings: + filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt") + with open(filename, "w+", encoding="utf-8") as f: + json.dump(dict(args.__dict__), f, ensure_ascii=False, indent=4) + + index = 0 + + # function for init image batching + init_array = [] + if args.use_init: + if args.init_image == "": + raise FileNotFoundError("No path was given for init_image") + if args.init_image.startswith('http://') or args.init_image.startswith('https://'): + init_array.append(args.init_image) + elif not os.path.isfile(args.init_image): + if args.init_image[-1] != "/": # avoids path error by adding / to end if not there + args.init_image += "/" + for image in sorted(os.listdir(args.init_image)): # iterates dir and appends images to init_array + if image.split(".")[-1] in ("png", "jpg", "jpeg"): + init_array.append(args.init_image + image) + else: + init_array.append(args.init_image) + else: + init_array = [""] + + # when doing large batches don't flood browser with images + clear_between_batches = args.n_batch >= 32 + + for iprompt, prompt in enumerate(prompts): + args.prompt = prompt + print(f"Prompt {iprompt+1} of {len(prompts)}") + print(f"{args.prompt}") + + all_images = [] + + for batch_index in range(args.n_batch): + if clear_between_batches and batch_index % 32 == 0: + display.clear_output(wait=True) + print(f"Batch {batch_index+1} of {args.n_batch}") + + for image in init_array: # iterates the init images + args.init_image = image + results = generate(args) + for image in results: + if args.make_grid: + all_images.append(T.functional.pil_to_tensor(image)) + if args.save_samples: + if args.filename_format == "{timestring}_{index}_{prompt}.png": + filename = f"{args.timestring}_{index:05}_{sanitize(prompt)[:160]}.png" + else: + filename = f"{args.timestring}_{index:05}_{args.seed}.png" + image.save(os.path.join(args.outdir, filename)) + if args.display_samples: + display.display(image) + index += 1 + args.seed = next_seed(args) + + #print(len(all_images)) + if args.make_grid: + grid = make_grid(all_images, nrow=int(len(all_images)/args.grid_rows)) + grid = rearrange(grid, 'c h w -> h w c').cpu().numpy() + filename = f"{args.timestring}_{iprompt:05d}_grid_{args.seed}.png" + grid_image = Image.fromarray(grid.astype(np.uint8)) + grid_image.save(os.path.join(args.outdir, filename)) + display.clear_output(wait=True) + display.display(grid_image) + + +def render_animation(args, anim_args): + # animations use key framed prompts + args.prompts = animation_prompts + + # expand key frame strings to values + keys = DeformAnimKeys(anim_args) + + # resume animation + start_frame = 0 + if anim_args.resume_from_timestring: + for tmp in os.listdir(args.outdir): + if tmp.split("_")[0] == anim_args.resume_timestring: + start_frame += 1 + start_frame = start_frame - 1 + + # create output folder for the batch + os.makedirs(args.outdir, exist_ok=True) + print(f"Saving animation frames to {args.outdir}") + + # save settings for the batch + settings_filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt") + with open(settings_filename, "w+", encoding="utf-8") as f: + s = {**dict(args.__dict__), **dict(anim_args.__dict__)} + json.dump(s, f, ensure_ascii=False, indent=4) + + # resume from timestring + if anim_args.resume_from_timestring: + args.timestring = anim_args.resume_timestring + + # expand prompts out to per-frame + prompt_series = pd.Series([np.nan for a in range(anim_args.max_frames)]) + for i, prompt in animation_prompts.items(): + prompt_series[i] = prompt + prompt_series = prompt_series.ffill().bfill() + + # check for video inits + using_vid_init = anim_args.animation_mode == 'Video Input' + + # load depth model for 3D + predict_depths = (anim_args.animation_mode == '3D' and anim_args.use_depth_warping) or anim_args.save_depth_maps + if predict_depths: + depth_model = DepthModel(device) + depth_model.load_midas(models_path) + if anim_args.midas_weight < 1.0: + depth_model.load_adabins() + else: + depth_model = None + anim_args.save_depth_maps = False + + # state for interpolating between diffusion steps + turbo_steps = 1 if using_vid_init else int(anim_args.diffusion_cadence) + turbo_prev_image, turbo_prev_frame_idx = None, 0 + turbo_next_image, turbo_next_frame_idx = None, 0 + + # resume animation + prev_sample = None + color_match_sample = None + if anim_args.resume_from_timestring: + last_frame = start_frame-1 + if turbo_steps > 1: + last_frame -= last_frame%turbo_steps + path = os.path.join(args.outdir,f"{args.timestring}_{last_frame:05}.png") + img = cv2.imread(path) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + prev_sample = sample_from_cv2(img) + if anim_args.color_coherence != 'None': + color_match_sample = img + if turbo_steps > 1: + turbo_next_image, turbo_next_frame_idx = sample_to_cv2(prev_sample, type=np.float32), last_frame + turbo_prev_image, turbo_prev_frame_idx = turbo_next_image, turbo_next_frame_idx + start_frame = last_frame+turbo_steps + + args.n_samples = 1 + frame_idx = start_frame + while frame_idx < anim_args.max_frames: + print(f"Rendering animation frame {frame_idx} of {anim_args.max_frames}") + noise = keys.noise_schedule_series[frame_idx] + strength = keys.strength_schedule_series[frame_idx] + contrast = keys.contrast_schedule_series[frame_idx] + depth = None + + # emit in-between frames + if turbo_steps > 1: + tween_frame_start_idx = max(0, frame_idx-turbo_steps) + for tween_frame_idx in range(tween_frame_start_idx, frame_idx): + tween = float(tween_frame_idx - tween_frame_start_idx + 1) / float(frame_idx - tween_frame_start_idx) + print(f" creating in between frame {tween_frame_idx} tween:{tween:0.2f}") + + advance_prev = turbo_prev_image is not None and tween_frame_idx > turbo_prev_frame_idx + advance_next = tween_frame_idx > turbo_next_frame_idx + + if depth_model is not None: + assert(turbo_next_image is not None) + depth = depth_model.predict(turbo_next_image, anim_args) + + if anim_args.animation_mode == '2D': + if advance_prev: + turbo_prev_image = anim_frame_warp_2d(turbo_prev_image, args, anim_args, keys, tween_frame_idx) + if advance_next: + turbo_next_image = anim_frame_warp_2d(turbo_next_image, args, anim_args, keys, tween_frame_idx) + else: # '3D' + if advance_prev: + turbo_prev_image = anim_frame_warp_3d(turbo_prev_image, depth, anim_args, keys, tween_frame_idx) + if advance_next: + turbo_next_image = anim_frame_warp_3d(turbo_next_image, depth, anim_args, keys, tween_frame_idx) + turbo_prev_frame_idx = turbo_next_frame_idx = tween_frame_idx + + if turbo_prev_image is not None and tween < 1.0: + img = turbo_prev_image*(1.0-tween) + turbo_next_image*tween + else: + img = turbo_next_image + + filename = f"{args.timestring}_{tween_frame_idx:05}.png" + cv2.imwrite(os.path.join(args.outdir, filename), cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_RGB2BGR)) + if anim_args.save_depth_maps: + depth_model.save(os.path.join(args.outdir, f"{args.timestring}_depth_{tween_frame_idx:05}.png"), depth) + if turbo_next_image is not None: + prev_sample = sample_from_cv2(turbo_next_image) + + # apply transforms to previous frame + if prev_sample is not None: + if anim_args.animation_mode == '2D': + prev_img = anim_frame_warp_2d(sample_to_cv2(prev_sample), args, anim_args, keys, frame_idx) + else: # '3D' + prev_img_cv2 = sample_to_cv2(prev_sample) + depth = depth_model.predict(prev_img_cv2, anim_args) if depth_model else None + prev_img = anim_frame_warp_3d(prev_img_cv2, depth, anim_args, keys, frame_idx) + + # apply color matching + if anim_args.color_coherence != 'None': + if color_match_sample is None: + color_match_sample = prev_img.copy() + else: + prev_img = maintain_colors(prev_img, color_match_sample, anim_args.color_coherence) + + # apply scaling + contrast_sample = prev_img * contrast + # apply frame noising + noised_sample = add_noise(sample_from_cv2(contrast_sample), noise) + + # use transformed previous frame as init for current + args.use_init = True + if half_precision: + args.init_sample = noised_sample.half().to(device) + else: + args.init_sample = noised_sample.to(device) + args.strength = max(0.0, min(1.0, strength)) + + # grab prompt for current frame + args.prompt = prompt_series[frame_idx] + print(f"{args.prompt} {args.seed}") + if not using_vid_init: + print(f"Angle: {keys.angle_series[frame_idx]} Zoom: {keys.zoom_series[frame_idx]}") + print(f"Tx: {keys.translation_x_series[frame_idx]} Ty: {keys.translation_y_series[frame_idx]} Tz: {keys.translation_z_series[frame_idx]}") + print(f"Rx: {keys.rotation_3d_x_series[frame_idx]} Ry: {keys.rotation_3d_y_series[frame_idx]} Rz: {keys.rotation_3d_z_series[frame_idx]}") + + # grab init image for current frame + if using_vid_init: + init_frame = os.path.join(args.outdir, 'inputframes', f"{frame_idx+1:05}.jpg") + print(f"Using video init frame {init_frame}") + args.init_image = init_frame + if anim_args.use_mask_video: + mask_frame = os.path.join(args.outdir, 'maskframes', f"{frame_idx+1:05}.jpg") + args.mask_file = mask_frame + + # sample the diffusion model + sample, image = generate(args, frame_idx, return_latent=False, return_sample=True) + if not using_vid_init: + prev_sample = sample + + if turbo_steps > 1: + turbo_prev_image, turbo_prev_frame_idx = turbo_next_image, turbo_next_frame_idx + turbo_next_image, turbo_next_frame_idx = sample_to_cv2(sample, type=np.float32), frame_idx + frame_idx += turbo_steps + else: + filename = f"{args.timestring}_{frame_idx:05}.png" + image.save(os.path.join(args.outdir, filename)) + if anim_args.save_depth_maps: + if depth is None: + depth = depth_model.predict(sample_to_cv2(sample), anim_args) + depth_model.save(os.path.join(args.outdir, f"{args.timestring}_depth_{frame_idx:05}.png"), depth) + frame_idx += 1 + + display.clear_output(wait=True) + display.display(image) + + args.seed = next_seed(args) + +def vid2frames(video_path, frames_path, n=1, overwrite=True): + if not os.path.exists(frames_path) or overwrite: + try: + for f in pathlib.Path(video_in_frame_path).glob('*.jpg'): + f.unlink() + except: + pass + assert os.path.exists(video_path), f"Video input {video_path} does not exist" + + vidcap = cv2.VideoCapture(video_path) + success,image = vidcap.read() + count = 0 + t=1 + success = True + while success: + if count % n == 0: + cv2.imwrite(frames_path + os.path.sep + f"{t:05}.jpg" , image) # save frame as JPEG file + t += 1 + success,image = vidcap.read() + count += 1 + print("Converted %d frames" % count) + else: print("Frames already unpacked") + +def render_input_video(args, anim_args): + # create a folder for the video input frames to live in + video_in_frame_path = os.path.join(args.outdir, 'inputframes') + os.makedirs(video_in_frame_path, exist_ok=True) + + # save the video frames from input video + print(f"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {video_in_frame_path}...") + vid2frames(anim_args.video_init_path, video_in_frame_path, anim_args.extract_nth_frame, anim_args.overwrite_extracted_frames) + + # determine max frames from length of input frames + anim_args.max_frames = len([f for f in pathlib.Path(video_in_frame_path).glob('*.jpg')]) + args.use_init = True + print(f"Loading {anim_args.max_frames} input frames from {video_in_frame_path} and saving video frames to {args.outdir}") + + if anim_args.use_mask_video: + # create a folder for the mask video input frames to live in + mask_in_frame_path = os.path.join(args.outdir, 'maskframes') + os.makedirs(mask_in_frame_path, exist_ok=True) + + # save the video frames from mask video + print(f"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {mask_in_frame_path}...") + vid2frames(anim_args.video_mask_path, mask_in_frame_path, anim_args.extract_nth_frame, anim_args.overwrite_extracted_frames) + args.use_mask = True + args.overlay_mask = True + + render_animation(args, anim_args) + +def render_interpolation(args, anim_args): + # animations use key framed prompts + args.prompts = animation_prompts + + # create output folder for the batch + os.makedirs(args.outdir, exist_ok=True) + print(f"Saving animation frames to {args.outdir}") + + # save settings for the batch + settings_filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt") + with open(settings_filename, "w+", encoding="utf-8") as f: + s = {**dict(args.__dict__), **dict(anim_args.__dict__)} + json.dump(s, f, ensure_ascii=False, indent=4) + + # Interpolation Settings + args.n_samples = 1 + args.seed_behavior = 'fixed' # force fix seed at the moment bc only 1 seed is available + prompts_c_s = [] # cache all the text embeddings + + print(f"Preparing for interpolation of the following...") + + for i, prompt in animation_prompts.items(): + args.prompt = prompt + + # sample the diffusion model + results = generate(args, return_c=True) + c, image = results[0], results[1] + prompts_c_s.append(c) + + # display.clear_output(wait=True) + display.display(image) + + args.seed = next_seed(args) + + display.clear_output(wait=True) + print(f"Interpolation start...") + + frame_idx = 0 + + if anim_args.interpolate_key_frames: + for i in range(len(prompts_c_s)-1): + dist_frames = list(animation_prompts.items())[i+1][0] - list(animation_prompts.items())[i][0] + if dist_frames <= 0: + print("key frames duplicated or reversed. interpolation skipped.") + return + else: + for j in range(dist_frames): + # interpolate the text embedding + prompt1_c = prompts_c_s[i] + prompt2_c = prompts_c_s[i+1] + args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/dist_frames)) + + # sample the diffusion model + results = generate(args) + image = results[0] + + filename = f"{args.timestring}_{frame_idx:05}.png" + image.save(os.path.join(args.outdir, filename)) + frame_idx += 1 + + display.clear_output(wait=True) + display.display(image) + + args.seed = next_seed(args) + + else: + for i in range(len(prompts_c_s)-1): + for j in range(anim_args.interpolate_x_frames+1): + # interpolate the text embedding + prompt1_c = prompts_c_s[i] + prompt2_c = prompts_c_s[i+1] + args.init_c = prompt1_c.add(prompt2_c.sub(prompt1_c).mul(j * 1/(anim_args.interpolate_x_frames+1))) + + # sample the diffusion model + results = generate(args) + image = results[0] + + filename = f"{args.timestring}_{frame_idx:05}.png" + image.save(os.path.join(args.outdir, filename)) + frame_idx += 1 + + display.clear_output(wait=True) + display.display(image) + + args.seed = next_seed(args) + + # generate the last prompt + args.init_c = prompts_c_s[-1] + results = generate(args) + image = results[0] + filename = f"{args.timestring}_{frame_idx:05}.png" + image.save(os.path.join(args.outdir, filename)) + + display.clear_output(wait=True) + display.display(image) + args.seed = next_seed(args) + + #clear init_c + args.init_c = None + + +args_dict = DeforumArgs() +anim_args_dict = DeforumAnimArgs() + +if override_settings_with_file: + print(f"reading custom settings from {custom_settings_file}") + if not os.path.isfile(custom_settings_file): + print('The custom settings file does not exist. The in-notebook settings will be used instead') + else: + with open(custom_settings_file, "r") as f: + jdata = json.loads(f.read()) + animation_prompts = jdata["prompts"] + for i, k in enumerate(args_dict): + if k in jdata: + args_dict[k] = jdata[k] + else: + print(f"key {k} doesn't exist in the custom settings data! using the default value of {args_dict[k]}") + for i, k in enumerate(anim_args_dict): + if k in jdata: + anim_args_dict[k] = jdata[k] + else: + print(f"key {k} doesn't exist in the custom settings data! using the default value of {anim_args_dict[k]}") + print(args_dict) + print(anim_args_dict) + +args = SimpleNamespace(**args_dict) +anim_args = SimpleNamespace(**anim_args_dict) + +args.timestring = time.strftime('%Y%m%d%H%M%S') +args.strength = max(0.0, min(1.0, args.strength)) + +if args.seed == -1: + args.seed = random.randint(0, 2**32 - 1) +if not args.use_init: + args.init_image = None +if args.sampler == 'plms' and (args.use_init or anim_args.animation_mode != 'None'): + print(f"Init images aren't supported with PLMS yet, switching to KLMS") + args.sampler = 'klms' +if args.sampler != 'ddim': + args.ddim_eta = 0 + +if anim_args.animation_mode == 'None': + anim_args.max_frames = 1 +elif anim_args.animation_mode == 'Video Input': + args.use_init = True + +# clean up unused memory +gc.collect() +torch.cuda.empty_cache() + +# dispatch to appropriate renderer +if anim_args.animation_mode == '2D' or anim_args.animation_mode == '3D': + render_animation(args, anim_args) +elif anim_args.animation_mode == 'Video Input': + render_input_video(args, anim_args) +elif anim_args.animation_mode == 'Interpolation': + render_interpolation(args, anim_args) +else: + render_image_batch(args) + +# %% +# !! {"metadata":{ +# !! "id": "4zV0J_YbMCTx" +# !! }} +""" +# Create video from frames +""" + +# %% +# !! {"metadata":{ +# !! "cellView": "form", +# !! "id": "no2jP8HTMBM0" +# !! }} +skip_video_for_run_all = True #@param {type: 'boolean'} +fps = 12 #@param {type:"number"} +#@markdown **Manual Settings** +use_manual_settings = False #@param {type:"boolean"} +image_path = "/content/drive/MyDrive/AI/StableDiffusion/2022-09/20220903000939_%05d.png" #@param {type:"string"} +mp4_path = "/content/drive/MyDrive/AI/StableDiffu'/content/drive/MyDrive/AI/StableDiffusion/2022-09/sion/2022-09/20220903000939.mp4" #@param {type:"string"} +render_steps = True #@param {type: 'boolean'} +path_name_modifier = "x0_pred" #@param ["x0_pred","x"] + + +if skip_video_for_run_all == True: + print('Skipping video creation, uncheck skip_video_for_run_all if you want to run it') +else: + import os + import subprocess + from base64 import b64encode + + print(f"{image_path} -> {mp4_path}") + + if use_manual_settings: + max_frames = "200" #@param {type:"string"} + else: + if render_steps: # render steps from a single image + fname = f"{path_name_modifier}_%05d.png" + all_step_dirs = [os.path.join(args.outdir, d) for d in os.listdir(args.outdir) if os.path.isdir(os.path.join(args.outdir,d))] + newest_dir = max(all_step_dirs, key=os.path.getmtime) + image_path = os.path.join(newest_dir, fname) + print(f"Reading images from {image_path}") + mp4_path = os.path.join(newest_dir, f"{args.timestring}_{path_name_modifier}.mp4") + max_frames = str(args.steps) + else: # render images for a video + image_path = os.path.join(args.outdir, f"{args.timestring}_%05d.png") + mp4_path = os.path.join(args.outdir, f"{args.timestring}.mp4") + max_frames = str(anim_args.max_frames) + + # make video + cmd = [ + 'ffmpeg', + '-y', + '-vcodec', 'png', + '-r', str(fps), + '-start_number', str(0), + '-i', image_path, + '-frames:v', max_frames, + '-c:v', 'libx264', + '-vf', + f'fps={fps}', + '-pix_fmt', 'yuv420p', + '-crf', '17', + '-preset', 'veryfast', + '-pattern_type', 'sequence', + mp4_path + ] + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.communicate() + if process.returncode != 0: + print(stderr) + raise RuntimeError(stderr) + + mp4 = open(mp4_path,'rb').read() + data_url = "data:video/mp4;base64," + b64encode(mp4).decode() + display.display( display.HTML(f'') ) + +# %% +# !! {"metadata":{ +# !! "id": "RoECylTun7AA" +# !! }} +""" +# Disconnect when finished +""" + +# %% +# !! {"metadata":{ +# !! "cellView": "form", +# !! "id": "bfXpWRgSn-eH" +# !! }} +skip_disconnect_for_run_all = True #@param {type: 'boolean'} + +if skip_disconnect_for_run_all == True: + print('Skipping disconnect, uncheck skip_disconnect_for_run_all if you want to run it') +else: + from google.colab import runtime + runtime.unassign() + +# %% +# !! {"main_metadata":{ +# !! "accelerator": "GPU", +# !! "colab": { +# !! "collapsed_sections": [], +# !! "provenance": [], +# !! "private_outputs": true +# !! }, +# !! "gpuClass": "standard", +# !! "kernelspec": { +# !! "display_name": "Python 3", +# !! "name": "python3" +# !! }, +# !! "language_info": { +# !! "name": "python" +# !! } +# !! }} diff --git a/README.md b/README.md index c9e6c3bb13..de7d291982 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,8 @@ + +[![Deforum Stable Diffusion](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/deforum/stable-diffusion/blob/main/Deforum_Stable_Diffusion.ipynb) +![visitors](https://visitor-badge.glitch.me/badge?page_id=deforum_sd_repo) +[![Replicate](https://replicate.com/deforum/deforum_stable_diffusion/badge)](https://replicate.com/deforum/deforum_stable_diffusion) + # Stable Diffusion *Stable Diffusion was made possible thanks to a collaboration with [Stability AI](https://stability.ai/) and [Runway](https://runwayml.com/) and builds upon our previous work:* diff --git a/backup/Deforum_Stable_Diffusion.ipynb b/backup/Deforum_Stable_Diffusion.ipynb new file mode 100644 index 0000000000..47c30151e3 --- /dev/null +++ b/backup/Deforum_Stable_Diffusion.ipynb @@ -0,0 +1,622 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "c442uQJ_gUgy" + }, + "source": [ + "# **Deforum Stable Diffusion**\n", + "[Stable Diffusion](https://github.com/CompVis/stable-diffusion) by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Bj\u00f6rn Ommer and the [Stability.ai](https://stability.ai/) Team\n", + "\n", + "Notebook by [deforum](https://twitter.com/deforum_art)" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "2g-f7cQmf2Nt", + "cellView": "form" + }, + "source": [ + "#@markdown **NVIDIA GPU**\n", + "import subprocess\n", + "sub_p_res = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total,memory.free', '--format=csv,noheader'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n", + "print(sub_p_res)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": { + "id": "VRNl2mfepEIe", + "cellView": "form" + }, + "source": [ + "#@markdown **Setup Environment**\n", + "\n", + "setup_environment = True #@param {type:\"boolean\"}\n", + "print_subprocess = False #@param {type:\"boolean\"}\n", + "\n", + "if setup_environment:\n", + " import subprocess\n", + " print(\"...setting up environment\")\n", + " all_process = [['pip', 'install', 'torch==1.11.0+cu113', 'torchvision==0.12.0+cu113', 'torchaudio==0.11.0', '--extra-index-url', 'https://download.pytorch.org/whl/cu113'],\n", + " ['pip', 'install', 'omegaconf==2.1.1', 'einops==0.3.0', 'pytorch-lightning==1.4.2', 'torchmetrics==0.6.0', 'torchtext==0.2.3', 'transformers==4.19.2', 'kornia==0.6'],\n", + " ['git', 'clone', 'https://github.com/deforum/stable-diffusion'],\n", + " ['pip', 'install', '-e', 'git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers'],\n", + " ['pip', 'install', '-e', 'git+https://github.com/openai/CLIP.git@main#egg=clip'],\n", + " ['pip', 'install', 'accelerate', 'ftfy', 'jsonmerge', 'resize-right', 'torchdiffeq'],\n", + " ]\n", + " for process in all_process:\n", + " running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8')\n", + " if print_subprocess:\n", + " print(running)\n", + " \n", + " print(subprocess.run(['git', 'clone', 'https://github.com/deforum/k-diffusion/'], stdout=subprocess.PIPE).stdout.decode('utf-8'))\n", + " with open('k-diffusion/k_diffusion/__init__.py', 'w') as f:\n", + " f.write('')\n", + " \n", + " import sys\n", + " sys.path.append('./src/taming-transformers')\n", + " sys.path.append('./src/clip')\n", + " sys.path.append('./stable-diffusion/')\n", + " sys.path.append('./k-diffusion')" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": { + "cellView": "form", + "id": "81qmVZbrm4uu" + }, + "source": [ + "#@markdown **Python Definitions**\n", + "import json\n", + "from IPython import display\n", + "\n", + "import sys, os\n", + "import argparse, glob\n", + "import torch\n", + "import torch.nn as nn\n", + "import numpy as np\n", + "import requests\n", + "import shutil\n", + "from types import SimpleNamespace\n", + "from omegaconf import OmegaConf\n", + "from PIL import Image\n", + "from tqdm import tqdm, trange\n", + "from itertools import islice\n", + "from einops import rearrange, repeat\n", + "from torchvision.utils import make_grid\n", + "import time\n", + "from pytorch_lightning import seed_everything\n", + "from torch import autocast\n", + "from contextlib import contextmanager, nullcontext\n", + "\n", + "from helpers import save_samples\n", + "from ldm.util import instantiate_from_config\n", + "from ldm.models.diffusion.ddim import DDIMSampler\n", + "from ldm.models.diffusion.plms import PLMSSampler\n", + "\n", + "import accelerate\n", + "from k_diffusion import sampling\n", + "from k_diffusion.external import CompVisDenoiser\n", + "\n", + "def chunk(it, size):\n", + " it = iter(it)\n", + " return iter(lambda: tuple(islice(it, size)), ())\n", + "\n", + "def get_output_folder(output_path,batch_folder=None):\n", + " yearMonth = time.strftime('%Y-%m/')\n", + " out_path = output_path+\"/\"+yearMonth\n", + " if batch_folder != \"\":\n", + " out_path += batch_folder\n", + " if out_path[-1] != \"/\":\n", + " out_path += \"/\"\n", + " os.makedirs(out_path, exist_ok=True)\n", + " return out_path\n", + "\n", + "def load_img(path, shape):\n", + " if path.startswith('http://') or path.startswith('https://'):\n", + " image = Image.open(requests.get(path, stream=True).raw).convert('RGB')\n", + " else:\n", + " image = Image.open(path).convert('RGB')\n", + "\n", + " image = image.resize(shape, resample=Image.LANCZOS)\n", + " image = np.array(image).astype(np.float16) / 255.0\n", + " image = image[None].transpose(0, 3, 1, 2)\n", + " image = torch.from_numpy(image)\n", + " return 2.*image - 1.\n", + "\n", + "class CFGDenoiser(nn.Module):\n", + " def __init__(self, model):\n", + " super().__init__()\n", + " self.inner_model = model\n", + "\n", + " def forward(self, x, sigma, uncond, cond, cond_scale):\n", + " x_in = torch.cat([x] * 2)\n", + " sigma_in = torch.cat([sigma] * 2)\n", + " cond_in = torch.cat([uncond, cond])\n", + " uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)\n", + " return uncond + (cond - uncond) * cond_scale\n", + "\n", + "def make_callback(sampler, dynamic_threshold=None, static_threshold=None): \n", + " # Creates the callback function to be passed into the samplers\n", + " # The callback function is applied to the image after each step\n", + " def dynamic_thresholding_(img, threshold):\n", + " # Dynamic thresholding from Imagen paper (May 2022)\n", + " s = np.percentile(np.abs(img.cpu()), threshold, axis=tuple(range(1,img.ndim)))\n", + " s = np.max(np.append(s,1.0))\n", + " torch.clamp_(img, -1*s, s)\n", + " torch.FloatTensor.div_(img, s)\n", + "\n", + " # Callback for samplers in the k-diffusion repo, called thus:\n", + " # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})\n", + " def k_callback(args_dict):\n", + " if static_threshold is not None:\n", + " torch.clamp_(args_dict['x'], -1*static_threshold, static_threshold)\n", + " if dynamic_threshold is not None:\n", + " dynamic_thresholding_(args_dict['x'], dynamic_threshold)\n", + "\n", + " # Function that is called on the image (img) and step (i) at each step\n", + " def img_callback(img, i):\n", + " # Thresholding functions\n", + " if dynamic_threshold is not None:\n", + " dynamic_thresholding_(img, dynamic_threshold)\n", + " if static_threshold is not None:\n", + " torch.clamp_(img, -1*static_threshold, static_threshold)\n", + "\n", + " if sampler in [\"plms\",\"ddim\"]: \n", + " # Callback function formated for compvis latent diffusion samplers\n", + " callback = img_callback\n", + " else: \n", + " # Default callback function uses k-diffusion sampler variables\n", + " callback = k_callback\n", + "\n", + " return callback\n", + "\n", + "def run(args, local_seed):\n", + "\n", + " # load settings\n", + " accelerator = accelerate.Accelerator()\n", + " device = accelerator.device\n", + " seeds = torch.randint(-2 ** 63, 2 ** 63 - 1, [accelerator.num_processes])\n", + " torch.manual_seed(seeds[accelerator.process_index].item())\n", + "\n", + " # plms\n", + " if args.sampler==\"plms\":\n", + " args.eta = 0\n", + " sampler = PLMSSampler(model)\n", + " else:\n", + " sampler = DDIMSampler(model)\n", + "\n", + " model_wrap = CompVisDenoiser(model)\n", + " sigma_min, sigma_max = model_wrap.sigmas[0].item(), model_wrap.sigmas[-1].item()\n", + "\n", + " batch_size = args.n_samples\n", + " n_rows = args.n_rows if args.n_rows > 0 else batch_size\n", + "\n", + " data = list(chunk(args.prompts, batch_size))\n", + " sample_index = 0\n", + "\n", + " start_code = None\n", + " \n", + " # init image\n", + " if args.use_init:\n", + " assert os.path.isfile(args.init_image)\n", + " init_image = load_img(args.init_image, shape=(args.W, args.H)).to(device)\n", + " init_image = repeat(init_image, '1 ... -> b ...', b=batch_size)\n", + " init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space\n", + "\n", + " sampler.make_schedule(ddim_num_steps=args.steps, ddim_eta=args.eta, verbose=False)\n", + "\n", + " assert 0. <= args.strength <= 1., 'can only work with strength in [0.0, 1.0]'\n", + " t_enc = int(args.strength * args.steps)\n", + " print(f\"target t_enc is {t_enc} steps\")\n", + "\n", + " # no init image\n", + " else:\n", + " if args.fixed_code:\n", + " start_code = torch.randn([args.n_samples, args.C, args.H // args.f, args.W // args.f], device=device)\n", + "\n", + " precision_scope = autocast if args.precision==\"autocast\" else nullcontext\n", + " with torch.no_grad():\n", + " with precision_scope(\"cuda\"):\n", + " with model.ema_scope():\n", + " tic = time.time()\n", + " for prompt_index, prompts in enumerate(data):\n", + " print(prompts)\n", + " prompt_seed = local_seed + prompt_index\n", + " seed_everything(prompt_seed)\n", + "\n", + " callback = make_callback(sampler=args.sampler,\n", + " dynamic_threshold=args.dynamic_threshold, \n", + " static_threshold=args.static_threshold) \n", + "\n", + " uc = None\n", + " if args.scale != 1.0:\n", + " uc = model.get_learned_conditioning(batch_size * [\"\"])\n", + " if isinstance(prompts, tuple):\n", + " prompts = list(prompts)\n", + " c = model.get_learned_conditioning(prompts)\n", + "\n", + " if args.sampler in [\"klms\",\"dpm2\",\"dpm2_ancestral\",\"heun\",\"euler\",\"euler_ancestral\"]:\n", + " shape = [args.C, args.H // args.f, args.W // args.f]\n", + " sigmas = model_wrap.get_sigmas(args.steps)\n", + " torch.manual_seed(prompt_seed)\n", + " if args.use_init:\n", + " sigmas = sigmas[t_enc:]\n", + " x = init_latent + torch.randn([args.n_samples, *shape], device=device) * sigmas[0]\n", + " else:\n", + " x = torch.randn([args.n_samples, *shape], device=device) * sigmas[0]\n", + " model_wrap_cfg = CFGDenoiser(model_wrap)\n", + " extra_args = {'cond': c, 'uncond': uc, 'cond_scale': args.scale}\n", + " if args.sampler==\"klms\":\n", + " samples = sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process, callback=callback)\n", + " elif args.sampler==\"dpm2\":\n", + " samples = sampling.sample_dpm_2(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process, callback=callback)\n", + " elif args.sampler==\"dpm2_ancestral\":\n", + " samples = sampling.sample_dpm_2_ancestral(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process, callback=callback)\n", + " elif args.sampler==\"heun\":\n", + " samples = sampling.sample_heun(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process, callback=callback)\n", + " elif args.sampler==\"euler\":\n", + " samples = sampling.sample_euler(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process, callback=callback)\n", + " elif args.sampler==\"euler_ancestral\":\n", + " samples = sampling.sample_euler_ancestral(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process, callback=callback)\n", + "\n", + " x_samples = model.decode_first_stage(samples)\n", + " x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)\n", + " x_samples = accelerator.gather(x_samples)\n", + "\n", + " else:\n", + "\n", + " # no init image\n", + " if not args.use_init:\n", + " shape = [args.C, args.H // args.f, args.W // args.f]\n", + "\n", + " samples, _ = sampler.sample(S=args.steps,\n", + " conditioning=c,\n", + " batch_size=args.n_samples,\n", + " shape=shape,\n", + " verbose=False,\n", + " unconditional_guidance_scale=args.scale,\n", + " unconditional_conditioning=uc,\n", + " eta=args.eta,\n", + " x_T=start_code,\n", + " img_callback=callback)\n", + "\n", + " # init image\n", + " else:\n", + " # encode (scaled latent)\n", + " z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device))\n", + " # decode it\n", + " samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=args.scale,\n", + " unconditional_conditioning=uc,)\n", + "\n", + " x_samples = model.decode_first_stage(samples)\n", + " x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)\n", + " \n", + "\n", + " grid, images = save_samples(\n", + " args, x_samples=x_samples, seed=prompt_seed, n_rows=n_rows\n", + " )\n", + " if args.display_samples:\n", + " for im in images:\n", + " display.display(im)\n", + " if args.display_grid:\n", + " display.display(grid)\n", + "\n", + " # stop timer\n", + " toc = time.time()\n", + "\n", + " #print(f\"Your samples are ready and waiting for you here: \\n{outpath} \\n\" f\" \\nEnjoy.\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": { + "cellView": "form", + "id": "TxIOPT0G5Lx1" + }, + "source": [ + "#@markdown **Model Path Variables**\n", + "# ask for the link\n", + "print(\"Local Path Variables:\\n\")\n", + "\n", + "models_path = \"/content/models\" #@param {type:\"string\"}\n", + "output_path = \"/content/output\" #@param {type:\"string\"}\n", + "\n", + "#@markdown **Google Drive Path Variables (Optional)**\n", + "mount_google_drive = True #@param {type:\"boolean\"}\n", + "force_remount = False\n", + "\n", + "if mount_google_drive:\n", + " from google.colab import drive\n", + " try:\n", + " drive_path = \"/content/drive\"\n", + " drive.mount(drive_path,force_remount=force_remount)\n", + " models_path_gdrive = \"/content/drive/MyDrive/AI/models\" #@param {type:\"string\"}\n", + " output_path_gdrive = \"/content/drive/MyDrive/AI/StableDiffusion\" #@param {type:\"string\"}\n", + " models_path = models_path_gdrive\n", + " output_path = output_path_gdrive\n", + " except:\n", + " print(\"...error mounting drive or with drive path variables\")\n", + " print(\"...reverting to default path variables\")\n", + "\n", + "os.makedirs(models_path, exist_ok=True)\n", + "os.makedirs(output_path, exist_ok=True)\n", + "\n", + "print(f\"models_path: {models_path}\")\n", + "print(f\"output_path: {output_path}\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": { + "cellView": "form", + "id": "CIUJ7lWI4v53" + }, + "source": [ + "#@markdown **Select Model**\n", + "print(\"\\nSelect Model:\\n\")\n", + "\n", + "model_config = \"v1-inference.yaml\" #@param [\"custom\",\"v1-inference.yaml\"]\n", + "model_checkpoint = \"sd-v1-4.ckpt\" #@param [\"custom\",\"sd-v1-4.ckpt\",\"sd-v1-3-full-ema.ckpt\",\"sd-v1-3.ckpt\",\"sd-v1-2-full-ema.ckpt\",\"sd-v1-2.ckpt\",\"sd-v1-1-full-ema.ckpt\",\"sd-v1-1.ckpt\"]\n", + "custom_config_path = \"\" #@param {type:\"string\"}\n", + "custom_checkpoint_path = \"\" #@param {type:\"string\"}\n", + "\n", + "check_sha256 = True #@param {type:\"boolean\"}\n", + "\n", + "model_map = {\n", + " \"sd-v1-4.ckpt\": {'sha256': 'fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556'},\n", + " \"sd-v1-3-full-ema.ckpt\": {'sha256': '54632c6e8a36eecae65e36cb0595fab314e1a1545a65209f24fde221a8d4b2ca'},\n", + " \"sd-v1-3.ckpt\": {'sha256': '2cff93af4dcc07c3e03110205988ff98481e86539c51a8098d4f2236e41f7f2f'},\n", + " \"sd-v1-2-full-ema.ckpt\": {'sha256': 'bc5086a904d7b9d13d2a7bccf38f089824755be7261c7399d92e555e1e9ac69a'},\n", + " \"sd-v1-2.ckpt\": {'sha256': '3b87d30facd5bafca1cbed71cfb86648aad75d1c264663c0cc78c7aea8daec0d'},\n", + " \"sd-v1-1-full-ema.ckpt\": {'sha256': 'efdeb5dc418a025d9a8cc0a8617e106c69044bc2925abecc8a254b2910d69829'},\n", + " \"sd-v1-1.ckpt\": {'sha256': '86cd1d3ccb044d7ba8db743d717c9bac603c4043508ad2571383f954390f3cea'}\n", + "}\n", + "\n", + "def wget(url, outputdir):\n", + " res = subprocess.run(['wget', url, '-P', f'{outputdir}'], stdout=subprocess.PIPE).stdout.decode('utf-8')\n", + " print(res)\n", + "\n", + "def download_model(model_checkpoint):\n", + " download_link = model_map[model_checkpoint][\"link\"][0]\n", + " print(f\"!wget -O {models_path}/{model_checkpoint} {download_link}\")\n", + " wget(download_link, models_path)\n", + " return\n", + "\n", + "# config path\n", + "if os.path.exists(models_path+'/'+model_config):\n", + " print(f\"{models_path+'/'+model_config} exists\")\n", + "else:\n", + " print(\"cp ./stable-diffusion/configs/stable-diffusion/v1-inference.yaml $models_path/.\")\n", + " shutil.copy('./stable-diffusion/configs/stable-diffusion/v1-inference.yaml', models_path)\n", + "\n", + "# checkpoint path or download\n", + "if os.path.exists(models_path+'/'+model_checkpoint):\n", + " print(f\"{models_path+'/'+model_checkpoint} exists\")\n", + "else:\n", + " print(f\"download model checkpoint and place in {models_path+'/'+model_checkpoint}\")\n", + " #download_model(model_checkpoint)\n", + "\n", + "if check_sha256:\n", + " import hashlib\n", + " print(\"\\n...checking sha256\")\n", + " with open(models_path+'/'+model_checkpoint, \"rb\") as f:\n", + " bytes = f.read() \n", + " hash = hashlib.sha256(bytes).hexdigest()\n", + " del bytes\n", + " if model_map[model_checkpoint][\"sha256\"] == hash:\n", + " print(\"hash is correct\\n\")\n", + " else:\n", + " print(\"hash in not correct\\n\")\n", + "\n", + "if model_config == \"custom\":\n", + " config = custom_config_path\n", + "else:\n", + " config = models_path+'/'+model_config\n", + "\n", + "if model_checkpoint == \"custom\":\n", + " ckpt = custom_checkpoint_path\n", + "else:\n", + " ckpt = models_path+'/'+model_checkpoint\n", + "\n", + "print(f\"config: {config}\")\n", + "print(f\"ckpt: {ckpt}\")" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": { + "cellView": "form", + "id": "IJiMgz_96nr3" + }, + "source": [ + "#@markdown **Load Stable Diffusion**\n", + "\n", + "def load_model_from_config(config, ckpt, verbose=False, device='cuda'):\n", + " map_location = \"cuda\" #@param [\"cpu\", \"cuda\"]\n", + " print(f\"Loading model from {ckpt}\")\n", + " pl_sd = torch.load(ckpt, map_location=map_location)\n", + " if \"global_step\" in pl_sd:\n", + " print(f\"Global Step: {pl_sd['global_step']}\")\n", + " sd = pl_sd[\"state_dict\"]\n", + " model = instantiate_from_config(config.model)\n", + " m, u = model.load_state_dict(sd, strict=False)\n", + " if len(m) > 0 and verbose:\n", + " print(\"missing keys:\")\n", + " print(m)\n", + " if len(u) > 0 and verbose:\n", + " print(\"unexpected keys:\")\n", + " print(u)\n", + "\n", + " #model.cuda()\n", + " model = model.half().to(device)\n", + " model.eval()\n", + " return model\n", + "\n", + "load_on_run_all = True #@param {type: 'boolean'}\n", + "\n", + "if load_on_run_all:\n", + "\n", + " local_config = OmegaConf.load(f\"{config}\")\n", + " model = load_model_from_config(local_config, f\"{ckpt}\")\n", + " device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n", + " model = model.to(device)" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ov3r4RD1tzsT" + }, + "source": [ + "# **Run**" + ] + }, + { + "cell_type": "code", + "metadata": { + "id": "qH74gBWDd2oq" + }, + "source": [ + "def DeforumArgs():\n", + " #@markdown **Save & Display Settings**\n", + " batchdir = \"test\" #@param {type:\"string\"}\n", + " outdir = get_output_folder(output_path, batchdir)\n", + " save_grid = False\n", + " save_samples = True #@param {type:\"boolean\"}\n", + " save_settings = True #@param {type:\"boolean\"}\n", + " display_grid = False\n", + " display_samples = True #@param {type:\"boolean\"}\n", + "\n", + " #@markdown **Image Settings**\n", + " n_samples = 1 #@param\n", + " n_rows = 1 #@param\n", + " W = 512 #@param\n", + " H = 576 #@param\n", + " W, H = map(lambda x: x - x % 64, (W, H)) # resize to integer multiple of 64\n", + "\n", + "\n", + " #@markdown **Init Settings**\n", + " use_init = False #@param {type:\"boolean\"}\n", + " init_image = \"/content/drive/MyDrive/AI/escape.jpg\" #@param {type:\"string\"}\n", + " strength = 0.5 #@param {type:\"number\"}\n", + "\n", + " #@markdown **Sampling Settings**\n", + " seed = 1 #@param\n", + " sampler = 'euler_ancestral' #@param [\"klms\",\"dpm2\",\"dpm2_ancestral\",\"heun\",\"euler\",\"euler_ancestral\",\"plms\", \"ddim\"]\n", + " steps = 50 #@param\n", + " scale = 7 #@param\n", + " eta = 0.0 #@param\n", + " dynamic_threshold = None\n", + " static_threshold = None \n", + "\n", + " #@markdown **Batch Settings**\n", + " n_batch = 2 #@param\n", + "\n", + " precision = 'autocast' \n", + " fixed_code = True\n", + " C = 4\n", + " f = 8\n", + " prompts = globals()['prompts']\n", + " timestring = \"\"\n", + "\n", + " return locals()\n", + "\n", + "args = SimpleNamespace(**DeforumArgs())" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": { + "id": "2ujwkGZTcGev" + }, + "source": [ + "prompts = [\n", + " \"a beautiful forest by Asher Brown Durand, trending on Artstation\", #the first prompt I want\n", + " \"a beautiful portrait of a woman by Artgerm, trending on Artstation\", #the second prompt I want\n", + " #\"the third prompt I don't want it I commented it with an\",\n", + "]" + ], + "outputs": [], + "execution_count": null + }, + { + "cell_type": "code", + "metadata": { + "cellView": "form", + "id": "cxx8BzxjiaXg" + }, + "source": [ + "#@markdown **Run**\n", + "args = DeforumArgs()\n", + "args.filename = None\n", + "args.prompts = prompts\n", + "\n", + "def do_batch_run():\n", + " # create output folder\n", + " os.makedirs(args.outdir, exist_ok=True)\n", + "\n", + " # current timestring for filenames\n", + " args.timestring = time.strftime('%Y%m%d%H%M%S')\n", + "\n", + " # save settings for the batch\n", + " if args.save_settings:\n", + " filename = os.path.join(args.outdir, f\"{args.timestring}_settings.txt\")\n", + " with open(filename, \"w+\", encoding=\"utf-8\") as f:\n", + " json.dump(dict(args.__dict__), f, ensure_ascii=False, indent=4)\n", + "\n", + " for batch_index in range(args.n_batch):\n", + "\n", + " # random seed\n", + " if args.seed == -1:\n", + " local_seed = np.random.randint(0,4294967295)\n", + " else:\n", + " local_seed = args.seed\n", + "\n", + " print(f\"run {batch_index+1} of {args.n_batch}\")\n", + " run(args, local_seed)\n", + "\n", + "do_batch_run()" + ], + "outputs": [], + "execution_count": null + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "collapsed_sections": [], + "name": "Deforum_Stable_Diffusion.ipynb", + "provenance": [], + "private_outputs": true + }, + "gpuClass": "standard", + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} \ No newline at end of file diff --git a/backup/Deforum_Stable_Diffusion.py b/backup/Deforum_Stable_Diffusion.py new file mode 100644 index 0000000000..e4e668197f --- /dev/null +++ b/backup/Deforum_Stable_Diffusion.py @@ -0,0 +1,571 @@ +# %% +# !! {"metadata":{ +# !! "id": "c442uQJ_gUgy" +# !! }} +""" +# **Deforum Stable Diffusion** +[Stable Diffusion](https://github.com/CompVis/stable-diffusion) by Robin Rombach, Andreas Blattmann, Dominik Lorenz, Patrick Esser, Björn Ommer and the [Stability.ai](https://stability.ai/) Team + +Notebook by [deforum](https://twitter.com/deforum_art) +""" + +# %% +# !! {"metadata":{ +# !! "id": "2g-f7cQmf2Nt", +# !! "cellView": "form" +# !! }} +#@markdown **NVIDIA GPU** +import subprocess +sub_p_res = subprocess.run(['nvidia-smi', '--query-gpu=name,memory.total,memory.free', '--format=csv,noheader'], stdout=subprocess.PIPE).stdout.decode('utf-8') +print(sub_p_res) + +# %% +# !! {"metadata":{ +# !! "id": "VRNl2mfepEIe", +# !! "cellView": "form" +# !! }} +#@markdown **Setup Environment** + +setup_environment = True #@param {type:"boolean"} +print_subprocess = False #@param {type:"boolean"} + +if setup_environment: + import subprocess + print("...setting up environment") + all_process = [['pip', 'install', 'torch==1.11.0+cu113', 'torchvision==0.12.0+cu113', 'torchaudio==0.11.0', '--extra-index-url', 'https://download.pytorch.org/whl/cu113'], + ['pip', 'install', 'omegaconf==2.1.1', 'einops==0.3.0', 'pytorch-lightning==1.4.2', 'torchmetrics==0.6.0', 'torchtext==0.2.3', 'transformers==4.19.2', 'kornia==0.6'], + ['git', 'clone', 'https://github.com/deforum/stable-diffusion'], + ['pip', 'install', '-e', 'git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers'], + ['pip', 'install', '-e', 'git+https://github.com/openai/CLIP.git@main#egg=clip'], + ['pip', 'install', 'accelerate', 'ftfy', 'jsonmerge', 'resize-right', 'torchdiffeq'], + ] + for process in all_process: + running = subprocess.run(process,stdout=subprocess.PIPE).stdout.decode('utf-8') + if print_subprocess: + print(running) + + print(subprocess.run(['git', 'clone', 'https://github.com/deforum/k-diffusion/'], stdout=subprocess.PIPE).stdout.decode('utf-8')) + with open('k-diffusion/k_diffusion/__init__.py', 'w') as f: + f.write('') + + import sys + sys.path.append('./src/taming-transformers') + sys.path.append('./src/clip') + sys.path.append('./stable-diffusion/') + sys.path.append('./k-diffusion') + +# %% +# !! {"metadata":{ +# !! "cellView": "form", +# !! "id": "81qmVZbrm4uu" +# !! }} +#@markdown **Python Definitions** +import json +from IPython import display + +import sys, os +import argparse, glob +import torch +import torch.nn as nn +import numpy as np +import requests +import shutil +from types import SimpleNamespace +from omegaconf import OmegaConf +from PIL import Image +from tqdm import tqdm, trange +from itertools import islice +from einops import rearrange, repeat +from torchvision.utils import make_grid +import time +from pytorch_lightning import seed_everything +from torch import autocast +from contextlib import contextmanager, nullcontext + +from helpers import save_samples +from ldm.util import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler + +import accelerate +from k_diffusion import sampling +from k_diffusion.external import CompVisDenoiser + +def chunk(it, size): + it = iter(it) + return iter(lambda: tuple(islice(it, size)), ()) + +def get_output_folder(output_path,batch_folder=None): + yearMonth = time.strftime('%Y-%m/') + out_path = output_path+"/"+yearMonth + if batch_folder != "": + out_path += batch_folder + if out_path[-1] != "/": + out_path += "/" + os.makedirs(out_path, exist_ok=True) + return out_path + +def load_img(path, shape): + if path.startswith('http://') or path.startswith('https://'): + image = Image.open(requests.get(path, stream=True).raw).convert('RGB') + else: + image = Image.open(path).convert('RGB') + + image = image.resize(shape, resample=Image.LANCZOS) + image = np.array(image).astype(np.float16) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.*image - 1. + +class CFGDenoiser(nn.Module): + def __init__(self, model): + super().__init__() + self.inner_model = model + + def forward(self, x, sigma, uncond, cond, cond_scale): + x_in = torch.cat([x] * 2) + sigma_in = torch.cat([sigma] * 2) + cond_in = torch.cat([uncond, cond]) + uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + return uncond + (cond - uncond) * cond_scale + +def make_callback(sampler, dynamic_threshold=None, static_threshold=None): + # Creates the callback function to be passed into the samplers + # The callback function is applied to the image after each step + def dynamic_thresholding_(img, threshold): + # Dynamic thresholding from Imagen paper (May 2022) + s = np.percentile(np.abs(img.cpu()), threshold, axis=tuple(range(1,img.ndim))) + s = np.max(np.append(s,1.0)) + torch.clamp_(img, -1*s, s) + torch.FloatTensor.div_(img, s) + + # Callback for samplers in the k-diffusion repo, called thus: + # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + def k_callback(args_dict): + if static_threshold is not None: + torch.clamp_(args_dict['x'], -1*static_threshold, static_threshold) + if dynamic_threshold is not None: + dynamic_thresholding_(args_dict['x'], dynamic_threshold) + + # Function that is called on the image (img) and step (i) at each step + def img_callback(img, i): + # Thresholding functions + if dynamic_threshold is not None: + dynamic_thresholding_(img, dynamic_threshold) + if static_threshold is not None: + torch.clamp_(img, -1*static_threshold, static_threshold) + + if sampler in ["plms","ddim"]: + # Callback function formated for compvis latent diffusion samplers + callback = img_callback + else: + # Default callback function uses k-diffusion sampler variables + callback = k_callback + + return callback + +def run(args, local_seed): + + # load settings + accelerator = accelerate.Accelerator() + device = accelerator.device + seeds = torch.randint(-2 ** 63, 2 ** 63 - 1, [accelerator.num_processes]) + torch.manual_seed(seeds[accelerator.process_index].item()) + + # plms + if args.sampler=="plms": + args.eta = 0 + sampler = PLMSSampler(model) + else: + sampler = DDIMSampler(model) + + model_wrap = CompVisDenoiser(model) + sigma_min, sigma_max = model_wrap.sigmas[0].item(), model_wrap.sigmas[-1].item() + + batch_size = args.n_samples + n_rows = args.n_rows if args.n_rows > 0 else batch_size + + data = list(chunk(args.prompts, batch_size)) + sample_index = 0 + + start_code = None + + # init image + if args.use_init: + assert os.path.isfile(args.init_image) + init_image = load_img(args.init_image, shape=(args.W, args.H)).to(device) + init_image = repeat(init_image, '1 ... -> b ...', b=batch_size) + init_latent = model.get_first_stage_encoding(model.encode_first_stage(init_image)) # move to latent space + + sampler.make_schedule(ddim_num_steps=args.steps, ddim_eta=args.eta, verbose=False) + + assert 0. <= args.strength <= 1., 'can only work with strength in [0.0, 1.0]' + t_enc = int(args.strength * args.steps) + print(f"target t_enc is {t_enc} steps") + + # no init image + else: + if args.fixed_code: + start_code = torch.randn([args.n_samples, args.C, args.H // args.f, args.W // args.f], device=device) + + precision_scope = autocast if args.precision=="autocast" else nullcontext + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): + tic = time.time() + for prompt_index, prompts in enumerate(data): + print(prompts) + prompt_seed = local_seed + prompt_index + seed_everything(prompt_seed) + + callback = make_callback(sampler=args.sampler, + dynamic_threshold=args.dynamic_threshold, + static_threshold=args.static_threshold) + + uc = None + if args.scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + + if args.sampler in ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral"]: + shape = [args.C, args.H // args.f, args.W // args.f] + sigmas = model_wrap.get_sigmas(args.steps) + torch.manual_seed(prompt_seed) + if args.use_init: + sigmas = sigmas[t_enc:] + x = init_latent + torch.randn([args.n_samples, *shape], device=device) * sigmas[0] + else: + x = torch.randn([args.n_samples, *shape], device=device) * sigmas[0] + model_wrap_cfg = CFGDenoiser(model_wrap) + extra_args = {'cond': c, 'uncond': uc, 'cond_scale': args.scale} + if args.sampler=="klms": + samples = sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process, callback=callback) + elif args.sampler=="dpm2": + samples = sampling.sample_dpm_2(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process, callback=callback) + elif args.sampler=="dpm2_ancestral": + samples = sampling.sample_dpm_2_ancestral(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process, callback=callback) + elif args.sampler=="heun": + samples = sampling.sample_heun(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process, callback=callback) + elif args.sampler=="euler": + samples = sampling.sample_euler(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process, callback=callback) + elif args.sampler=="euler_ancestral": + samples = sampling.sample_euler_ancestral(model_wrap_cfg, x, sigmas, extra_args=extra_args, disable=not accelerator.is_main_process, callback=callback) + + x_samples = model.decode_first_stage(samples) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + x_samples = accelerator.gather(x_samples) + + else: + + # no init image + if not args.use_init: + shape = [args.C, args.H // args.f, args.W // args.f] + + samples, _ = sampler.sample(S=args.steps, + conditioning=c, + batch_size=args.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=args.scale, + unconditional_conditioning=uc, + eta=args.eta, + x_T=start_code, + img_callback=callback) + + # init image + else: + # encode (scaled latent) + z_enc = sampler.stochastic_encode(init_latent, torch.tensor([t_enc]*batch_size).to(device)) + # decode it + samples = sampler.decode(z_enc, c, t_enc, unconditional_guidance_scale=args.scale, + unconditional_conditioning=uc,) + + x_samples = model.decode_first_stage(samples) + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + + + grid, images = save_samples( + args, x_samples=x_samples, seed=prompt_seed, n_rows=n_rows + ) + if args.display_samples: + for im in images: + display.display(im) + if args.display_grid: + display.display(grid) + + # stop timer + toc = time.time() + + #print(f"Your samples are ready and waiting for you here: \n{outpath} \n" f" \nEnjoy.") + +# %% +# !! {"metadata":{ +# !! "cellView": "form", +# !! "id": "TxIOPT0G5Lx1" +# !! }} +#@markdown **Model Path Variables** +# ask for the link +print("Local Path Variables:\n") + +models_path = "/content/models" #@param {type:"string"} +output_path = "/content/output" #@param {type:"string"} + +#@markdown **Google Drive Path Variables (Optional)** +mount_google_drive = True #@param {type:"boolean"} +force_remount = False + +if mount_google_drive: + from google.colab import drive + try: + drive_path = "/content/drive" + drive.mount(drive_path,force_remount=force_remount) + models_path_gdrive = "/content/drive/MyDrive/AI/models" #@param {type:"string"} + output_path_gdrive = "/content/drive/MyDrive/AI/StableDiffusion" #@param {type:"string"} + models_path = models_path_gdrive + output_path = output_path_gdrive + except: + print("...error mounting drive or with drive path variables") + print("...reverting to default path variables") + +os.makedirs(models_path, exist_ok=True) +os.makedirs(output_path, exist_ok=True) + +print(f"models_path: {models_path}") +print(f"output_path: {output_path}") + +# %% +# !! {"metadata":{ +# !! "cellView": "form", +# !! "id": "CIUJ7lWI4v53" +# !! }} +#@markdown **Select Model** +print("\nSelect Model:\n") + +model_config = "v1-inference.yaml" #@param ["custom","v1-inference.yaml"] +model_checkpoint = "sd-v1-4.ckpt" #@param ["custom","sd-v1-4.ckpt","sd-v1-3-full-ema.ckpt","sd-v1-3.ckpt","sd-v1-2-full-ema.ckpt","sd-v1-2.ckpt","sd-v1-1-full-ema.ckpt","sd-v1-1.ckpt"] +custom_config_path = "" #@param {type:"string"} +custom_checkpoint_path = "" #@param {type:"string"} + +check_sha256 = True #@param {type:"boolean"} + +model_map = { + "sd-v1-4.ckpt": {'sha256': 'fe4efff1e174c627256e44ec2991ba279b3816e364b49f9be2abc0b3ff3f8556'}, + "sd-v1-3-full-ema.ckpt": {'sha256': '54632c6e8a36eecae65e36cb0595fab314e1a1545a65209f24fde221a8d4b2ca'}, + "sd-v1-3.ckpt": {'sha256': '2cff93af4dcc07c3e03110205988ff98481e86539c51a8098d4f2236e41f7f2f'}, + "sd-v1-2-full-ema.ckpt": {'sha256': 'bc5086a904d7b9d13d2a7bccf38f089824755be7261c7399d92e555e1e9ac69a'}, + "sd-v1-2.ckpt": {'sha256': '3b87d30facd5bafca1cbed71cfb86648aad75d1c264663c0cc78c7aea8daec0d'}, + "sd-v1-1-full-ema.ckpt": {'sha256': 'efdeb5dc418a025d9a8cc0a8617e106c69044bc2925abecc8a254b2910d69829'}, + "sd-v1-1.ckpt": {'sha256': '86cd1d3ccb044d7ba8db743d717c9bac603c4043508ad2571383f954390f3cea'} +} + +def wget(url, outputdir): + res = subprocess.run(['wget', url, '-P', f'{outputdir}'], stdout=subprocess.PIPE).stdout.decode('utf-8') + print(res) + +def download_model(model_checkpoint): + download_link = model_map[model_checkpoint]["link"][0] + print(f"!wget -O {models_path}/{model_checkpoint} {download_link}") + wget(download_link, models_path) + return + +# config path +if os.path.exists(models_path+'/'+model_config): + print(f"{models_path+'/'+model_config} exists") +else: + print("cp ./stable-diffusion/configs/stable-diffusion/v1-inference.yaml $models_path/.") + shutil.copy('./stable-diffusion/configs/stable-diffusion/v1-inference.yaml', models_path) + +# checkpoint path or download +if os.path.exists(models_path+'/'+model_checkpoint): + print(f"{models_path+'/'+model_checkpoint} exists") +else: + print(f"download model checkpoint and place in {models_path+'/'+model_checkpoint}") + #download_model(model_checkpoint) + +if check_sha256: + import hashlib + print("\n...checking sha256") + with open(models_path+'/'+model_checkpoint, "rb") as f: + bytes = f.read() + hash = hashlib.sha256(bytes).hexdigest() + del bytes + if model_map[model_checkpoint]["sha256"] == hash: + print("hash is correct\n") + else: + print("hash in not correct\n") + +if model_config == "custom": + config = custom_config_path +else: + config = models_path+'/'+model_config + +if model_checkpoint == "custom": + ckpt = custom_checkpoint_path +else: + ckpt = models_path+'/'+model_checkpoint + +print(f"config: {config}") +print(f"ckpt: {ckpt}") + +# %% +# !! {"metadata":{ +# !! "cellView": "form", +# !! "id": "IJiMgz_96nr3" +# !! }} +#@markdown **Load Stable Diffusion** + +def load_model_from_config(config, ckpt, verbose=False, device='cuda'): + map_location = "cuda" #@param ["cpu", "cuda"] + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location=map_location) + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + #model.cuda() + model = model.half().to(device) + model.eval() + return model + +load_on_run_all = True #@param {type: 'boolean'} + +if load_on_run_all: + + local_config = OmegaConf.load(f"{config}") + model = load_model_from_config(local_config, f"{ckpt}") + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + model = model.to(device) + +# %% +# !! {"metadata":{ +# !! "id": "ov3r4RD1tzsT" +# !! }} +""" +# **Run** +""" + +# %% +# !! {"metadata":{ +# !! "id": "qH74gBWDd2oq" +# !! }} +def DeforumArgs(): + #@markdown **Save & Display Settings** + batchdir = "test" #@param {type:"string"} + outdir = get_output_folder(output_path, batchdir) + save_grid = False + save_samples = True #@param {type:"boolean"} + save_settings = True #@param {type:"boolean"} + display_grid = False + display_samples = True #@param {type:"boolean"} + + #@markdown **Image Settings** + n_samples = 1 #@param + n_rows = 1 #@param + W = 512 #@param + H = 576 #@param + W, H = map(lambda x: x - x % 64, (W, H)) # resize to integer multiple of 64 + + + #@markdown **Init Settings** + use_init = False #@param {type:"boolean"} + init_image = "/content/drive/MyDrive/AI/escape.jpg" #@param {type:"string"} + strength = 0.5 #@param {type:"number"} + + #@markdown **Sampling Settings** + seed = 1 #@param + sampler = 'euler_ancestral' #@param ["klms","dpm2","dpm2_ancestral","heun","euler","euler_ancestral","plms", "ddim"] + steps = 50 #@param + scale = 7 #@param + eta = 0.0 #@param + dynamic_threshold = None + static_threshold = None + + #@markdown **Batch Settings** + n_batch = 2 #@param + + precision = 'autocast' + fixed_code = True + C = 4 + f = 8 + prompts = [] + timestring = "" + + return locals() + +args = SimpleNamespace(**DeforumArgs()) + + +# %% +# !! {"metadata":{ +# !! "id": "2ujwkGZTcGev" +# !! }} +prompts = [ + "a beautiful forest by Asher Brown Durand, trending on Artstation", #the first prompt I want + "a beautiful portrait of a woman by Artgerm, trending on Artstation", #the second prompt I want + #"the third prompt I don't want it I commented it with an", +] + +# %% +# !! {"metadata":{ +# !! "cellView": "form", +# !! "id": "cxx8BzxjiaXg" +# !! }} +#@markdown **Run** +args = DeforumArgs() +args.filename = None +args.prompts = prompts + +def do_batch_run(): + # create output folder + os.makedirs(args.outdir, exist_ok=True) + + # current timestring for filenames + args.timestring = time.strftime('%Y%m%d%H%M%S') + + # save settings for the batch + if args.save_settings: + filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt") + with open(filename, "w+", encoding="utf-8") as f: + json.dump(dict(args.__dict__), f, ensure_ascii=False, indent=4) + + for batch_index in range(args.n_batch): + + # random seed + if args.seed == -1: + local_seed = np.random.randint(0,4294967295) + else: + local_seed = args.seed + + print(f"run {batch_index+1} of {args.n_batch}") + run(args, local_seed) + +do_batch_run() + +# %% +# !! {"main_metadata":{ +# !! "accelerator": "GPU", +# !! "colab": { +# !! "collapsed_sections": [], +# !! "name": "Deforum_Stable_Diffusion.ipynb", +# !! "provenance": [], +# !! "private_outputs": true +# !! }, +# !! "gpuClass": "standard", +# !! "kernelspec": { +# !! "display_name": "Python 3", +# !! "name": "python3" +# !! }, +# !! "language_info": { +# !! "name": "python" +# !! } +# !! }} diff --git a/convert_colab.py b/convert_colab.py new file mode 100644 index 0000000000..6907ef8d7c --- /dev/null +++ b/convert_colab.py @@ -0,0 +1,16 @@ +from colab_convert import convert +from argparse import ArgumentParser + +argp = ArgumentParser() +argp.add_argument( + "-o", + "--output", + default="./Deforum_Stable_Diffusion_test.ipynb", + help="Output ipynb colab file", +) + +convert( + "./Deforum_Stable_Diffusion.py", + argp.parse_args().output, + extra_flags={}, +) \ No newline at end of file diff --git a/embedding_manager.py b/embedding_manager.py new file mode 100644 index 0000000000..82d9188d7d --- /dev/null +++ b/embedding_manager.py @@ -0,0 +1,178 @@ +import torch +from torch import nn + +from ldm.data.personalized import per_img_token_list +from transformers import CLIPTokenizer +from functools import partial + +DEFAULT_PLACEHOLDER_TOKEN = ["*"] + +PROGRESSIVE_SCALE = 2000 + +def get_clip_token_for_string(tokenizer, string): + batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"] + #assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string" + + return tokens[0, 1] + +def get_bert_token_for_string(tokenizer, string): + token = tokenizer(string) + #assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string" + + token = token[0, 1] + + return token + +def get_embedding_for_clip_token(embedder, token): + return embedder(token.unsqueeze(0))[0, 0] + + +class EmbeddingManager(nn.Module): + def __init__( + self, + embedder, + placeholder_strings=None, + initializer_words=None, + per_image_tokens=False, + num_vectors_per_token=1, + progressive_words=False, + **kwargs + ): + super().__init__() + self.embedder = embedder + self.string_to_token_dict = {} + + self.string_to_param_dict = nn.ParameterDict() + + self.initial_embeddings = nn.ParameterDict() # These should not be optimized + + self.progressive_words = progressive_words + self.progressive_counter = 0 + + self.max_vectors_per_token = num_vectors_per_token + + if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder + self.is_clip = True + get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer) + get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.transformer.text_model.embeddings) + token_dim = 768 + else: # using LDM's BERT encoder + self.is_clip = False + get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn) + get_embedding_for_tkn = embedder.transformer.token_emb + token_dim = 1280 + + if per_image_tokens: + placeholder_strings.extend(per_img_token_list) + + for idx, placeholder_string in enumerate(placeholder_strings): + + token = get_token_for_string(placeholder_string) + + if initializer_words and idx < len(initializer_words): + init_word_token = get_token_for_string(initializer_words[idx]) + + with torch.no_grad(): + init_word_embedding = get_embedding_for_tkn(init_word_token.cpu()) + + token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True) + self.initial_embeddings[placeholder_string] = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=False) + else: + token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True)) + + self.string_to_token_dict[placeholder_string] = token + self.string_to_param_dict[placeholder_string] = token_params + + def forward( + self, + tokenized_text, + embedded_text, + ): + b, n, device = *tokenized_text.shape, tokenized_text.device + + for placeholder_string, placeholder_token in self.string_to_token_dict.items(): + + placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device) + + if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement + placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device)) + embedded_text[placeholder_idx] = placeholder_embedding + else: # otherwise, need to insert and keep track of changing indices + if self.progressive_words: + self.progressive_counter += 1 + max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE + else: + max_step_tokens = self.max_vectors_per_token + + num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens) + + placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device)) + + if placeholder_rows.nelement() == 0: + continue + + sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True) + sorted_rows = placeholder_rows[sort_idx] + + for idx in range(len(sorted_rows)): + row = sorted_rows[idx] + col = sorted_cols[idx] + + new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n] + new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n] + + embedded_text[row] = new_embed_row + tokenized_text[row] = new_token_row + + return embedded_text + + def save(self, ckpt_path): + torch.save({"string_to_token": self.string_to_token_dict, + "string_to_param": self.string_to_param_dict}, ckpt_path) + + def load(self, ckpt_path): + ckpt = torch.load(ckpt_path, map_location='cpu') + + # Handle .pt textual inversion files + if 'string_to_token' in ckpt and 'string_to_param' in ckpt: + self.string_to_token_dict = ckpt["string_to_token"] + self.string_to_param_dict = ckpt["string_to_param"] + + # Handle .bin textual inversion files from Huggingface Concepts + # https://huggingface.co/sd-concepts-library + else: + for token_str in list(ckpt.keys()): + token = get_clip_token_for_string(self.embedder.tokenizer, token_str) + self.string_to_token_dict[token_str] = token + ckpt[token_str] = torch.nn.Parameter(ckpt[token_str]) + + self.string_to_param_dict.update(ckpt) + if not full: + for key, value in self.string_to_param_dict.items(): + self.string_to_param_dict[key] = torch.nn.Parameter(value.half()) + + print(f'Added terms: {", ".join(self.string_to_param_dict.keys())}') + + def get_embedding_norms_squared(self): + all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim + param_norm_squared = (all_params * all_params).sum(axis=-1) # num_placeholders + + return param_norm_squared + + def embedding_parameters(self): + return self.string_to_param_dict.parameters() + + def embedding_to_coarse_loss(self): + + loss = 0. + num_embeddings = len(self.initial_embeddings) + + for key in self.initial_embeddings: + optimized = self.string_to_param_dict[key] + coarse = self.initial_embeddings[key].clone().to(optimized.device) + + loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings + + return loss diff --git a/helpers/__init__.py b/helpers/__init__.py new file mode 100644 index 0000000000..c0595ebe3f --- /dev/null +++ b/helpers/__init__.py @@ -0,0 +1,3 @@ +from .save_images import save_samples +from .k_samplers import sampler_fn +from .depth import DepthModel \ No newline at end of file diff --git a/helpers/depth.py b/helpers/depth.py new file mode 100644 index 0000000000..9f40aab22d --- /dev/null +++ b/helpers/depth.py @@ -0,0 +1,156 @@ +import math, os, subprocess +import cv2 +import numpy as np +import torch +import torchvision.transforms as T + +from einops import rearrange, repeat +from PIL import Image + +from infer import InferenceHelper +from midas.dpt_depth import DPTDepthModel +from midas.transforms import Resize, NormalizeImage, PrepareForNet + + +def wget(url, outputdir): + print(subprocess.run(['wget', url, '-P', outputdir], stdout=subprocess.PIPE).stdout.decode('utf-8')) + + +class DepthModel(): + def __init__(self, device): + self.adabins_helper = None + self.depth_min = 1000 + self.depth_max = -1000 + self.device = device + self.midas_model = None + self.midas_transform = None + + def load_adabins(self): + if not os.path.exists('pretrained/AdaBins_nyu.pt'): + print("Downloading AdaBins_nyu.pt...") + os.makedirs('pretrained', exist_ok=True) + wget("https://cloudflare-ipfs.com/ipfs/Qmd2mMnDLWePKmgfS8m6ntAg4nhV5VkUyAydYBp8cWWeB7/AdaBins_nyu.pt", 'pretrained') + self.adabins_helper = InferenceHelper(dataset='nyu', device=self.device) + + def load_midas(self, models_path, half_precision=True): + if not os.path.exists(os.path.join(models_path, 'dpt_large-midas-2f21e586.pt')): + print("Downloading dpt_large-midas-2f21e586.pt...") + wget("https://github.com/intel-isl/DPT/releases/download/1_0/dpt_large-midas-2f21e586.pt", models_path) + + self.midas_model = DPTDepthModel( + path=f"{models_path}/dpt_large-midas-2f21e586.pt", + backbone="vitl16_384", + non_negative=True, + ) + normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) + + self.midas_transform = T.Compose([ + Resize( + 384, 384, + resize_target=None, + keep_aspect_ratio=True, + ensure_multiple_of=32, + resize_method="minimal", + image_interpolation_method=cv2.INTER_CUBIC, + ), + normalization, + PrepareForNet() + ]) + + self.midas_model.eval() + if half_precision and self.device == torch.device("cuda"): + self.midas_model = self.midas_model.to(memory_format=torch.channels_last) + self.midas_model = self.midas_model.half() + self.midas_model.to(self.device) + + def predict(self, prev_img_cv2, anim_args) -> torch.Tensor: + w, h = prev_img_cv2.shape[1], prev_img_cv2.shape[0] + + # predict depth with AdaBins + use_adabins = anim_args.midas_weight < 1.0 and self.adabins_helper is not None + if use_adabins: + MAX_ADABINS_AREA = 500000 + MIN_ADABINS_AREA = 448*448 + + # resize image if too large or too small + img_pil = Image.fromarray(cv2.cvtColor(prev_img_cv2.astype(np.uint8), cv2.COLOR_RGB2BGR)) + image_pil_area = w*h + resized = True + if image_pil_area > MAX_ADABINS_AREA: + scale = math.sqrt(MAX_ADABINS_AREA) / math.sqrt(image_pil_area) + depth_input = img_pil.resize((int(w*scale), int(h*scale)), Image.LANCZOS) # LANCZOS is good for downsampling + print(f" resized to {depth_input.width}x{depth_input.height}") + elif image_pil_area < MIN_ADABINS_AREA: + scale = math.sqrt(MIN_ADABINS_AREA) / math.sqrt(image_pil_area) + depth_input = img_pil.resize((int(w*scale), int(h*scale)), Image.BICUBIC) + print(f" resized to {depth_input.width}x{depth_input.height}") + else: + depth_input = img_pil + resized = False + + # predict depth and resize back to original dimensions + try: + with torch.no_grad(): + _, adabins_depth = self.adabins_helper.predict_pil(depth_input) + if resized: + adabins_depth = TF.resize( + torch.from_numpy(adabins_depth), + torch.Size([h, w]), + interpolation=TF.InterpolationMode.BICUBIC + ) + adabins_depth = adabins_depth.squeeze() + except: + print(f" exception encountered, falling back to pure MiDaS") + use_adabins = False + torch.cuda.empty_cache() + + if self.midas_model is not None: + # convert image from 0->255 uint8 to 0->1 float for feeding to MiDaS + img_midas = prev_img_cv2.astype(np.float32) / 255.0 + img_midas_input = self.midas_transform({"image": img_midas})["image"] + + # MiDaS depth estimation implementation + sample = torch.from_numpy(img_midas_input).float().to(self.device).unsqueeze(0) + if self.device == torch.device("cuda"): + sample = sample.to(memory_format=torch.channels_last) + sample = sample.half() + with torch.no_grad(): + midas_depth = self.midas_model.forward(sample) + midas_depth = torch.nn.functional.interpolate( + midas_depth.unsqueeze(1), + size=img_midas.shape[:2], + mode="bicubic", + align_corners=False, + ).squeeze() + midas_depth = midas_depth.cpu().numpy() + torch.cuda.empty_cache() + + # MiDaS makes the near values greater, and the far values lesser. Let's reverse that and try to align with AdaBins a bit better. + midas_depth = np.subtract(50.0, midas_depth) + midas_depth = midas_depth / 19.0 + + # blend between MiDaS and AdaBins predictions + if use_adabins: + depth_map = midas_depth*anim_args.midas_weight + adabins_depth*(1.0-anim_args.midas_weight) + else: + depth_map = midas_depth + + depth_map = np.expand_dims(depth_map, axis=0) + depth_tensor = torch.from_numpy(depth_map).squeeze().to(self.device) + else: + depth_tensor = torch.ones((h, w), device=self.device) + + return depth_tensor + + def save(self, filename: str, depth: torch.Tensor): + depth = depth.cpu().numpy() + if len(depth.shape) == 2: + depth = np.expand_dims(depth, axis=0) + self.depth_min = min(self.depth_min, depth.min()) + self.depth_max = max(self.depth_max, depth.max()) + print(f" depth min:{depth.min()} max:{depth.max()}") + denom = max(1e-8, self.depth_max - self.depth_min) + temp = rearrange((depth - self.depth_min) / denom * 255, 'c h w -> h w c') + temp = repeat(temp, 'h w 1 -> h w c', c=3) + Image.fromarray(temp.astype(np.uint8)).save(filename) + diff --git a/helpers/k_samplers.py b/helpers/k_samplers.py new file mode 100644 index 0000000000..c0197b2f5e --- /dev/null +++ b/helpers/k_samplers.py @@ -0,0 +1,66 @@ +from typing import Any, Callable, Optional +import torch +from k_diffusion.external import CompVisDenoiser +from k_diffusion import sampling +from torch import nn + +class CFGDenoiser(nn.Module): + def __init__(self, model): + super().__init__() + self.inner_model = model + + def forward(self, x, sigma, uncond, cond, cond_scale): + x_in = torch.cat([x] * 2) + sigma_in = torch.cat([sigma] * 2) + cond_in = torch.cat([uncond, cond]) + uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + return uncond + (cond - uncond) * cond_scale + + +def sampler_fn( + c: torch.Tensor, + uc: torch.Tensor, + args, + model_wrap: CompVisDenoiser, + init_latent: Optional[torch.Tensor] = None, + t_enc: Optional[torch.Tensor] = None, + device=torch.device("cpu") + if not torch.cuda.is_available() + else torch.device("cuda"), + cb: Callable[[Any], None] = None, +) -> torch.Tensor: + shape = [args.C, args.H // args.f, args.W // args.f] + sigmas: torch.Tensor = model_wrap.get_sigmas(args.steps) + sigmas = sigmas[len(sigmas) - t_enc - 1 :] + if args.use_init: + if len(sigmas) > 0: + x = ( + init_latent + + torch.randn([args.n_samples, *shape], device=device) * sigmas[0] + ) + else: + x = init_latent + else: + if len(sigmas) > 0: + x = torch.randn([args.n_samples, *shape], device=device) * sigmas[0] + else: + x = torch.zeros([args.n_samples, *shape], device=device) + sampler_args = { + "model": CFGDenoiser(model_wrap), + "x": x, + "sigmas": sigmas, + "extra_args": {"cond": c, "uncond": uc, "cond_scale": args.scale}, + "disable": False, + "callback": cb, + } + sampler_map = { + "klms": sampling.sample_lms, + "dpm2": sampling.sample_dpm_2, + "dpm2_ancestral": sampling.sample_dpm_2_ancestral, + "heun": sampling.sample_heun, + "euler": sampling.sample_euler, + "euler_ancestral": sampling.sample_euler_ancestral, + } + + samples = sampler_map[args.sampler](**sampler_args) + return samples diff --git a/helpers/save_images.py b/helpers/save_images.py new file mode 100644 index 0000000000..a890c6df91 --- /dev/null +++ b/helpers/save_images.py @@ -0,0 +1,51 @@ +from typing import List, Tuple +from einops import rearrange +import numpy as np, os, torch +from PIL import Image +from torchvision.utils import make_grid + + +def save_samples( + args, x_samples: torch.Tensor, seed: int, n_rows: int +) -> Tuple[Image.Image, List[Image.Image]]: + """Function to save samples to disk. + Args: + args: Stable deforum diffusion arguments. + x_samples: Samples to save. + seed: Seed for the experiment. + n_rows: Number of rows in the grid. + Returns: + A tuple of the grid image and a list of the generated images. + ( grid_image, generated_images ) + """ + + # save samples + images = [] + grid_image = None + if args.display_samples or args.save_samples: + for index, x_sample in enumerate(x_samples): + x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c") + images.append(Image.fromarray(x_sample.astype(np.uint8))) + if args.save_samples: + images[-1].save( + os.path.join( + args.outdir, f"{args.timestring}_{index:02}_{seed}.png" + ) + ) + + # save grid + if args.display_grid or args.save_grid: + grid = torch.stack([x_samples], 0) + grid = rearrange(grid, "n b c h w -> (n b) c h w") + grid = make_grid(grid, nrow=n_rows, padding=0) + + # to image + grid = 255.0 * rearrange(grid, "c h w -> h w c").cpu().numpy() + grid_image = Image.fromarray(grid.astype(np.uint8)) + if args.save_grid: + grid_image.save( + os.path.join(args.outdir, f"{args.timestring}_{seed}_grid.png") + ) + + # return grid_image and individual sample images + return grid_image, images diff --git a/ldm/data/personalized.py b/ldm/data/personalized.py new file mode 100644 index 0000000000..3c147e8e78 --- /dev/null +++ b/ldm/data/personalized.py @@ -0,0 +1,173 @@ +import os +import numpy as np +import PIL +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + +import random + +imagenet_templates_smallest = [ + 'a photo of a {}', +] + +imagenet_templates_small = [ + 'a photo of a {}', + 'a rendering of a {}', + 'a cropped photo of the {}', + 'the photo of a {}', + 'a photo of a clean {}', + 'a photo of a dirty {}', + 'a dark photo of the {}', + 'a photo of my {}', + 'a photo of the cool {}', + 'a close-up photo of a {}', + 'a bright photo of the {}', + 'a cropped photo of a {}', + 'a photo of the {}', + 'a good photo of the {}', + 'a photo of one {}', + 'a close-up photo of the {}', + 'a rendition of the {}', + 'a photo of the clean {}', + 'a rendition of a {}', + 'a photo of a nice {}', + 'a good photo of a {}', + 'a photo of the nice {}', + 'a photo of the small {}', + 'a photo of the weird {}', + 'a photo of the large {}', + 'a photo of a cool {}', + 'a photo of a small {}', +] + +imagenet_dual_templates_small = [ + 'a photo of a {} with {}', + 'a rendering of a {} with {}', + 'a cropped photo of the {} with {}', + 'the photo of a {} with {}', + 'a photo of a clean {} with {}', + 'a photo of a dirty {} with {}', + 'a dark photo of the {} with {}', + 'a photo of my {} with {}', + 'a photo of the cool {} with {}', + 'a close-up photo of a {} with {}', + 'a bright photo of the {} with {}', + 'a cropped photo of a {} with {}', + 'a photo of the {} with {}', + 'a good photo of the {} with {}', + 'a photo of one {} with {}', + 'a close-up photo of the {} with {}', + 'a rendition of the {} with {}', + 'a photo of the clean {} with {}', + 'a rendition of a {} with {}', + 'a photo of a nice {} with {}', + 'a good photo of a {} with {}', + 'a photo of the nice {} with {}', + 'a photo of the small {} with {}', + 'a photo of the weird {} with {}', + 'a photo of the large {} with {}', + 'a photo of a cool {} with {}', + 'a photo of a small {} with {}', +] + +per_img_token_list = [ + 'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת', +] + +class PersonalizedBase(Dataset): + def __init__(self, + data_root, + size=None, + repeats=100, + interpolation="bicubic", + flip_p=0.5, + set="train", + placeholder_token="*", + per_image_tokens=False, + center_crop=False, + mixing_prob=0.25, + coarse_class_text=None, + ): + + self.data_root = data_root + + self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] + + # self._length = len(self.image_paths) + self.num_images = len(self.image_paths) + self._length = self.num_images + + self.placeholder_token = placeholder_token + + self.per_image_tokens = per_image_tokens + self.center_crop = center_crop + self.mixing_prob = mixing_prob + + self.coarse_class_text = coarse_class_text + + if per_image_tokens: + assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'." + + if set == "train": + self._length = self.num_images * repeats + + self.size = size + self.interpolation = {"linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = {} + image = Image.open(self.image_paths[i % self.num_images]) + + placeholder_string = self.placeholder_token + if self.coarse_class_text: + placeholder_string = f"{self.coarse_class_text} {placeholder_string}" + + image = image.convert('RGBA') + new_image = Image.new('RGBA', image.size, 'WHITE') + new_image.paste(image, (0, 0), image) + image = new_image.convert('RGB') + + templates = [ + 'a {} portrait of {}', + 'an {} image of {}', + 'a {} pretty picture of {}', + 'a {} clip art picture of {}', + 'an {} illustration of {}', + 'a {} 3D render of {}', + 'a {} {}', + ] + + filename = os.path.basename(self.image_paths[i % self.num_images]) + filename_tokens = os.path.splitext(filename)[0].replace(' ', '-').replace('_', '-').split('-') + filename_tokens = [token for token in filename_tokens if token.isalpha()] + + text = random.choice(templates).format(' '.join(filename_tokens), self.placeholder_token) + + example["caption"] = text + + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + + if self.center_crop: + crop = min(img.shape[0], img.shape[1]) + h, w, = img.shape[0], img.shape[1] + img = img[(h - crop) // 2:(h + crop) // 2, + (w - crop) // 2:(w + crop) // 2] + + image = Image.fromarray(img) + if self.size is not None: + image = image.resize((self.size, self.size), resample=self.interpolation) + + image = self.flip(image) + image = np.array(image).astype(np.uint8) + example["image"] = (image / 127.5 - 1.0).astype(np.float32) + return example \ No newline at end of file diff --git a/ldm/data/personalized_style.py b/ldm/data/personalized_style.py new file mode 100644 index 0000000000..1fefb6dd34 --- /dev/null +++ b/ldm/data/personalized_style.py @@ -0,0 +1,143 @@ +import os +import numpy as np +import PIL +from PIL import Image +from torch.utils.data import Dataset +from torchvision import transforms + +import random + +imagenet_templates_small = [ + 'a painting in the style of {}', + 'a rendering in the style of {}', + 'a cropped painting in the style of {}', + 'the painting in the style of {}', + 'a clean painting in the style of {}', + 'a dirty painting in the style of {}', + 'a dark painting in the style of {}', + 'a picture in the style of {}', + 'a cool painting in the style of {}', + 'a close-up painting in the style of {}', + 'a bright painting in the style of {}', + 'a cropped painting in the style of {}', + 'a good painting in the style of {}', + 'a close-up painting in the style of {}', + 'a rendition in the style of {}', + 'a nice painting in the style of {}', + 'a small painting in the style of {}', + 'a weird painting in the style of {}', + 'a large painting in the style of {}', +] + +imagenet_dual_templates_small = [ + 'a painting in the style of {} with {}', + 'a rendering in the style of {} with {}', + 'a cropped painting in the style of {} with {}', + 'the painting in the style of {} with {}', + 'a clean painting in the style of {} with {}', + 'a dirty painting in the style of {} with {}', + 'a dark painting in the style of {} with {}', + 'a cool painting in the style of {} with {}', + 'a close-up painting in the style of {} with {}', + 'a bright painting in the style of {} with {}', + 'a cropped painting in the style of {} with {}', + 'a good painting in the style of {} with {}', + 'a painting of one {} in the style of {}', + 'a nice painting in the style of {} with {}', + 'a small painting in the style of {} with {}', + 'a weird painting in the style of {} with {}', + 'a large painting in the style of {} with {}', +] + +per_img_token_list = [ + 'א', 'ב', 'ג', 'ד', 'ה', 'ו', 'ז', 'ח', 'ט', 'י', 'כ', 'ל', 'מ', 'נ', 'ס', 'ע', 'פ', 'צ', 'ק', 'ר', 'ש', 'ת', +] + +class PersonalizedBase(Dataset): + def __init__(self, + data_root, + size=None, + repeats=100, + interpolation="bicubic", + flip_p=0.5, + set="train", + placeholder_token="*", + per_image_tokens=False, + center_crop=False, + ): + + self.data_root = data_root + + self.image_paths = [os.path.join(self.data_root, file_path) for file_path in os.listdir(self.data_root)] + + # self._length = len(self.image_paths) + self.num_images = len(self.image_paths) + self._length = self.num_images + + self.placeholder_token = placeholder_token + + self.per_image_tokens = per_image_tokens + self.center_crop = center_crop + + if per_image_tokens: + assert self.num_images < len(per_img_token_list), f"Can't use per-image tokens when the training set contains more than {len(per_img_token_list)} tokens. To enable larger sets, add more tokens to 'per_img_token_list'." + + if set == "train": + self._length = self.num_images * repeats + + self.size = size + self.interpolation = {"linear": PIL.Image.LINEAR, + "bilinear": PIL.Image.BILINEAR, + "bicubic": PIL.Image.BICUBIC, + "lanczos": PIL.Image.LANCZOS, + }[interpolation] + self.flip = transforms.RandomHorizontalFlip(p=flip_p) + + def __len__(self): + return self._length + + def __getitem__(self, i): + example = {} + image = Image.open(self.image_paths[i % self.num_images]) + + image = image.convert('RGBA') + new_image = Image.new('RGBA', image.size, 'WHITE') + new_image.paste(image, (0, 0), image) + image = new_image.convert('RGB') + + templates = [ + 'a {} portrait of {}', + 'an {} image of {}', + 'a {} pretty picture of {}', + 'a {} clip art picture of {}', + 'an {} illustration of {}', + 'a {} 3D render of {}', + 'a {} {}', + ] + + filename = os.path.basename(self.image_paths[i % self.num_images]) + filename_tokens = os.path.splitext(filename)[0].replace('_', '-').split('-') + filename_tokens = [token for token in filename_tokens if token.isalpha()] + + text = random.choice(templates).format(' '.join(filename_tokens), self.placeholder_token) + print(text) + + example["caption"] = text + + # default to score-sde preprocessing + img = np.array(image).astype(np.uint8) + + if self.center_crop: + crop = min(img.shape[0], img.shape[1]) + h, w, = img.shape[0], img.shape[1] + img = img[(h - crop) // 2:(h + crop) // 2, + (w - crop) // 2:(w + crop) // 2] + + image = Image.fromarray(img) + if self.size is not None: + image = image.resize((self.size, self.size), resample=self.interpolation) + + image = self.flip(image) + image = np.array(image).astype(np.uint8) + example["image"] = (image / 127.5 - 1.0).astype(np.float32) + return example \ No newline at end of file diff --git a/ldm/models/diffusion/ddim.py b/ldm/models/diffusion/ddim.py index fb31215db5..105c38d4b3 100644 --- a/ldm/models/diffusion/ddim.py +++ b/ldm/models/diffusion/ddim.py @@ -2,6 +2,7 @@ import torch import numpy as np +#from tqdm.notebook import tqdm from tqdm import tqdm from functools import partial @@ -154,7 +155,7 @@ def ddim_sampling(self, cond, shape, unconditional_conditioning=unconditional_conditioning) img, pred_x0 = outs if callback: callback(i) - if img_callback: img_callback(pred_x0, i) + if img_callback: img_callback(img, i) if index % log_every_t == 0 or index == total_steps - 1: intermediates['x_inter'].append(img) @@ -221,7 +222,7 @@ def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): @torch.no_grad() def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unconditional_conditioning=None, - use_original_steps=False): + use_original_steps=False, img_callback=None): timesteps = np.arange(self.ddpm_num_timesteps) if use_original_steps else self.ddim_timesteps timesteps = timesteps[:t_start] @@ -234,8 +235,12 @@ def decode(self, x_latent, cond, t_start, unconditional_guidance_scale=1.0, unco x_dec = x_latent for i, step in enumerate(iterator): index = total_steps - i - 1 + ts = torch.full((x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long) x_dec, _ = self.p_sample_ddim(x_dec, cond, ts, index=index, use_original_steps=use_original_steps, unconditional_guidance_scale=unconditional_guidance_scale, unconditional_conditioning=unconditional_conditioning) - return x_dec \ No newline at end of file + + if img_callback: img_callback(x_dec, i) + + return x_dec diff --git a/ldm/models/diffusion/ddpm.py b/ldm/models/diffusion/ddpm.py index bbedd04cfd..96670b9c96 100644 --- a/ldm/models/diffusion/ddpm.py +++ b/ldm/models/diffusion/ddpm.py @@ -14,6 +14,7 @@ from einops import rearrange, repeat from contextlib import contextmanager from functools import partial +#from tqdm.notebook import tqdm from tqdm import tqdm from torchvision.utils import make_grid from pytorch_lightning.utilities.distributed import rank_zero_only diff --git a/ldm/models/diffusion/plms.py b/ldm/models/diffusion/plms.py index 78eeb1003a..f1a7cf5f56 100644 --- a/ldm/models/diffusion/plms.py +++ b/ldm/models/diffusion/plms.py @@ -2,6 +2,7 @@ import torch import numpy as np +#from tqdm.notebook import tqdm from tqdm import tqdm from functools import partial @@ -161,7 +162,7 @@ def plms_sampling(self, cond, shape, if len(old_eps) >= 4: old_eps.pop(0) if callback: callback(i) - if img_callback: img_callback(pred_x0, i) + if img_callback: img_callback(img, i) if index % log_every_t == 0 or index == total_steps - 1: intermediates['x_inter'].append(img) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index f4eff39ccb..f848a7c75f 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -1,3 +1,4 @@ +import gc from inspect import isfunction import math import torch @@ -170,27 +171,56 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0. def forward(self, x, context=None, mask=None): h = self.heads - q = self.to_q(x) + q_in = self.to_q(x) context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) + k_in = self.to_k(context) + v_in = self.to_v(context) + del context, x - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) - if exists(mask): - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4 + mem_required = tensor_size * 2.5 + steps = 1 - out = einsum('b i j, b j d -> b i d', attn, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - return self.to_out(out) + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + + if steps > 64: + max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' + f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale + + s2 = s1.softmax(dim=-1) + del s1 + + r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) + del s2 + + del q, k, v + + r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) + del r1 + + return self.to_out(r2) class BasicTransformerBlock(nn.Module): diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index 533e589a20..9be8922e9d 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -1,4 +1,5 @@ # pytorch_diffusion + derived encoder decoder +import gc import math import torch import torch.nn as nn @@ -119,18 +120,30 @@ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, padding=0) def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) + h1 = x + h2 = self.norm1(h1) + del h1 + + h3 = nonlinearity(h2) + del h2 + + h4 = self.conv1(h3) + del h3 if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + h4 = h4 + self.temb_proj(nonlinearity(temb))[:,:,None,None] - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) + h5 = self.norm2(h4) + del h4 + + h6 = nonlinearity(h5) + del h5 + + h7 = self.dropout(h6) + del h6 + + h8 = self.conv2(h7) + del h7 if self.in_channels != self.out_channels: if self.use_conv_shortcut: @@ -138,7 +151,7 @@ def forward(self, x, temb): else: x = self.nin_shortcut(x) - return x+h + return x + h8 class LinAttnBlock(LinearAttention): @@ -174,32 +187,68 @@ def __init__(self, in_channels): stride=1, padding=0) - def forward(self, x): h_ = x h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) + q1 = self.q(h_) + k1 = self.k(h_) v = self.v(h_) # compute attention - b,c,h,w = q.shape - q = q.reshape(b,c,h*w) - q = q.permute(0,2,1) # b,hw,c - k = k.reshape(b,c,h*w) # b,c,hw - w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c)**(-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) + b, c, h, w = q1.shape + + q2 = q1.reshape(b, c, h*w) + del q1 + + q = q2.permute(0, 2, 1) # b,hw,c + del q2 + + k = k1.reshape(b, c, h*w) # b,c,hw + del k1 + + h_ = torch.zeros_like(k, device=q.device) + + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + + tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * 4 + mem_required = tensor_size * 2.5 + steps = 1 + + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + + w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w2 = w1 * (int(c)**(-0.5)) + del w1 + w3 = torch.nn.functional.softmax(w2, dim=2) + del w2 + + # attend to values + v1 = v.reshape(b, c, h*w) + w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + del w3 - # attend to values - v = v.reshape(b,c,h*w) - w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b,c,h,w) + h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + del v1, w4 - h_ = self.proj_out(h_) + h2 = h_.reshape(b, c, h, w) + del h_ - return x+h_ + h3 = self.proj_out(h2) + del h2 + + h3 += x + + return h3 def make_attn(in_channels, attn_type="vanilla"): @@ -540,31 +589,54 @@ def forward(self, z): temb = None # z to block_in - h = self.conv_in(z) + h1 = self.conv_in(z) # middle - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) + h2 = self.mid.block_1(h1, temb) + del h1 + + h3 = self.mid.attn_1(h2) + del h2 + + h = self.mid.block_2(h3, temb) + del h3 + + # prepare for up sampling + gc.collect() + torch.cuda.empty_cache() # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks+1): h = self.up[i_level].block[i_block](h, temb) if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) + t = h + h = self.up[i_level].attn[i_block](t) + del t + if i_level != 0: - h = self.up[i_level].upsample(h) + t = h + h = self.up[i_level].upsample(t) + del t # end if self.give_pre_end: return h - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) + h1 = self.norm_out(h) + del h + + h2 = nonlinearity(h1) + del h1 + + h = self.conv_out(h2) + del h2 + if self.tanh_out: - h = torch.tanh(h) + t = h + h = torch.tanh(t) + del t + return h diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py index a952e6c403..17f9679a36 100644 --- a/ldm/modules/diffusionmodules/util.py +++ b/ldm/modules/diffusionmodules/util.py @@ -49,12 +49,17 @@ def make_ddim_timesteps(ddim_discr_method, num_ddim_timesteps, num_ddpm_timestep ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) elif ddim_discr_method == 'quad': ddim_timesteps = ((np.linspace(0, np.sqrt(num_ddpm_timesteps * .8), num_ddim_timesteps)) ** 2).astype(int) + elif ddim_discr_method == 'fill': + ddim_timesteps = np.linspace(0, num_ddpm_timesteps-1,num_ddim_timesteps+1).astype(int) else: raise NotImplementedError(f'There is no ddim discretization method called "{ddim_discr_method}"') # assert ddim_timesteps.shape[0] == num_ddim_timesteps - # add one to get the final alpha values right (the ones from first scale to data during sampling) - steps_out = ddim_timesteps + 1 + if ddim_discr_method == 'fill': + steps_out = ddim_timesteps + else: + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 if verbose: print(f'Selected timesteps for ddim sampler: {steps_out}') return steps_out diff --git a/ldm/modules/embedding_manager.py b/ldm/modules/embedding_manager.py new file mode 100644 index 0000000000..fb6ae420a9 --- /dev/null +++ b/ldm/modules/embedding_manager.py @@ -0,0 +1,255 @@ +from cmath import log +import torch +from torch import nn + +import sys + +from ldm.data.personalized import per_img_token_list +from transformers import CLIPTokenizer +from functools import partial + +DEFAULT_PLACEHOLDER_TOKEN = ['*'] + +PROGRESSIVE_SCALE = 2000 + + +def get_clip_token_for_string(tokenizer, string): + batch_encoding = tokenizer( + string, + truncation=True, + max_length=77, + return_length=True, + return_overflowing_tokens=False, + padding='max_length', + return_tensors='pt', + ) + tokens = batch_encoding['input_ids'] + """ assert ( + torch.count_nonzero(tokens - 49407) == 2 + ), f"String '{string}' maps to more than a single token. Please use another string" """ + + return tokens[0, 1] + + +def get_bert_token_for_string(tokenizer, string): + token = tokenizer(string) + # assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string" + + token = token[0, 1] + + return token + + +def get_embedding_for_clip_token(embedder, token): + return embedder(token.unsqueeze(0))[0, 0] + + +class EmbeddingManager(nn.Module): + def __init__( + self, + embedder, + placeholder_strings=None, + initializer_words=None, + per_image_tokens=False, + num_vectors_per_token=1, + progressive_words=False, + **kwargs, + ): + super().__init__() + + self.embedder = embedder + + self.string_to_token_dict = {} + self.string_to_param_dict = nn.ParameterDict() + + self.initial_embeddings = ( + nn.ParameterDict() + ) # These should not be optimized + + self.progressive_words = progressive_words + self.progressive_counter = 0 + + self.max_vectors_per_token = num_vectors_per_token + + if hasattr( + embedder, 'tokenizer' + ): # using Stable Diffusion's CLIP encoder + self.is_clip = True + get_token_for_string = partial( + get_clip_token_for_string, embedder.tokenizer + ) + get_embedding_for_tkn = partial( + get_embedding_for_clip_token, + embedder.transformer.text_model.embeddings, + ) + token_dim = 1280 + else: # using LDM's BERT encoder + self.is_clip = False + get_token_for_string = partial( + get_bert_token_for_string, embedder.tknz_fn + ) + get_embedding_for_tkn = embedder.transformer.token_emb + token_dim = 1280 + + if per_image_tokens: + placeholder_strings.extend(per_img_token_list) + + for idx, placeholder_string in enumerate(placeholder_strings): + + token = get_token_for_string(placeholder_string) + + if initializer_words and idx < len(initializer_words): + init_word_token = get_token_for_string(initializer_words[idx]) + + with torch.no_grad(): + init_word_embedding = get_embedding_for_tkn( + init_word_token.cpu() + ) + + token_params = torch.nn.Parameter( + init_word_embedding.unsqueeze(0).repeat( + num_vectors_per_token, 1 + ), + requires_grad=True, + ) + self.initial_embeddings[ + placeholder_string + ] = torch.nn.Parameter( + init_word_embedding.unsqueeze(0).repeat( + num_vectors_per_token, 1 + ), + requires_grad=False, + ) + else: + token_params = torch.nn.Parameter( + torch.rand( + size=(num_vectors_per_token, token_dim), + requires_grad=True, + ) + ) + + self.string_to_token_dict[placeholder_string] = token + self.string_to_param_dict[placeholder_string] = token_params + + def forward( + self, + tokenized_text, + embedded_text, + ): + b, n, device = *tokenized_text.shape, tokenized_text.device + + for ( + placeholder_string, + placeholder_token, + ) in self.string_to_token_dict.items(): + + placeholder_embedding = self.string_to_param_dict[ + placeholder_string + ].to(device) + + if ( + self.max_vectors_per_token == 1 + ): # If there's only one vector per token, we can do a simple replacement + placeholder_idx = torch.where( + tokenized_text == placeholder_token.to(device) + ) + embedded_text[placeholder_idx] = placeholder_embedding + else: # otherwise, need to insert and keep track of changing indices + if self.progressive_words: + self.progressive_counter += 1 + max_step_tokens = ( + 1 + self.progressive_counter // PROGRESSIVE_SCALE + ) + else: + max_step_tokens = self.max_vectors_per_token + + num_vectors_for_token = min( + placeholder_embedding.shape[0], max_step_tokens + ) + + placeholder_rows, placeholder_cols = torch.where( + tokenized_text == placeholder_token.to(device) + ) + + if placeholder_rows.nelement() == 0: + continue + + sorted_cols, sort_idx = torch.sort( + placeholder_cols, descending=True + ) + sorted_rows = placeholder_rows[sort_idx] + + for idx in range(len(sorted_rows)): + row = sorted_rows[idx] + col = sorted_cols[idx] + + new_token_row = torch.cat( + [ + tokenized_text[row][:col], + placeholder_token.repeat(num_vectors_for_token).to( + device + ), + tokenized_text[row][col + 1 :], + ], + axis=0, + )[:n] + new_embed_row = torch.cat( + [ + embedded_text[row][:col], + placeholder_embedding[:num_vectors_for_token], + embedded_text[row][col + 1 :], + ], + axis=0, + )[:n] + + embedded_text[row] = new_embed_row + tokenized_text[row] = new_token_row + + return embedded_text + + def save(self, ckpt_path): + torch.save( + { + 'string_to_token': self.string_to_token_dict, + 'string_to_param': self.string_to_param_dict, + }, + ckpt_path, + ) + + def load(self, ckpt_path): + ckpt = torch.load(ckpt_path, map_location='cpu') + + self.string_to_token_dict = ckpt["string_to_token"] + self.string_to_param_dict = ckpt["string_to_param"] + + + def get_embedding_norms_squared(self): + all_params = torch.cat( + list(self.string_to_param_dict.values()), axis=0 + ) # num_placeholders x embedding_dim + param_norm_squared = (all_params * all_params).sum( + axis=-1 + ) # num_placeholders + + return param_norm_squared + + def embedding_parameters(self): + return self.string_to_param_dict.parameters() + + def embedding_to_coarse_loss(self): + + loss = 0.0 + num_embeddings = len(self.initial_embeddings) + + for key in self.initial_embeddings: + optimized = self.string_to_param_dict[key] + coarse = self.initial_embeddings[key].clone().to(optimized.device) + + loss = ( + loss + + (optimized - coarse) + @ (optimized - coarse).T + / num_embeddings + ) + + return loss diff --git a/ldm/modules/embedding_managerbin.py b/ldm/modules/embedding_managerbin.py new file mode 100644 index 0000000000..25df677444 --- /dev/null +++ b/ldm/modules/embedding_managerbin.py @@ -0,0 +1,175 @@ +import torch +from torch import nn + +from ldm.data.personalized import per_img_token_list +from transformers import CLIPTokenizer +from functools import partial + +DEFAULT_PLACEHOLDER_TOKEN = ["*"] + +PROGRESSIVE_SCALE = 2000 + +def get_clip_token_for_string(tokenizer, string): + batch_encoding = tokenizer(string, truncation=True, max_length=77, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"] + #assert torch.count_nonzero(tokens - 49407) == 2, f"String '{string}' maps to more than a single token. Please use another string" + + return tokens[0, 1] + +def get_bert_token_for_string(tokenizer, string): + token = tokenizer(string) + assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string" + + token = token[0, 1] + + return token + +def get_embedding_for_clip_token(embedder, token): + return embedder(token.unsqueeze(0))[0, 0] + + +class EmbeddingManager(nn.Module): + def __init__( + self, + embedder, + placeholder_strings=None, + initializer_words=None, + per_image_tokens=False, + num_vectors_per_token=1, + progressive_words=False, + **kwargs + ): + super().__init__() + + self.string_to_token_dict = {} + + self.string_to_param_dict = nn.ParameterDict() + + self.initial_embeddings = nn.ParameterDict() # These should not be optimized + + self.progressive_words = progressive_words + self.progressive_counter = 0 + + self.max_vectors_per_token = num_vectors_per_token + + if hasattr(embedder, 'tokenizer'): # using Stable Diffusion's CLIP encoder + self.is_clip = True + get_token_for_string = partial(get_clip_token_for_string, embedder.tokenizer) + get_embedding_for_tkn = partial(get_embedding_for_clip_token, embedder.transformer.text_model.embeddings) + token_dim = 768 + else: # using LDM's BERT encoder + self.is_clip = False + get_token_for_string = partial(get_bert_token_for_string, embedder.tknz_fn) + get_embedding_for_tkn = embedder.transformer.token_emb + token_dim = 1280 + + if per_image_tokens: + placeholder_strings.extend(per_img_token_list) + + for idx, placeholder_string in enumerate(placeholder_strings): + + token = get_token_for_string(placeholder_string) + + if initializer_words and idx < len(initializer_words): + init_word_token = get_token_for_string(initializer_words[idx]) + + with torch.no_grad(): + init_word_embedding = get_embedding_for_tkn(init_word_token.cpu()) + + token_params = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=True) + self.initial_embeddings[placeholder_string] = torch.nn.Parameter(init_word_embedding.unsqueeze(0).repeat(num_vectors_per_token, 1), requires_grad=False) + else: + token_params = torch.nn.Parameter(torch.rand(size=(num_vectors_per_token, token_dim), requires_grad=True)) + + self.string_to_token_dict[placeholder_string] = token + self.string_to_param_dict[placeholder_string] = token_params + + def forward( + self, + tokenized_text, + embedded_text, + ): + b, n, device = *tokenized_text.shape, tokenized_text.device + + for placeholder_string, placeholder_token in self.string_to_token_dict.items(): + + placeholder_embedding = self.string_to_param_dict[placeholder_string].to(device) + + if self.max_vectors_per_token == 1: # If there's only one vector per token, we can do a simple replacement + placeholder_idx = torch.where(tokenized_text == placeholder_token.to(device)) + embedded_text[placeholder_idx] = placeholder_embedding + else: # otherwise, need to insert and keep track of changing indices + if self.progressive_words: + self.progressive_counter += 1 + max_step_tokens = 1 + self.progressive_counter // PROGRESSIVE_SCALE + else: + max_step_tokens = self.max_vectors_per_token + + num_vectors_for_token = min(placeholder_embedding.shape[0], max_step_tokens) + + placeholder_rows, placeholder_cols = torch.where(tokenized_text == placeholder_token.to(device)) + + if placeholder_rows.nelement() == 0: + continue + + sorted_cols, sort_idx = torch.sort(placeholder_cols, descending=True) + sorted_rows = placeholder_rows[sort_idx] + + for idx in range(len(sorted_rows)): + row = sorted_rows[idx] + col = sorted_cols[idx] + + new_token_row = torch.cat([tokenized_text[row][:col], placeholder_token.repeat(num_vectors_for_token).to(device), tokenized_text[row][col + 1:]], axis=0)[:n] + new_embed_row = torch.cat([embedded_text[row][:col], placeholder_embedding[:num_vectors_for_token], embedded_text[row][col + 1:]], axis=0)[:n] + + embedded_text[row] = new_embed_row + tokenized_text[row] = new_token_row + + return embedded_text + + def save(self, ckpt_path): + torch.save({"string_to_token": self.string_to_token_dict, + "string_to_param": self.string_to_param_dict}, ckpt_path) + + def load(self, ckpt_path): + ckpt = torch.load(ckpt_path, map_location='cpu') + if isinstance(ckpt, nn.ParameterDict): + self.string_to_token_dict = ckpt["string_to_token"] + self.string_to_param_dict = ckpt["string_to_param"] + else: + file_token = list(ckpt.keys())[0] + new_token = '*' + + tensor_size = ckpt[file_token].count_nonzero() + newt = ckpt[file_token].reshape(1, tensor_size) + newt = newt.half() + + nparam = nn.Parameter(data = newt, requires_grad=True) + + self.string_to_token_dict = {new_token: torch.tensor(265)} + self.string_to_param_dict = nn.ParameterDict({new_token: nparam}) + + print(f'Added terms: {", ".join(self.string_to_param_dict.keys())}') + + def get_embedding_norms_squared(self): + all_params = torch.cat(list(self.string_to_param_dict.values()), axis=0) # num_placeholders x embedding_dim + param_norm_squared = (all_params * all_params).sum(axis=-1) # num_placeholders + + return param_norm_squared + + def embedding_parameters(self): + return self.string_to_param_dict.parameters() + + def embedding_to_coarse_loss(self): + + loss = 0. + num_embeddings = len(self.initial_embeddings) + + for key in self.initial_embeddings: + optimized = self.string_to_param_dict[key] + coarse = self.initial_embeddings[key].clone().to(optimized.device) + + loss = loss + (optimized - coarse) @ (optimized - coarse).T / num_embeddings + + return loss diff --git a/ldm/modules/embedding_managerpt.py b/ldm/modules/embedding_managerpt.py new file mode 100644 index 0000000000..fb6ae420a9 --- /dev/null +++ b/ldm/modules/embedding_managerpt.py @@ -0,0 +1,255 @@ +from cmath import log +import torch +from torch import nn + +import sys + +from ldm.data.personalized import per_img_token_list +from transformers import CLIPTokenizer +from functools import partial + +DEFAULT_PLACEHOLDER_TOKEN = ['*'] + +PROGRESSIVE_SCALE = 2000 + + +def get_clip_token_for_string(tokenizer, string): + batch_encoding = tokenizer( + string, + truncation=True, + max_length=77, + return_length=True, + return_overflowing_tokens=False, + padding='max_length', + return_tensors='pt', + ) + tokens = batch_encoding['input_ids'] + """ assert ( + torch.count_nonzero(tokens - 49407) == 2 + ), f"String '{string}' maps to more than a single token. Please use another string" """ + + return tokens[0, 1] + + +def get_bert_token_for_string(tokenizer, string): + token = tokenizer(string) + # assert torch.count_nonzero(token) == 3, f"String '{string}' maps to more than a single token. Please use another string" + + token = token[0, 1] + + return token + + +def get_embedding_for_clip_token(embedder, token): + return embedder(token.unsqueeze(0))[0, 0] + + +class EmbeddingManager(nn.Module): + def __init__( + self, + embedder, + placeholder_strings=None, + initializer_words=None, + per_image_tokens=False, + num_vectors_per_token=1, + progressive_words=False, + **kwargs, + ): + super().__init__() + + self.embedder = embedder + + self.string_to_token_dict = {} + self.string_to_param_dict = nn.ParameterDict() + + self.initial_embeddings = ( + nn.ParameterDict() + ) # These should not be optimized + + self.progressive_words = progressive_words + self.progressive_counter = 0 + + self.max_vectors_per_token = num_vectors_per_token + + if hasattr( + embedder, 'tokenizer' + ): # using Stable Diffusion's CLIP encoder + self.is_clip = True + get_token_for_string = partial( + get_clip_token_for_string, embedder.tokenizer + ) + get_embedding_for_tkn = partial( + get_embedding_for_clip_token, + embedder.transformer.text_model.embeddings, + ) + token_dim = 1280 + else: # using LDM's BERT encoder + self.is_clip = False + get_token_for_string = partial( + get_bert_token_for_string, embedder.tknz_fn + ) + get_embedding_for_tkn = embedder.transformer.token_emb + token_dim = 1280 + + if per_image_tokens: + placeholder_strings.extend(per_img_token_list) + + for idx, placeholder_string in enumerate(placeholder_strings): + + token = get_token_for_string(placeholder_string) + + if initializer_words and idx < len(initializer_words): + init_word_token = get_token_for_string(initializer_words[idx]) + + with torch.no_grad(): + init_word_embedding = get_embedding_for_tkn( + init_word_token.cpu() + ) + + token_params = torch.nn.Parameter( + init_word_embedding.unsqueeze(0).repeat( + num_vectors_per_token, 1 + ), + requires_grad=True, + ) + self.initial_embeddings[ + placeholder_string + ] = torch.nn.Parameter( + init_word_embedding.unsqueeze(0).repeat( + num_vectors_per_token, 1 + ), + requires_grad=False, + ) + else: + token_params = torch.nn.Parameter( + torch.rand( + size=(num_vectors_per_token, token_dim), + requires_grad=True, + ) + ) + + self.string_to_token_dict[placeholder_string] = token + self.string_to_param_dict[placeholder_string] = token_params + + def forward( + self, + tokenized_text, + embedded_text, + ): + b, n, device = *tokenized_text.shape, tokenized_text.device + + for ( + placeholder_string, + placeholder_token, + ) in self.string_to_token_dict.items(): + + placeholder_embedding = self.string_to_param_dict[ + placeholder_string + ].to(device) + + if ( + self.max_vectors_per_token == 1 + ): # If there's only one vector per token, we can do a simple replacement + placeholder_idx = torch.where( + tokenized_text == placeholder_token.to(device) + ) + embedded_text[placeholder_idx] = placeholder_embedding + else: # otherwise, need to insert and keep track of changing indices + if self.progressive_words: + self.progressive_counter += 1 + max_step_tokens = ( + 1 + self.progressive_counter // PROGRESSIVE_SCALE + ) + else: + max_step_tokens = self.max_vectors_per_token + + num_vectors_for_token = min( + placeholder_embedding.shape[0], max_step_tokens + ) + + placeholder_rows, placeholder_cols = torch.where( + tokenized_text == placeholder_token.to(device) + ) + + if placeholder_rows.nelement() == 0: + continue + + sorted_cols, sort_idx = torch.sort( + placeholder_cols, descending=True + ) + sorted_rows = placeholder_rows[sort_idx] + + for idx in range(len(sorted_rows)): + row = sorted_rows[idx] + col = sorted_cols[idx] + + new_token_row = torch.cat( + [ + tokenized_text[row][:col], + placeholder_token.repeat(num_vectors_for_token).to( + device + ), + tokenized_text[row][col + 1 :], + ], + axis=0, + )[:n] + new_embed_row = torch.cat( + [ + embedded_text[row][:col], + placeholder_embedding[:num_vectors_for_token], + embedded_text[row][col + 1 :], + ], + axis=0, + )[:n] + + embedded_text[row] = new_embed_row + tokenized_text[row] = new_token_row + + return embedded_text + + def save(self, ckpt_path): + torch.save( + { + 'string_to_token': self.string_to_token_dict, + 'string_to_param': self.string_to_param_dict, + }, + ckpt_path, + ) + + def load(self, ckpt_path): + ckpt = torch.load(ckpt_path, map_location='cpu') + + self.string_to_token_dict = ckpt["string_to_token"] + self.string_to_param_dict = ckpt["string_to_param"] + + + def get_embedding_norms_squared(self): + all_params = torch.cat( + list(self.string_to_param_dict.values()), axis=0 + ) # num_placeholders x embedding_dim + param_norm_squared = (all_params * all_params).sum( + axis=-1 + ) # num_placeholders + + return param_norm_squared + + def embedding_parameters(self): + return self.string_to_param_dict.parameters() + + def embedding_to_coarse_loss(self): + + loss = 0.0 + num_embeddings = len(self.initial_embeddings) + + for key in self.initial_embeddings: + optimized = self.string_to_param_dict[key] + coarse = self.initial_embeddings[key].clone().to(optimized.device) + + loss = ( + loss + + (optimized - coarse) + @ (optimized - coarse).T + / num_embeddings + ) + + return loss diff --git a/replicate/cog.yaml b/replicate/cog.yaml new file mode 100644 index 0000000000..7482fa99d4 --- /dev/null +++ b/replicate/cog.yaml @@ -0,0 +1,35 @@ +build: + gpu: true + cuda: "11.3" + python_version: "3.8" + system_packages: + - "libgl1-mesa-glx" + - "libglib2.0-0" + python_packages: + - "ipython==8.4.0" + - "pandas==1.4.4" + - "scikit-image==0.19.3" + - "clean_fid==0.1.28" + - "torch==1.11.0 --extra-index-url=https://download.pytorch.org/whl/cu113" + - "torchvision==0.12.0 --extra-index-url=https://download.pytorch.org/whl/cu113" + - "ftfy==6.1.1" + - "scipy==1.9.0" + - "transformers==4.21.1" + - "omegaconf==2.1.1" + - "einops==0.3.0" + - "pytorch-lightning==1.4.2" + - "torchmetrics==0.6.0" + - "kornia==0.6" + - "accelerate==0.12.0" + - "jsonmerge==1.8.0" + - "resize-right==0.0.2" + - "torchdiffeq==0.2.3" + - "opencv-python==4.6.0.66" + + run: + - mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/checkpoint_liberty_with_aug.pth" "https://github.com/DagnyT/hardnet/raw/master/pretrained/train_liberty_with_aug/checkpoint_liberty_with_aug.pth" + - apt-get update && apt-get install -y software-properties-common + - add-apt-repository ppa:ubuntu-toolchain-r/test + - apt update -y && apt-get install ffmpeg -y + +predict: "predict.py:Predictor" diff --git a/replicate/replicate.py b/replicate/replicate.py new file mode 100644 index 0000000000..ca78dd3f30 --- /dev/null +++ b/replicate/replicate.py @@ -0,0 +1,1026 @@ +""" +clone/install the following repo beforehand +git clone https://github.com/deforum/stable-diffusion +git clone https://github.com/deforum/k-diffusion +pip install -e git+https://github.com/CompVis/taming-transformers.git@master#egg=taming-transformers +pip install -e git+https://github.com/openai/CLIP.git@main#egg=clip + +weights for openai/clip-vit-large-patch14 and stable-diffusion sd-v1-4.ckpt are downloaded to ./weights +in ./stable-diffusion/ldm/modules/encoders/modules.py, load from local weights (local_files_only=True) for FrozenCLIPEmbedder() +""" + +import os +from typing import Optional, List +from collections import OrderedDict +from PIL import Image +from itertools import islice +import shutil +import json +from IPython import display +import argparse, glob, os, pathlib, subprocess, sys, time +import cv2 +import numpy as np +import pandas as pd +import random +import requests +import shutil +import torch +import torch.nn as nn +from torch import autocast +import torchvision.transforms as T +import torchvision.transforms.functional as TF +from torchvision.utils import make_grid +from contextlib import contextmanager, nullcontext +from einops import rearrange, repeat +from omegaconf import OmegaConf +from pytorch_lightning import seed_everything +from skimage.exposure import match_histograms +from tqdm import tqdm, trange +from types import SimpleNamespace +import subprocess +from base64 import b64encode +from cog import BasePredictor, Input, Path + +sys.path.append("./src/taming-transformers") +sys.path.append("./src/clip") +sys.path.append("./stable-diffusion/") +sys.path.append("./k-diffusion") + +from helpers import save_samples, sampler_fn +from ldm.util import instantiate_from_config +from ldm.models.diffusion.ddim import DDIMSampler +from ldm.models.diffusion.plms import PLMSSampler + +from k_diffusion import sampling +from k_diffusion.external import CompVisDenoiser + + +class Predictor(BasePredictor): + def setup(self): + """Load the model into memory to make running multiple predictions efficient""" + ckpt_config_path = ( + "./stable-diffusion/configs/stable-diffusion/v1-inference.yaml" + ) + ckpt_path = "./weights/sd-v1-4.ckpt" + local_config = OmegaConf.load(f"{ckpt_config_path}") + + half_precision = True + self.model = load_model_from_config( + local_config, f"{ckpt_path}", half_precision=half_precision + ) + self.device = ( + torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + ) + self.model = self.model.to(self.device) + + def predict( + self, + max_frames: int = Input( + description="Number of frames for animation", ge=100, le=1000, default=30 + ), + animation_prompts: str = Input( + default="0: a beautiful portrait of a woman by Artgerm, trending on Artstation", + description="Prompt for animation. Provide 'frame number : prompt at this frame', separate different prompts with '|'. Make sure the frame number does not exceed the max_frames.", + ), + angle: str = Input( + description="angle parameter for the motion", default="0:(0)" + ), + zoom: str = Input( + description="zoom parameter for the motion", default="0: (1.04)" + ), + translation_x: str = Input( + description="translation_x parameter for the motion", default="0: (0)" + ), + translation_y: str = Input( + description="translation_y parameter for the motion", default="0: (0)" + ), + color_coherence: str = Input( + choices=[ + "None", + "Match Frame 0 HSV", + "Match Frame 0 LAB", + "Match Frame 0 RGB", + ], + default="Match Frame 0 LAB", + ), + sampler: str = Input( + choices=[ + "klms", + "dpm2", + "dpm2_ancestral", + "heun", + "euler", + "euler_ancestral", + "plms", + "ddim", + ], + default="plms", + ), + fps: int = Input( + default=15, ge=10, le=60, description="Choose fps for the video." + ), + seed: int = Input( + description="Random seed. Leave blank to randomize the seed", default=None + ), + ) -> Path: + """Run a single prediction on the model""" + + # sanity checks: + animation_prompts_dict = {} + animation_prompts = animation_prompts.split("|") + assert len(animation_prompts) > 0, "Please provide valid prompt for animation." + if len(animation_prompts) == 1: + animation_prompts = {0: animation_prompts[0]} + else: + for frame_prompt in animation_prompts: + frame_prompt = frame_prompt.split(":") + assert ( + len(frame_prompt) == 2 + ), "Please follow the 'frame_num: prompt' format." + frame_id, prompt = frame_prompt[0].strip(), frame_prompt[1].strip() + assert ( + frame_id.isdigit() and 0<= int(frame_id) <= max_frames + ), "frame_num should be an integer and 0<= frame_num <= max_frames" + assert ( + int(frame_id) not in animation_prompts_dict + ), f"Duplicate prompts for frame_num {frame_id}. " + assert len(prompt) > 0, "prompt cannot be empty" + animation_prompts_dict[int(frame_id)] = prompt + animation_prompts = OrderedDict(sorted(animation_prompts_dict.items())) + + outdir = "cog_out" + if os.path.exists(outdir): + shutil.rmtree(outdir) + os.makedirs(outdir) + + # load default args + anim_args = SimpleNamespace(**DeforumAnimArgs()) + + # overwrite with user input + anim_args.max_frames = max_frames + anim_args.angle = angle + anim_args.zoom = zoom + anim_args.translation_x = translation_x + anim_args.translation_y = translation_y + anim_args.color_coherence = color_coherence + + if anim_args.animation_mode == "None": + anim_args.max_frames = 1 + + if anim_args.key_frames: + anim_args.angle_series = get_inbetweens( + anim_args, parse_key_frames(anim_args.angle) + ) + anim_args.zoom_series = get_inbetweens( + anim_args, parse_key_frames(anim_args.zoom) + ) + anim_args.translation_x_series = get_inbetweens( + anim_args, parse_key_frames(anim_args.translation_x) + ) + anim_args.translation_y_series = get_inbetweens( + anim_args, parse_key_frames(anim_args.translation_y) + ) + anim_args.noise_schedule_series = get_inbetweens( + anim_args, parse_key_frames(anim_args.noise_schedule) + ) + anim_args.strength_schedule_series = get_inbetweens( + anim_args, parse_key_frames(anim_args.strength_schedule) + ) + anim_args.contrast_schedule_series = get_inbetweens( + anim_args, parse_key_frames(anim_args.contrast_schedule) + ) + + args = SimpleNamespace(**DeforumArgs()) + args.timestring = time.strftime("%Y%m%d%H%M%S") + args.strength = max(0.0, min(1.0, args.strength)) + + if seed is None: + seed = int.from_bytes(os.urandom(2), "big") + print(f"Using seed: {seed}") + args.seed = seed + args.outdir = outdir + + if anim_args.animation_mode == "Video Input": + args.use_init = True + if not args.use_init: + args.init_image = None + args.strength = 0 + if args.sampler == "plms" and ( + args.use_init or anim_args.animation_mode != "None" + ): + print(f"Init images aren't supported with PLMS yet, switching to KLMS") + args.sampler = "klms" + if args.sampler != "ddim": + args.ddim_eta = 0 + + if anim_args.animation_mode == "2D": + anim_args.animation_prompts = animation_prompts + render_animation(args, anim_args, self.model, self.device) + elif anim_args.animation_mode == "Video Input": + render_input_video(args, anim_args, self.model, self.device) + elif anim_args.animation_mode == "Interpolation": + render_interpolation(args, anim_args, self.model, self.device) + else: + render_image_batch(args, prompts, self.model, self.device) + + # make video + image_path = os.path.join(args.outdir, f"{args.timestring}_%05d.png") + mp4_path = f"/tmp/out.mp4" + + # make video + cmd = [ + "ffmpeg", + "-y", + "-vcodec", + "png", + "-r", + str(fps), + "-start_number", + str(0), + "-i", + image_path, + "-frames:v", + str(anim_args.max_frames), + "-c:v", + "libx264", + "-vf", + f"fps={fps}", + "-pix_fmt", + "yuv420p", + "-crf", + "17", + "-preset", + "veryfast", + mp4_path, + ] + process = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE) + stdout, stderr = process.communicate() + if process.returncode != 0: + print(stderr) + raise RuntimeError(stderr) + + return Path(mp4_path) + + +def DeforumArgs(): + # Save & Display Settings + batch_name = "StableFun" + outdir = "cog_output" + save_settings = False + save_samples = True + display_samples = False + + # Image Settings + n_samples = 1 # hidden + W = 512 + H = 512 + W, H = map(lambda x: x - x % 64, (W, H)) # resize to integer multiple of 64 + + # Init Settings + use_init = False + strength = 0.5 + init_image = "https://cdn.pixabay.com/photo/2022/07/30/13/10/green-longhorn-beetle-7353749_1280.jpg" + + # Sampling Settings + seed = -1 + sampler = "klms" + steps = 50 + scale = 7 + ddim_eta = 0.0 + dynamic_threshold = None + static_threshold = None + + # Batch Settings + n_batch = 1 + seed_behavior = "iter" + + # Grid Settings + make_grid = False + grid_rows = 2 + + precision = "autocast" + fixed_code = True + C = 4 + f = 8 + + prompt = "" + timestring = "" + init_latent = None + init_sample = None + init_c = None + + return locals() + + +def DeforumAnimArgs(): + + # Animation + animation_mode = "2D" + max_frames = 1000 + border = "wrap" + + # Motion Parameters + key_frames = True + interp_spline = "Linear" + angle = "0:(0)" + zoom = "0: (1.04)" + translation_x = "0: (0)" + translation_y = "0: (0)" + noise_schedule = "0: (0.02)" + strength_schedule = "0: (0.65)" + contrast_schedule = "0: (1.0)" + + # Coherence + color_coherence = "Match Frame 0 LAB" + + # Video Input + video_init_path = "/content/video_in.mp4" + extract_nth_frame = 1 + + # Interpolation + interpolate_key_frames = False + interpolate_x_frames = 4 + + # Resume Animation + resume_from_timestring = False + resume_timestring = " " + return locals() + + +def load_model_from_config( + config, ckpt, verbose=False, device="cuda", half_precision=True +): + map_location = "cuda" + print(f"Loading model from {ckpt}") + pl_sd = torch.load(ckpt, map_location=map_location) + if "global_step" in pl_sd: + print(f"Global Step: {pl_sd['global_step']}") + sd = pl_sd["state_dict"] + model = instantiate_from_config(config.model) + m, u = model.load_state_dict(sd, strict=False) + if len(m) > 0 and verbose: + print("missing keys:") + print(m) + if len(u) > 0 and verbose: + print("unexpected keys:") + print(u) + + if half_precision: + model = model.half().to(device) + else: + model = model.to(device) + model.eval() + return model + + +class CFGDenoiser(nn.Module): + def __init__(self, model): + super().__init__() + self.inner_model = model + + def forward(self, x, sigma, uncond, cond, cond_scale): + x_in = torch.cat([x] * 2) + sigma_in = torch.cat([sigma] * 2) + cond_in = torch.cat([uncond, cond]) + uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) + return uncond + (cond - uncond) * cond_scale + + +def add_noise(sample: torch.Tensor, noise_amt: float): + return sample + torch.randn(sample.shape, device=sample.device) * noise_amt + + +def load_img(path, shape): + if path.startswith("http://") or path.startswith("https://"): + image = Image.open(requests.get(path, stream=True).raw).convert("RGB") + else: + image = Image.open(path).convert("RGB") + + image = image.resize(shape, resample=Image.LANCZOS) + image = np.array(image).astype(np.float16) / 255.0 + image = image[None].transpose(0, 3, 1, 2) + image = torch.from_numpy(image) + return 2.0 * image - 1.0 + + +def maintain_colors(prev_img, color_match_sample, mode): + if mode == "Match Frame 0 RGB": + return match_histograms(prev_img, color_match_sample, multichannel=True) + elif mode == "Match Frame 0 HSV": + prev_img_hsv = cv2.cvtColor(prev_img, cv2.COLOR_RGB2HSV) + color_match_hsv = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2HSV) + matched_hsv = match_histograms(prev_img_hsv, color_match_hsv, multichannel=True) + return cv2.cvtColor(matched_hsv, cv2.COLOR_HSV2RGB) + else: # Match Frame 0 LAB + prev_img_lab = cv2.cvtColor(prev_img, cv2.COLOR_RGB2LAB) + color_match_lab = cv2.cvtColor(color_match_sample, cv2.COLOR_RGB2LAB) + matched_lab = match_histograms(prev_img_lab, color_match_lab, multichannel=True) + return cv2.cvtColor(matched_lab, cv2.COLOR_LAB2RGB) + + +def make_callback(sampler, dynamic_threshold=None, static_threshold=None): + # Creates the callback function to be passed into the samplers + # The callback function is applied to the image after each step + def dynamic_thresholding_(img, threshold): + # Dynamic thresholding from Imagen paper (May 2022) + s = np.percentile(np.abs(img.cpu()), threshold, axis=tuple(range(1, img.ndim))) + s = np.max(np.append(s, 1.0)) + torch.clamp_(img, -1 * s, s) + torch.FloatTensor.div_(img, s) + + # Callback for samplers in the k-diffusion repo, called thus: + # callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + def k_callback(args_dict): + if static_threshold is not None: + torch.clamp_(args_dict["x"], -1 * static_threshold, static_threshold) + if dynamic_threshold is not None: + dynamic_thresholding_(args_dict["x"], dynamic_threshold) + + # Function that is called on the image (img) and step (i) at each step + def img_callback(img, i): + # Thresholding functions + if dynamic_threshold is not None: + dynamic_thresholding_(img, dynamic_threshold) + if static_threshold is not None: + torch.clamp_(img, -1 * static_threshold, static_threshold) + + if sampler in ["plms", "ddim"]: + # Callback function formated for compvis latent diffusion samplers + callback = img_callback + else: + # Default callback function uses k-diffusion sampler variables + callback = k_callback + + return callback + + +def generate( + args, model, device, return_latent=False, return_sample=False, return_c=False +): + seed_everything(args.seed) + os.makedirs(args.outdir, exist_ok=True) + + if args.sampler == "plms": + sampler = PLMSSampler(model) + else: + sampler = DDIMSampler(model) + + model_wrap = CompVisDenoiser(model) + batch_size = args.n_samples + prompt = args.prompt + assert prompt is not None + data = [batch_size * [prompt]] + + init_latent = None + if args.init_latent is not None: + init_latent = args.init_latent + elif args.init_sample is not None: + init_latent = model.get_first_stage_encoding( + model.encode_first_stage(args.init_sample) + ) + elif args.init_image != None and args.init_image != "": + init_image = load_img(args.init_image, shape=(args.W, args.H)).to(device) + init_image = repeat(init_image, "1 ... -> b ...", b=batch_size) + init_latent = model.get_first_stage_encoding( + model.encode_first_stage(init_image) + ) # move to latent space + + sampler.make_schedule( + ddim_num_steps=args.steps, ddim_eta=args.ddim_eta, verbose=False + ) + + t_enc = int((1.0 - args.strength) * args.steps) + + start_code = None + if args.fixed_code and init_latent == None: + start_code = torch.randn( + [args.n_samples, args.C, args.H // args.f, args.W // args.f], device=device + ) + + callback = make_callback( + sampler=args.sampler, + dynamic_threshold=args.dynamic_threshold, + static_threshold=args.static_threshold, + ) + + results = [] + precision_scope = autocast if args.precision == "autocast" else nullcontext + with torch.no_grad(): + with precision_scope("cuda"): + with model.ema_scope(): + for prompts in data: + uc = None + if args.scale != 1.0: + uc = model.get_learned_conditioning(batch_size * [""]) + if isinstance(prompts, tuple): + prompts = list(prompts) + c = model.get_learned_conditioning(prompts) + + if args.init_c != None: + c = args.init_c + + if args.sampler in [ + "klms", + "dpm2", + "dpm2_ancestral", + "heun", + "euler", + "euler_ancestral", + ]: + samples = sampler_fn( + c=c, + uc=uc, + args=args, + model_wrap=model_wrap, + init_latent=init_latent, + t_enc=t_enc, + device=device, + cb=callback, + ) + else: + + if init_latent != None: + z_enc = sampler.stochastic_encode( + init_latent, + torch.tensor([t_enc] * batch_size).to(device), + ) + samples = sampler.decode( + z_enc, + c, + t_enc, + unconditional_guidance_scale=args.scale, + unconditional_conditioning=uc, + ) + else: + if args.sampler == "plms" or args.sampler == "ddim": + shape = [args.C, args.H // args.f, args.W // args.f] + samples, _ = sampler.sample( + S=args.steps, + conditioning=c, + batch_size=args.n_samples, + shape=shape, + verbose=False, + unconditional_guidance_scale=args.scale, + unconditional_conditioning=uc, + eta=args.ddim_eta, + x_T=start_code, + img_callback=callback, + ) + + if return_latent: + results.append(samples.clone()) + + x_samples = model.decode_first_stage(samples) + if return_sample: + results.append(x_samples.clone()) + + x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0) + + if return_c: + results.append(c.clone()) + + for x_sample in x_samples: + x_sample = 255.0 * rearrange( + x_sample.cpu().numpy(), "c h w -> h w c" + ) + image = Image.fromarray(x_sample.astype(np.uint8)) + results.append(image) + return results + + +def sample_from_cv2(sample: np.ndarray) -> torch.Tensor: + sample = ((sample.astype(float) / 255.0) * 2) - 1 + sample = sample[None].transpose(0, 3, 1, 2).astype(np.float16) + sample = torch.from_numpy(sample) + return sample + + +def sample_to_cv2(sample: torch.Tensor) -> np.ndarray: + sample_f32 = rearrange(sample.squeeze().cpu().numpy(), "c h w -> h w c").astype( + np.float32 + ) + sample_f32 = ((sample_f32 * 0.5) + 0.5).clip(0, 1) + sample_int8 = (sample_f32 * 255).astype(np.uint8) + return sample_int8 + + +def make_xform_2d(width, height, translation_x, translation_y, angle, scale): + center = (width // 2, height // 2) + trans_mat = np.float32([[1, 0, translation_x], [0, 1, translation_y]]) + rot_mat = cv2.getRotationMatrix2D(center, angle, scale) + trans_mat = np.vstack([trans_mat, [0, 0, 1]]) + rot_mat = np.vstack([rot_mat, [0, 0, 1]]) + return np.matmul(rot_mat, trans_mat) + + +def parse_key_frames(string, prompt_parser=None): + import re + + pattern = r"((?P[0-9]+):[\s]*[\(](?P[\S\s]*?)[\)])" + frames = dict() + for match_object in re.finditer(pattern, string): + frame = int(match_object.groupdict()["frame"]) + param = match_object.groupdict()["param"] + if prompt_parser: + frames[frame] = prompt_parser(param) + else: + frames[frame] = param + if frames == {} and len(string) != 0: + raise RuntimeError("Key Frame string not correctly formatted") + return frames + + +def get_inbetweens(anim_args, key_frames, integer=False): + key_frame_series = pd.Series([np.nan for a in range(anim_args.max_frames)]) + + for i, value in key_frames.items(): + key_frame_series[i] = value + key_frame_series = key_frame_series.astype(float) + + interp_method = anim_args.interp_spline + if interp_method == "Cubic" and len(key_frames.items()) <= 3: + interp_method = "Quadratic" + if interp_method == "Quadratic" and len(key_frames.items()) <= 2: + interp_method = "Linear" + + key_frame_series[0] = key_frame_series[key_frame_series.first_valid_index()] + key_frame_series[anim_args.max_frames - 1] = key_frame_series[ + key_frame_series.last_valid_index() + ] + key_frame_series = key_frame_series.interpolate( + method=interp_method.lower(), limit_direction="both" + ) + if integer: + return key_frame_series.astype(int) + return key_frame_series + + +def next_seed(args): + if args.seed_behavior == "iter": + args.seed += 1 + elif args.seed_behavior == "fixed": + pass # always keep seed the same + else: + args.seed = random.randint(0, 2**32) + return args.seed + + +def render_image_batch(args, prompts, model, device): + args.prompts = prompts + + # create output folder for the batch + os.makedirs(args.outdir, exist_ok=True) + if args.save_settings or args.save_samples: + print(f"Saving to {os.path.join(args.outdir, args.timestring)}_*") + + # save settings for the batch + if args.save_settings: + filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt") + with open(filename, "w+", encoding="utf-8") as f: + json.dump(dict(args.__dict__), f, ensure_ascii=False, indent=4) + + index = 0 + + # function for init image batching + init_array = [] + if args.use_init: + if args.init_image == "": + raise FileNotFoundError("No path was given for init_image") + if args.init_image.startswith("http://") or args.init_image.startswith( + "https://" + ): + init_array.append(args.init_image) + elif not os.path.isfile(args.init_image): + if ( + args.init_image[-1] != "/" + ): # avoids path error by adding / to end if not there + args.init_image += "/" + for image in sorted( + os.listdir(args.init_image) + ): # iterates dir and appends images to init_array + if image.split(".")[-1] in ("png", "jpg", "jpeg"): + init_array.append(args.init_image + image) + else: + init_array.append(args.init_image) + else: + init_array = [""] + + # when doing large batches don't flood browser with images + clear_between_batches = args.n_batch >= 32 + + for iprompt, prompt in enumerate(prompts): + args.prompt = prompt + + all_images = [] + + for batch_index in range(args.n_batch): + if clear_between_batches: + display.clear_output(wait=True) + print(f"Batch {batch_index+1} of {args.n_batch}") + + for image in init_array: # iterates the init images + args.init_image = image + results = generate(args, model, device) + for image in results: + if args.make_grid: + all_images.append(T.functional.pil_to_tensor(image)) + if args.save_samples: + filename = f"{args.timestring}_{index:05}_{args.seed}.png" + image.save(os.path.join(args.outdir, filename)) + if args.display_samples: + display.display(image) + index += 1 + args.seed = next_seed(args) + + # print(len(all_images)) + if args.make_grid: + grid = make_grid(all_images, nrow=int(len(all_images) / args.grid_rows)) + grid = rearrange(grid, "c h w -> h w c").cpu().numpy() + filename = f"{args.timestring}_{iprompt:05d}_grid_{args.seed}.png" + grid_image = Image.fromarray(grid.astype(np.uint8)) + grid_image.save(os.path.join(args.outdir, filename)) + display.clear_output(wait=True) + display.display(grid_image) + + +def render_animation(args, anim_args, model, device): + # animations use key framed prompts + args.prompts = anim_args.animation_prompts + + # resume animation + start_frame = 0 + if anim_args.resume_from_timestring: + for tmp in os.listdir(args.outdir): + if tmp.split("_")[0] == anim_args.resume_timestring: + start_frame += 1 + start_frame = start_frame - 1 + + # create output folder for the batch + os.makedirs(args.outdir, exist_ok=True) + print(f"Saving animation frames to {args.outdir}") + + # save settings for the batch + settings_filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt") + + # resume from timestring + if anim_args.resume_from_timestring: + args.timestring = anim_args.resume_timestring + + # expand prompts out to per-frame + prompt_series = pd.Series([np.nan for a in range(anim_args.max_frames)]) + for i, prompt in anim_args.animation_prompts.items(): + prompt_series[i] = prompt + prompt_series = prompt_series.ffill().bfill() + + # check for video inits + using_vid_init = anim_args.animation_mode == "Video Input" + + args.n_samples = 1 + prev_sample = None + color_match_sample = None + for frame_idx in range(start_frame, anim_args.max_frames): + print(f"Rendering animation frame {frame_idx} of {anim_args.max_frames}") + + # resume animation + if anim_args.resume_from_timestring: + path = os.path.join(args.outdir, f"{args.timestring}_{frame_idx-1:05}.png") + img = cv2.imread(path) + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + prev_sample = sample_from_cv2(img) + + # apply transforms to previous frame + if prev_sample is not None: + if anim_args.key_frames: + angle = anim_args.angle_series[frame_idx] + zoom = anim_args.zoom_series[frame_idx] + translation_x = anim_args.translation_x_series[frame_idx] + translation_y = anim_args.translation_y_series[frame_idx] + noise = anim_args.noise_schedule_series[frame_idx] + strength = anim_args.strength_schedule_series[frame_idx] + contrast = anim_args.contrast_schedule_series[frame_idx] + print( + f"angle: {angle}", + f"zoom: {zoom}", + f"translation_x: {translation_x}", + f"translation_y: {translation_y}", + f"noise: {noise}", + f"strength: {strength}", + f"contrast: {contrast}", + ) + xform = make_xform_2d( + args.W, args.H, translation_x, translation_y, angle, zoom + ) + + # transform previous frame + prev_img = sample_to_cv2(prev_sample) + prev_img = cv2.warpPerspective( + prev_img, + xform, + (prev_img.shape[1], prev_img.shape[0]), + borderMode=cv2.BORDER_WRAP + if anim_args.border == "wrap" + else cv2.BORDER_REPLICATE, + ) + + # apply color matching + if anim_args.color_coherence != "None": + if color_match_sample is None: + color_match_sample = prev_img.copy() + else: + prev_img = maintain_colors( + prev_img, color_match_sample, anim_args.color_coherence + ) + + # apply scaling + contrast_sample = prev_img * contrast + # apply frame noising + noised_sample = add_noise(sample_from_cv2(contrast_sample), noise) + + # use transformed previous frame as init for current + args.use_init = True + args.init_sample = noised_sample.half().to(device) + args.strength = max(0.0, min(1.0, strength)) + + # grab prompt for current frame + args.prompt = prompt_series[frame_idx] + print(f"{args.prompt} {args.seed}") + + # grab init image for current frame + if using_vid_init: + init_frame = os.path.join( + args.outdir, "inputframes", f"{frame_idx+1:04}.jpg" + ) + print(f"Using video init frame {init_frame}") + args.init_image = init_frame + + # sample the diffusion model + results = generate(args, model, device, return_latent=False, return_sample=True) + sample, image = results[0], results[1] + + filename = f"{args.timestring}_{frame_idx:05}.png" + image.save(os.path.join(args.outdir, filename)) + if not using_vid_init: + prev_sample = sample + + display.clear_output(wait=True) + display.display(image) + + args.seed = next_seed(args) + + +def render_input_video(args, anim_args, model, dvice): + # create a folder for the video input frames to live in + video_in_frame_path = os.path.join(args.outdir, "inputframes") + os.makedirs(os.path.join(args.outdir, video_in_frame_path), exist_ok=True) + + # save the video frames from input video + print( + f"Exporting Video Frames (1 every {anim_args.extract_nth_frame}) frames to {video_in_frame_path}..." + ) + try: + for f in pathlib.Path(video_in_frame_path).glob("*.jpg"): + f.unlink() + except: + pass + vf = r"select=not(mod(n\," + str(anim_args.extract_nth_frame) + "))" + subprocess.run( + [ + "ffmpeg", + "-i", + f"{anim_args.video_init_path}", + "-vf", + f"{vf}", + "-vsync", + "vfr", + "-q:v", + "2", + "-loglevel", + "error", + "-stats", + os.path.join(video_in_frame_path, "%04d.jpg"), + ], + stdout=subprocess.PIPE, + ).stdout.decode("utf-8") + + # determine max frames from length of input frames + anim_args.max_frames = len( + [f for f in pathlib.Path(video_in_frame_path).glob("*.jpg")] + ) + + args.use_init = True + print( + f"Loading {anim_args.max_frames} input frames from {video_in_frame_path} and saving video frames to {args.outdir}" + ) + render_animation(args, anim_args, model, device) + + +def render_interpolation(args, anim_args, model, device): + # animations use key framed prompts + args.prompts = animation_prompts + + # create output folder for the batch + os.makedirs(args.outdir, exist_ok=True) + print(f"Saving animation frames to {args.outdir}") + + # save settings for the batch + settings_filename = os.path.join(args.outdir, f"{args.timestring}_settings.txt") + # with open(settings_filename, "w+", encoding="utf-8") as f: + # s = {**dict(args.__dict__), **dict(anim_args.__dict__)} + # json.dump(s, f, ensure_ascii=False, indent=4) + + # Interpolation Settings + args.n_samples = 1 + args.seed_behavior = ( + "fixed" # force fix seed at the moment bc only 1 seed is available + ) + prompts_c_s = [] # cache all the text embeddings + + print(f"Preparing for interpolation of the following...") + + for i, prompt in animation_prompts.items(): + args.prompt = prompt + + # sample the diffusion model + results = generate(args, model, device, return_c=True) + c, image = results[0], results[1] + prompts_c_s.append(c) + + # display.clear_output(wait=True) + display.display(image) + + args.seed = next_seed(args) + + display.clear_output(wait=True) + print(f"Interpolation start...") + + frame_idx = 0 + + if anim_args.interpolate_key_frames: + for i in range(len(prompts_c_s) - 1): + dist_frames = ( + list(animation_prompts.items())[i + 1][0] + - list(animation_prompts.items())[i][0] + ) + if dist_frames <= 0: + print("key frames duplicated or reversed. interpolation skipped.") + return + else: + for j in range(dist_frames): + # interpolate the text embedding + prompt1_c = prompts_c_s[i] + prompt2_c = prompts_c_s[i + 1] + args.init_c = prompt1_c.add( + prompt2_c.sub(prompt1_c).mul(j * 1 / dist_frames) + ) + + # sample the diffusion model + results = generate(args, model, device) + image = results[0] + + filename = f"{args.timestring}_{frame_idx:05}.png" + image.save(os.path.join(args.outdir, filename)) + frame_idx += 1 + + display.clear_output(wait=True) + display.display(image) + + args.seed = next_seed(args) + + else: + for i in range(len(prompts_c_s) - 1): + for j in range(anim_args.interpolate_x_frames + 1): + # interpolate the text embedding + prompt1_c = prompts_c_s[i] + prompt2_c = prompts_c_s[i + 1] + args.init_c = prompt1_c.add( + prompt2_c.sub(prompt1_c).mul( + j * 1 / (anim_args.interpolate_x_frames + 1) + ) + ) + + # sample the diffusion model + results = generate(args, model, device) + image = results[0] + + filename = f"{args.timestring}_{frame_idx:05}.png" + image.save(os.path.join(args.outdir, filename)) + frame_idx += 1 + + display.clear_output(wait=True) + display.display(image) + + args.seed = next_seed(args) + + # generate the last prompt + args.init_c = prompts_c_s[-1] + results = generate(args, model, device) + image = results[0] + filename = f"{args.timestring}_{frame_idx:05}.png" + image.save(os.path.join(args.outdir, filename)) + + display.clear_output(wait=True) + display.display(image) + args.seed = next_seed(args) + + # clear init_c + args.init_c = None diff --git a/scripts/img2img.py b/scripts/img2img.py index 421e2151d9..5b4537d4e2 100644 --- a/scripts/img2img.py +++ b/scripts/img2img.py @@ -198,6 +198,7 @@ def main(): config = OmegaConf.load(f"{opt.config}") model = load_model_from_config(config, f"{opt.ckpt}") + model = model.half() device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) diff --git a/scripts/txt2img.py b/scripts/txt2img.py index 59c16a1db8..28db4e78a9 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -238,6 +238,7 @@ def main(): config = OmegaConf.load(f"{opt.config}") model = load_model_from_config(config, f"{opt.ckpt}") + model = model.half() device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device)