Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
# Automatic memory management

This version can be run stand alone, but it's more meant as proof of concept so other forks can implement similar changes.

Allows to use resolutions that require up to 64x more VRAM than possible on the default CompVis build


# Stable Diffusion
*Stable Diffusion was made possible thanks to a collaboration with [Stability AI](https://stability.ai/) and [Runway](https://runwayml.com/) and builds upon our previous work:*

Expand Down
65 changes: 48 additions & 17 deletions ldm/modules/attention.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import gc
from inspect import isfunction
import math
import torch
Expand Down Expand Up @@ -89,7 +90,7 @@ def forward(self, x):
b, c, h, w = x.shape
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b (qkv heads c) h w -> qkv b heads c (h w)', heads = self.heads, qkv=3)
k = k.softmax(dim=-1)
k = k.softmax(dim=-1)
context = torch.einsum('bhdn,bhen->bhde', k, v)
out = torch.einsum('bhde,bhdn->bhen', context, q)
out = rearrange(out, 'b heads c (h w) -> b (heads c) h w', heads=self.heads, h=h, w=w)
Expand Down Expand Up @@ -161,7 +162,6 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.
self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

self.to_out = nn.Sequential(
nn.Linear(inner_dim, query_dim),
nn.Dropout(dropout)
Expand All @@ -170,27 +170,58 @@ def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.
def forward(self, x, context=None, mask=None):
h = self.heads

q = self.to_q(x)
q_in = self.to_q(x)
context = default(context, x)
k = self.to_k(context)
v = self.to_v(context)
k_in = self.to_k(context)
v_in = self.to_v(context)
del context, x

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
del q_in, k_in, v_in

sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)

if exists(mask):
mask = rearrange(mask, 'b ... -> b (...)')
max_neg_value = -torch.finfo(sim.dtype).max
mask = repeat(mask, 'b j -> (b h) () j', h=h)
sim.masked_fill_(~mask, max_neg_value)
stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch

# attention, what we cannot get enough of
attn = sim.softmax(dim=-1)
gb = 1024 ** 3
tensor_size = q.shape[0] * q.shape[1] * k.shape[1] * q.element_size()
modifier = 3 if q.element_size() == 2 else 2.5
mem_required = tensor_size * modifier
steps = 1

out = einsum('b i j, b j d -> b i d', attn, v)
out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
return self.to_out(out)

if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))
# print(f"Expected tensor size:{tensor_size/gb:0.1f}GB, cuda free:{mem_free_cuda/gb:0.1f}GB "
# f"torch free:{mem_free_torch/gb:0.1f} total:{mem_free_total/gb:0.1f} steps:{steps}")

if steps > 64:
max_res = math.floor(math.sqrt(math.sqrt(mem_free_total / 2.5)) / 8) * 64
raise RuntimeError(f'Not enough memory, use lower resolution (max approx. {max_res}x{max_res}). '
f'Need: {mem_required/64/gb:0.1f}GB free, Have:{mem_free_total/gb:0.1f}GB free')

slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size
s1 = einsum('b i d, b j d -> b i j', q[:, i:end], k) * self.scale

s2 = s1.softmax(dim=-1, dtype=q.dtype)
del s1

r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2

del q, k, v

r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
del r1

return self.to_out(r2)


class BasicTransformerBlock(nn.Module):
Expand Down
152 changes: 114 additions & 38 deletions ldm/modules/diffusionmodules/model.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# pytorch_diffusion + derived encoder decoder
import gc
import math
import torch
import torch.nn as nn
Expand Down Expand Up @@ -32,7 +33,11 @@ def get_timestep_embedding(timesteps, embedding_dim):

def nonlinearity(x):
# swish
return x*torch.sigmoid(x)
t = torch.sigmoid(x)
x *= t
del t

return x


def Normalize(in_channels, num_groups=32):
Expand Down Expand Up @@ -119,26 +124,38 @@ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
padding=0)

def forward(self, x, temb):
h = x
h = self.norm1(h)
h = nonlinearity(h)
h = self.conv1(h)
h1 = x
h2 = self.norm1(h1)
del h1

h3 = nonlinearity(h2)
del h2

h4 = self.conv1(h3)
del h3

if temb is not None:
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
h4 = h4 + self.temb_proj(nonlinearity(temb))[:,:,None,None]

h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
h5 = self.norm2(h4)
del h4

h6 = nonlinearity(h5)
del h5

h7 = self.dropout(h6)
del h6

h8 = self.conv2(h7)
del h7

if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
x = self.conv_shortcut(x)
else:
x = self.nin_shortcut(x)

return x+h
return x + h8


class LinAttnBlock(LinearAttention):
Expand Down Expand Up @@ -174,32 +191,68 @@ def __init__(self, in_channels):
stride=1,
padding=0)


def forward(self, x):
h_ = x
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
q1 = self.q(h_)
k1 = self.k(h_)
v = self.v(h_)

# compute attention
b,c,h,w = q.shape
q = q.reshape(b,c,h*w)
q = q.permute(0,2,1) # b,hw,c
k = k.reshape(b,c,h*w) # b,c,hw
w_ = torch.bmm(q,k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w_ = w_ * (int(c)**(-0.5))
w_ = torch.nn.functional.softmax(w_, dim=2)
b, c, h, w = q1.shape

q2 = q1.reshape(b, c, h*w)
del q1

q = q2.permute(0, 2, 1) # b,hw,c
del q2

k = k1.reshape(b, c, h*w) # b,c,hw
del k1

h_ = torch.zeros_like(k, device=q.device)

stats = torch.cuda.memory_stats(q.device)
mem_active = stats['active_bytes.all.current']
mem_reserved = stats['reserved_bytes.all.current']
mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device())
mem_free_torch = mem_reserved - mem_active
mem_free_total = mem_free_cuda + mem_free_torch

tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * q.element_size()
mem_required = tensor_size * 2.5
steps = 1

if mem_required > mem_free_total:
steps = 2**(math.ceil(math.log(mem_required / mem_free_total, 2)))

slice_size = q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1]
for i in range(0, q.shape[1], slice_size):
end = i + slice_size

w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
w2 = w1 * (int(c)**(-0.5))
del w1
w3 = torch.nn.functional.softmax(w2, dim=2, dtype=q.dtype)
del w2

# attend to values
v = v.reshape(b,c,h*w)
w_ = w_.permute(0,2,1) # b,hw,hw (first hw of k, second of q)
h_ = torch.bmm(v,w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
h_ = h_.reshape(b,c,h,w)
# attend to values
v1 = v.reshape(b, c, h*w)
w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
del w3

h_ = self.proj_out(h_)
h_[:, :, i:end] = torch.bmm(v1, w4) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
del v1, w4

return x+h_
h2 = h_.reshape(b, c, h, w)
del h_

h3 = self.proj_out(h2)
del h2

h3 += x

return h3


def make_attn(in_channels, attn_type="vanilla"):
Expand Down Expand Up @@ -540,31 +593,54 @@ def forward(self, z):
temb = None

# z to block_in
h = self.conv_in(z)
h1 = self.conv_in(z)

# middle
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
h2 = self.mid.block_1(h1, temb)
del h1

h3 = self.mid.attn_1(h2)
del h2

h = self.mid.block_2(h3, temb)
del h3

# prepare for up sampling
gc.collect()
torch.cuda.empty_cache()

# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
h = self.up[i_level].attn[i_block](h)
t = h
h = self.up[i_level].attn[i_block](t)
del t

if i_level != 0:
h = self.up[i_level].upsample(h)
t = h
h = self.up[i_level].upsample(t)
del t

# end
if self.give_pre_end:
return h

h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
h1 = self.norm_out(h)
del h

h2 = nonlinearity(h1)
del h1

h = self.conv_out(h2)
del h2

if self.tanh_out:
h = torch.tanh(h)
t = h
h = torch.tanh(t)
del t

return h


Expand Down
4 changes: 4 additions & 0 deletions scripts/img2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,12 @@ def main():
opt = parser.parse_args()
seed_everything(opt.seed)

# needed when model is in half mode, remove if not using half mode
torch.set_default_tensor_type(torch.HalfTensor)

config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")
model = model.half()

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
Expand Down
4 changes: 4 additions & 0 deletions scripts/txt2img.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,12 @@ def main():

seed_everything(opt.seed)

# needed when model is in half mode, remove if not using half mode
torch.set_default_tensor_type(torch.HalfTensor)

config = OmegaConf.load(f"{opt.config}")
model = load_model_from_config(config, f"{opt.ckpt}")
model = model.half()

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model = model.to(device)
Expand Down