diff --git a/README.md b/README.md index c9e6c3bb13..169399aac8 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,10 @@ +# Automatic memory management + +This version can be run stand alone, but it's more meant as proof of concept so other forks can implement similar changes. + +Allows to use resolutions that require up to 64x more VRAM than possible on the default CompVis build + + # Stable Diffusion *Stable Diffusion was made possible thanks to a collaboration with [Stability AI](https://stability.ai/) and [Runway](https://runwayml.com/) and builds upon our previous work:* diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index f4eff39ccb..a0f4f18b80 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -1,3 +1,4 @@ +import gc from inspect import isfunction import math import torch @@ -89,7 +90,7 @@ def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x) q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) - k = k.softmax(dim=-1) + k = k.softmax(dim=-1) context = torch.einsum('bhdn,bhen->bhde', k, v) out = torch.einsum('bhde,bhdn->bhen', context, q) out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) @@ -161,7 +162,6 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0. self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - self.to_out = nn.Sequential( nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) @@ -170,27 +170,58 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0. def forward(self, x, context=None, mask=None): h = self.heads - q = self.to_q(x) + q_in = self.to_q(x) context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) + k_in = self.to_k(context) + v_in = self.to_v(context) + del context, x - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) - if exists(mask): - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) + gb = 1024 ** 3 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier + steps = 1 - out = einsum('b i j, b j d -> b i d', attn, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - return self.to_out(out) + + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + + if steps > 64: + max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' + f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale + + s2 = s1.softmax(dim=-1, dtype=q.dtype) + del s1 + + r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) + del s2 + + del q, k, v + + r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) + del r1 + + return self.to_out(r2) class BasicTransformerBlock(nn.Module): diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index 533e589a20..de3ce38c6e 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -1,4 +1,5 @@ # pytorch_diffusion + derived encoder decoder +import gc import math import torch import torch.nn as nn @@ -32,7 +33,11 @@ def get_timestep_embedding(timesteps, embedding_dim): def nonlinearity(x): # swish - return x*torch.sigmoid(x) + t = torch.sigmoid(x) + x *= t + del t + + return x def Normalize(in_channels, num_groups=32): @@ -119,18 +124,30 @@ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, padding=0) def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) + h1 = x + h2 = self.norm1(h1) + del h1 + + h3 = nonlinearity(h2) + del h2 + + h4 = self.conv1(h3) + del h3 if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + h4 = h4 + self.temb_proj(nonlinearity(temb))[:,:,None,None] - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) + h5 = self.norm2(h4) + del h4 + + h6 = nonlinearity(h5) + del h5 + + h7 = self.dropout(h6) + del h6 + + h8 = self.conv2(h7) + del h7 if self.in_channels != self.out_channels: if self.use_conv_shortcut: @@ -138,7 +155,7 @@ def forward(self, x, temb): else: x = self.nin_shortcut(x) - return x+h + return x + h8 class LinAttnBlock(LinearAttention): @@ -174,32 +191,68 @@ def __init__(self, in_channels): stride=1, padding=0) - def forward(self, x): h_ = x h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) + q1 = self.q(h_) + k1 = self.k(h_) v = self.v(h_) # compute attention - b,c,h,w = q.shape - q = q.reshape(b,c,h*w) - q = q.permute(0,2,1) # b,hw,c - k = k.reshape(b,c,h*w) # b,c,hw - w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c)**(-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) + b, c, h, w = q1.shape + + q2 = q1.reshape(b, c, h*w) + del q1 + + q = q2.permute(0, 2, 1) # b,hw,c + del q2 + + k = k1.reshape(b, c, h*w) # b,c,hw + del k1 + + h_ = torch.zeros_like(k, device=q.device) + + stats = torch.cuda.memory_stats(q.device) + mem_active = stats['active_bytes.all.current'] + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch + + tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() + mem_required = tensor_size * 2.5 + steps = 1 + + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + + w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w2 = w1 * (int(c)**(-0.5)) + del w1 + w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) + del w2 - # attend to values - v = v.reshape(b,c,h*w) - w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b,c,h,w) + # attend to values + v1 = v.reshape(b, c, h*w) + w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + del w3 - h_ = self.proj_out(h_) + h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + del v1, w4 - return x+h_ + h2 = h_.reshape(b, c, h, w) + del h_ + + h3 = self.proj_out(h2) + del h2 + + h3 += x + + return h3 def make_attn(in_channels, attn_type="vanilla"): @@ -540,31 +593,54 @@ def forward(self, z): temb = None # z to block_in - h = self.conv_in(z) + h1 = self.conv_in(z) # middle - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) + h2 = self.mid.block_1(h1, temb) + del h1 + + h3 = self.mid.attn_1(h2) + del h2 + + h = self.mid.block_2(h3, temb) + del h3 + + # prepare for up sampling + gc.collect() + torch.cuda.empty_cache() # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks+1): h = self.up[i_level].block[i_block](h, temb) if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) + t = h + h = self.up[i_level].attn[i_block](t) + del t + if i_level != 0: - h = self.up[i_level].upsample(h) + t = h + h = self.up[i_level].upsample(t) + del t # end if self.give_pre_end: return h - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) + h1 = self.norm_out(h) + del h + + h2 = nonlinearity(h1) + del h1 + + h = self.conv_out(h2) + del h2 + if self.tanh_out: - h = torch.tanh(h) + t = h + h = torch.tanh(t) + del t + return h diff --git a/scripts/img2img.py b/scripts/img2img.py index 421e2151d9..04b88b54f1 100644 --- a/scripts/img2img.py +++ b/scripts/img2img.py @@ -196,8 +196,12 @@ def main(): opt = parser.parse_args() seed_everything(opt.seed) + # needed when model is in half mode, remove if not using half mode + torch.set_default_tensor_type(torch.HalfTensor) + config = OmegaConf.load(f"{opt.config}") model = load_model_from_config(config, f"{opt.ckpt}") + model = model.half() device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) diff --git a/scripts/txt2img.py b/scripts/txt2img.py index 59c16a1db8..a08f522988 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -236,8 +236,12 @@ def main(): seed_everything(opt.seed) + # needed when model is in half mode, remove if not using half mode + torch.set_default_tensor_type(torch.HalfTensor) + config = OmegaConf.load(f"{opt.config}") model = load_model_from_config(config, f"{opt.ckpt}") + model = model.half() device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device)