From afaefd2e5f2e391d8b84fe787916fbf4fb1ee740 Mon Sep 17 00:00:00 2001 From: Doggettx Date: Mon, 5 Sep 2022 09:26:27 +0200 Subject: [PATCH 01/13] Update attention.py Run attention in a loop to allow for much higher resolutions (over 1920x1920 on a 3090) --- ldm/modules/attention.py | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index f4eff39ccb..2b7214c711 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -174,23 +174,34 @@ def forward(self, x, context=None, mask=None): context = default(context, x) k = self.to_k(context) v = self.to_v(context) + del context, x q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) - if exists(mask): - mask = rearrange(mask, 'b ... -> b (...)') - max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) - sim.masked_fill_(~mask, max_neg_value) + # valid values for steps = 2,4,8,16,32,64 + # higher steps is slower but less memory usage + # at 16 can run 1920x1536 on a 3090, at 64 can run over 1920x1920 + # speed seems to be impacted more on 30x series cards + steps = 16 + slice_size = q.shape[1] // steps if q.shape[1] % steps == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) + s1 *= self.scale - # attention, what we cannot get enough of - attn = sim.softmax(dim=-1) + s2 = s1.softmax(dim=-1) + del s1 - out = einsum('b i j, b j d -> b i d', attn, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) - return self.to_out(out) + r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) + del s2 + + + r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) + del r1 + + return self.to_out(r2) class BasicTransformerBlock(nn.Module): From 5065b41ce12c3b043ba5196283a4907cd2e1df5b Mon Sep 17 00:00:00 2001 From: Doggettx Date: Mon, 5 Sep 2022 10:13:59 +0200 Subject: [PATCH 02/13] Update attention.py Correction to comment --- ldm/modules/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index 2b7214c711..39fc20af0b 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -182,7 +182,7 @@ def forward(self, x, context=None, mask=None): # valid values for steps = 2,4,8,16,32,64 # higher steps is slower but less memory usage - # at 16 can run 1920x1536 on a 3090, at 64 can run over 1920x1920 + # at 16 can run 1920x1536 on a 3090, at 32 can run over 1920x1920 # speed seems to be impacted more on 30x series cards steps = 16 slice_size = q.shape[1] // steps if q.shape[1] % steps == 0 else q.shape[1] From 8283bb5b84580487e7a9e25c37816484bf4ed42b Mon Sep 17 00:00:00 2001 From: Doggettx Date: Mon, 5 Sep 2022 12:03:59 +0200 Subject: [PATCH 03/13] Update attention.py --- ldm/modules/attention.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index 39fc20af0b..89d8b6db78 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -180,12 +180,18 @@ def forward(self, x, context=None, mask=None): r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) - # valid values for steps = 2,4,8,16,32,64 - # higher steps is slower but less memory usage - # at 16 can run 1920x1536 on a 3090, at 32 can run over 1920x1920 - # speed seems to be impacted more on 30x series cards - steps = 16 - slice_size = q.shape[1] // steps if q.shape[1] % steps == 0 else q.shape[1] + stats = torch.cuda.memory_stats(q.device) + mem_total = torch.cuda.get_device_properties(0).total_memory + mem_active = stats['active_bytes.all.current'] + mem_free = mem_total - mem_active + + mem_required = q.shape[0] * q.shape[1] * k.shape[1] * 4 * 2.5 + steps = 1 + + if mem_required > mem_free: + steps = 2**(math.ceil(math.log(mem_required / mem_free, 2))) + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] for i in range(0, q.shape[1], slice_size): end = i + slice_size s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) @@ -204,6 +210,7 @@ def forward(self, x, context=None, mask=None): return self.to_out(r2) + class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): super().__init__() From d3c91ec937a4f1d4fc79b68875931bdb5550bb6e Mon Sep 17 00:00:00 2001 From: Doggettx Date: Mon, 5 Sep 2022 19:49:45 +0200 Subject: [PATCH 04/13] Fixed memory handling for model.decode_first_stage Better memory handling for model.decode_first_stage so it doesn't crash anymore after 100% rendering --- ldm/modules/attention.py | 11 +- ldm/modules/diffusionmodules/model.py | 142 +++++++++++++++++++------- 2 files changed, 112 insertions(+), 41 deletions(-) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index 89d8b6db78..7d3f8c2be8 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -170,13 +170,14 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0. def forward(self, x, context=None, mask=None): h = self.heads - q = self.to_q(x) + q_in = self.to_q(x) context = default(context, x) - k = self.to_k(context) - v = self.to_v(context) + k_in = self.to_k(context) + v_in = self.to_v(context) del context, x - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in)) + del q_in, k_in, v_in r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) @@ -203,6 +204,7 @@ def forward(self, x, context=None, mask=None): r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) del s2 + del q, k, v r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h) del r1 @@ -210,7 +212,6 @@ def forward(self, x, context=None, mask=None): return self.to_out(r2) - class BasicTransformerBlock(nn.Module): def __init__(self, dim, n_heads, d_head, dropout=0., context_dim=None, gated_ff=True, checkpoint=True): super().__init__() diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index 533e589a20..fd16dd50ab 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -1,4 +1,5 @@ # pytorch_diffusion + derived encoder decoder +import gc import math import torch import torch.nn as nn @@ -119,18 +120,30 @@ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, padding=0) def forward(self, x, temb): - h = x - h = self.norm1(h) - h = nonlinearity(h) - h = self.conv1(h) + h1 = x + h2 = self.norm1(h1) + del h1 + + h3 = nonlinearity(h2) + del h2 + + h4 = self.conv1(h3) + del h3 if temb is not None: - h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] + h4 += self.temb_proj(nonlinearity(temb))[:,:,None,None] - h = self.norm2(h) - h = nonlinearity(h) - h = self.dropout(h) - h = self.conv2(h) + h5 = self.norm2(h4) + del h4 + + h6 = nonlinearity(h5) + del h5 + + h7 = self.dropout(h6) + del h6 + + h8 = self.conv2(h7) + del h7 if self.in_channels != self.out_channels: if self.use_conv_shortcut: @@ -138,7 +151,8 @@ def forward(self, x, temb): else: x = self.nin_shortcut(x) - return x+h + h8 += x + return h8 class LinAttnBlock(LinearAttention): @@ -178,28 +192,61 @@ def __init__(self, in_channels): def forward(self, x): h_ = x h_ = self.norm(h_) - q = self.q(h_) - k = self.k(h_) + q1 = self.q(h_) + k1 = self.k(h_) v = self.v(h_) # compute attention - b,c,h,w = q.shape - q = q.reshape(b,c,h*w) - q = q.permute(0,2,1) # b,hw,c - k = k.reshape(b,c,h*w) # b,c,hw - w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w_ = w_ * (int(c)**(-0.5)) - w_ = torch.nn.functional.softmax(w_, dim=2) + b, c, h, w = q1.shape + + q2 = q1.reshape(b, c, h*w) + del q1 + + q = q2.permute(0, 2, 1) # b,hw,c + del q2 + + k = k1.reshape(b, c, h*w) # b,c,hw + del k1 + + h_ = torch.zeros_like(k, device=q.device) + + stats = torch.cuda.memory_stats(q.device) + mem_total = torch.cuda.get_device_properties(0).total_memory + mem_active = stats['active_bytes.all.current'] + mem_free = mem_total - mem_active + + mem_required = q.shape[0] * q.shape[1] * k.shape[2] * 4 * 2.5 + steps = 1 + + if mem_required > mem_free: + steps = 2**(math.ceil(math.log(mem_required / mem_free, 2))) + + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] + for i in range(0, q.shape[1], slice_size): + end = i + slice_size + + w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w1 *= (int(c)**(-0.5)) + w2 = torch.nn.functional.softmax(w1, dim=2) + del w1 + + # attend to values + v1 = v.reshape(b, c, h*w) + w3 = w2.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + del w2 - # attend to values - v = v.reshape(b,c,h*w) - w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q) - h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - h_ = h_.reshape(b,c,h,w) + h_[:, :, i:end] = torch.bmm(v1, w3) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + del v1, w3 - h_ = self.proj_out(h_) + h2 = h_.reshape(b, c, h, w) + del h_ - return x+h_ + h3 = self.proj_out(h2) + del h2 + + h3 += x + + return h3 def make_attn(in_channels, attn_type="vanilla"): @@ -540,31 +587,54 @@ def forward(self, z): temb = None # z to block_in - h = self.conv_in(z) + h1 = self.conv_in(z) # middle - h = self.mid.block_1(h, temb) - h = self.mid.attn_1(h) - h = self.mid.block_2(h, temb) + h2 = self.mid.block_1(h1, temb) + del h1 + + h3 = self.mid.attn_1(h2) + del h2 + + h = self.mid.block_2(h3, temb) + del h3 + + # prepare for up sampling + gc.collect() + torch.cuda.empty_cache() # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks+1): h = self.up[i_level].block[i_block](h, temb) if len(self.up[i_level].attn) > 0: - h = self.up[i_level].attn[i_block](h) + t = h + h = self.up[i_level].attn[i_block](t) + del t + if i_level != 0: - h = self.up[i_level].upsample(h) + t = h + h = self.up[i_level].upsample(t) + del t # end if self.give_pre_end: return h - h = self.norm_out(h) - h = nonlinearity(h) - h = self.conv_out(h) + h1 = self.norm_out(h) + del h + + h2 = nonlinearity(h1) + del h1 + + h = self.conv_out(h2) + del h2 + if self.tanh_out: - h = torch.tanh(h) + t = h + h = torch.tanh(t) + del t + return h From 507ddec578d54ddf7eb39fac5d646c9937526565 Mon Sep 17 00:00:00 2001 From: Doggettx Date: Tue, 6 Sep 2022 09:08:46 +0200 Subject: [PATCH 05/13] Fixed free memory calculation Old version gave incorrect free memory results causing in crashes on edge cases. --- ldm/modules/attention.py | 16 +++++++++++----- ldm/modules/diffusionmodules/model.py | 13 ++++++++----- 2 files changed, 19 insertions(+), 10 deletions(-) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index 7d3f8c2be8..e6db2ddfce 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -182,15 +182,21 @@ def forward(self, x, context=None, mask=None): r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device) stats = torch.cuda.memory_stats(q.device) - mem_total = torch.cuda.get_device_properties(0).total_memory mem_active = stats['active_bytes.all.current'] - mem_free = mem_total - mem_active + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch - mem_required = q.shape[0] * q.shape[1] * k.shape[1] * 4 * 2.5 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4 + mem_required = tensor_size * 2.5 steps = 1 - if mem_required > mem_free: - steps = 2**(math.ceil(math.log(mem_required / mem_free, 2))) + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) + gb = 1024**3 + print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] for i in range(0, q.shape[1], slice_size): diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index fd16dd50ab..7c78f465a2 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -211,15 +211,18 @@ def forward(self, x): h_ = torch.zeros_like(k, device=q.device) stats = torch.cuda.memory_stats(q.device) - mem_total = torch.cuda.get_device_properties(0).total_memory mem_active = stats['active_bytes.all.current'] - mem_free = mem_total - mem_active + mem_reserved = stats['reserved_bytes.all.current'] + mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) + mem_free_torch = mem_reserved - mem_active + mem_free_total = mem_free_cuda + mem_free_torch - mem_required = q.shape[0] * q.shape[1] * k.shape[2] * 4 * 2.5 + tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * 4 + mem_required = tensor_size * 2.5 steps = 1 - if mem_required > mem_free: - steps = 2**(math.ceil(math.log(mem_required / mem_free, 2))) + if mem_required > mem_free_total: + steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] for i in range(0, q.shape[1], slice_size): From a1fbe55f85dd6e7e4fdb3c9081f8f272e3233b59 Mon Sep 17 00:00:00 2001 From: Doggettx Date: Tue, 6 Sep 2022 09:09:49 +0200 Subject: [PATCH 06/13] Set model to half Set model to half in txt2img and img2img for less memory usage. --- scripts/img2img.py | 1 + scripts/txt2img.py | 1 + 2 files changed, 2 insertions(+) diff --git a/scripts/img2img.py b/scripts/img2img.py index 421e2151d9..5b4537d4e2 100644 --- a/scripts/img2img.py +++ b/scripts/img2img.py @@ -198,6 +198,7 @@ def main(): config = OmegaConf.load(f"{opt.config}") model = load_model_from_config(config, f"{opt.ckpt}") + model = model.half() device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) diff --git a/scripts/txt2img.py b/scripts/txt2img.py index 59c16a1db8..28db4e78a9 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -238,6 +238,7 @@ def main(): config = OmegaConf.load(f"{opt.config}") model = load_model_from_config(config, f"{opt.ckpt}") + model = model.half() device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") model = model.to(device) From 7a32fd649360aca42c12b411f05cd47f3bbb13ab Mon Sep 17 00:00:00 2001 From: Doggettx Date: Tue, 6 Sep 2022 09:12:09 +0200 Subject: [PATCH 07/13] Commented out debug info Forgot to comment out debug info --- ldm/modules/attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index e6db2ddfce..f82038f67e 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -194,9 +194,9 @@ def forward(self, x, context=None, mask=None): if mem_required > mem_free_total: steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) - gb = 1024**3 - print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " - f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + # gb = 1024**3 + # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] for i in range(0, q.shape[1], slice_size): From 755bec892369f04bccb430b05a8dcad97dbd6703 Mon Sep 17 00:00:00 2001 From: Doggettx Date: Tue, 6 Sep 2022 09:39:10 +0200 Subject: [PATCH 08/13] Raise error when steps too high Technically you could run at higher steps as long as the resolution is dividable by the steps but you're going to run into memory issues later on anyhow. --- ldm/modules/attention.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index f82038f67e..f556c7bc08 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -188,16 +188,20 @@ def forward(self, x, context=None, mask=None): mem_free_torch = mem_reserved - mem_active mem_free_total = mem_free_cuda + mem_free_torch + gb = 1024 ** 3 tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4 mem_required = tensor_size * 2.5 steps = 1 if mem_required > mem_free_total: steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) - # gb = 1024**3 # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + if steps > 64: + raise RuntimeError(f'Not enough memory, use lower resolution. ' + f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') + slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] for i in range(0, q.shape[1], slice_size): end = i + slice_size From f134e245ba8d53d2fa5e3268050dcc752ae1edf0 Mon Sep 17 00:00:00 2001 From: Doggettx Date: Tue, 6 Sep 2022 10:43:19 +0200 Subject: [PATCH 09/13] Added max. res info to memory exception --- ldm/modules/attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index f556c7bc08..7ad8165c6e 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -199,7 +199,8 @@ def forward(self, x, context=None, mask=None): # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") if steps > 64: - raise RuntimeError(f'Not enough memory, use lower resolution. ' + max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 + raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). ' f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free') slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] From 830f6946f9aa29a7683b1c3b4d92537c0aecc827 Mon Sep 17 00:00:00 2001 From: Doggettx Date: Wed, 7 Sep 2022 08:43:36 +0200 Subject: [PATCH 10/13] Reverted in place tensor functions back to CompVis version Improves performance and is no longer needed. --- ldm/modules/attention.py | 4 ++-- ldm/modules/diffusionmodules/model.py | 14 +++++++------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index 7ad8165c6e..f848a7c75f 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -1,3 +1,4 @@ +import gc from inspect import isfunction import math import torch @@ -206,8 +207,7 @@ def forward(self, x, context=None, mask=None): slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] for i in range(0, q.shape[1], slice_size): end = i + slice_size - s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) - s1 *= self.scale + s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale s2 = s1.softmax(dim=-1) del s1 diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index 7c78f465a2..cd3328cbe6 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -188,7 +188,6 @@ def __init__(self, in_channels): stride=1, padding=0) - def forward(self, x): h_ = x h_ = self.norm(h_) @@ -229,17 +228,18 @@ def forward(self, x): end = i + slice_size w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] - w1 *= (int(c)**(-0.5)) - w2 = torch.nn.functional.softmax(w1, dim=2) + w2 = w1 * (int(c)**(-0.5)) del w1 + w3 = torch.nn.functional.softmax(w2, dim=2) + del w2 # attend to values v1 = v.reshape(b, c, h*w) - w3 = w2.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) - del w2 + w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) + del w3 - h_[:, :, i:end] = torch.bmm(v1, w3) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] - del v1, w3 + h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + del v1, w4 h2 = h_.reshape(b, c, h, w) del h_ From c2d72c5c23492343bce6c090dbfd20ae90006deb Mon Sep 17 00:00:00 2001 From: Doggettx Date: Wed, 7 Sep 2022 09:03:19 +0200 Subject: [PATCH 11/13] Missed one function to revert --- ldm/modules/diffusionmodules/model.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index cd3328cbe6..9be8922e9d 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -131,7 +131,7 @@ def forward(self, x, temb): del h3 if temb is not None: - h4 += self.temb_proj(nonlinearity(temb))[:,:,None,None] + h4 = h4 + self.temb_proj(nonlinearity(temb))[:,:,None,None] h5 = self.norm2(h4) del h4 @@ -151,8 +151,7 @@ def forward(self, x, temb): else: x = self.nin_shortcut(x) - h8 += x - return h8 + return x + h8 class LinAttnBlock(LinearAttention): From cd3d653f79cedc1849a02323f36b9b33fd089ff3 Mon Sep 17 00:00:00 2001 From: Doggettx <110817577+Doggettx@users.noreply.github.com> Date: Wed, 7 Sep 2022 12:29:15 +0200 Subject: [PATCH 12/13] Update README.md --- README.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/README.md b/README.md index c9e6c3bb13..169399aac8 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,10 @@ +# Automatic memory management + +This version can be run stand alone, but it's more meant as proof of concept so other forks can implement similar changes. + +Allows to use resolutions that require up to 64x more VRAM than possible on the default CompVis build + + # Stable Diffusion *Stable Diffusion was made possible thanks to a collaboration with [Stability AI](https://stability.ai/) and [Runway](https://runwayml.com/) and builds upon our previous work:* From 5fe97c69c9e3738d722690063c9492b297737cee Mon Sep 17 00:00:00 2001 From: Doggettx Date: Sat, 10 Sep 2022 11:11:05 +0200 Subject: [PATCH 13/13] Performance boost and fix sigmoid for higher resolutions Significant performance boost at higher resolutions when running in auto_cast or half mode on 3090 went from 1.13it/s to 1.63it/s at 1024x1024 Will also allow for higher resolutions due to sigmoid fix and using half memory --- ldm/modules/attention.py | 13 +++++++------ ldm/modules/diffusionmodules/model.py | 10 +++++++--- scripts/img2img.py | 3 +++ scripts/txt2img.py | 3 +++ 4 files changed, 20 insertions(+), 9 deletions(-) diff --git a/ldm/modules/attention.py b/ldm/modules/attention.py index f848a7c75f..a0f4f18b80 100644 --- a/ldm/modules/attention.py +++ b/ldm/modules/attention.py @@ -90,7 +90,7 @@ def forward(self, x): b, c, h, w = x.shape qkv = self.to_qkv(x) q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3) - k = k.softmax(dim=-1) + k = k.softmax(dim=-1) context = torch.einsum('bhdn,bhen->bhde', k, v) out = torch.einsum('bhde,bhdn->bhen', context, q) out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w) @@ -162,7 +162,6 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0. self.to_q = nn.Linear(query_dim, inner_dim, bias=False) self.to_k = nn.Linear(context_dim, inner_dim, bias=False) self.to_v = nn.Linear(context_dim, inner_dim, bias=False) - self.to_out = nn.Sequential( nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) @@ -190,14 +189,16 @@ def forward(self, x, context=None, mask=None): mem_free_total = mem_free_cuda + mem_free_torch gb = 1024 ** 3 - tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * 4 - mem_required = tensor_size * 2.5 + tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size() + modifier = 3 if q.element_size() == 2 else 2.5 + mem_required = tensor_size * modifier steps = 1 + if mem_required > mem_free_total: steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2))) # print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB " - # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") + # f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}") if steps > 64: max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64 @@ -209,7 +210,7 @@ def forward(self, x, context=None, mask=None): end = i + slice_size s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale - s2 = s1.softmax(dim=-1) + s2 = s1.softmax(dim=-1, dtype=q.dtype) del s1 r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index 9be8922e9d..de3ce38c6e 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -33,7 +33,11 @@ def get_timestep_embedding(timesteps, embedding_dim): def nonlinearity(x): # swish - return x*torch.sigmoid(x) + t = torch.sigmoid(x) + x *= t + del t + + return x def Normalize(in_channels, num_groups=32): @@ -215,7 +219,7 @@ def forward(self, x): mem_free_torch = mem_reserved - mem_active mem_free_total = mem_free_cuda + mem_free_torch - tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * 4 + tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size() mem_required = tensor_size * 2.5 steps = 1 @@ -229,7 +233,7 @@ def forward(self, x): w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] w2 = w1 * (int(c)**(-0.5)) del w1 - w3 = torch.nn.functional.softmax(w2, dim=2) + w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype) del w2 # attend to values diff --git a/scripts/img2img.py b/scripts/img2img.py index 5b4537d4e2..04b88b54f1 100644 --- a/scripts/img2img.py +++ b/scripts/img2img.py @@ -196,6 +196,9 @@ def main(): opt = parser.parse_args() seed_everything(opt.seed) + # needed when model is in half mode, remove if not using half mode + torch.set_default_tensor_type(torch.HalfTensor) + config = OmegaConf.load(f"{opt.config}") model = load_model_from_config(config, f"{opt.ckpt}") model = model.half() diff --git a/scripts/txt2img.py b/scripts/txt2img.py index 28db4e78a9..a08f522988 100644 --- a/scripts/txt2img.py +++ b/scripts/txt2img.py @@ -236,6 +236,9 @@ def main(): seed_everything(opt.seed) + # needed when model is in half mode, remove if not using half mode + torch.set_default_tensor_type(torch.HalfTensor) + config = OmegaConf.load(f"{opt.config}") model = load_model_from_config(config, f"{opt.ckpt}") model = model.half()