Skip to content

Commit 720e5cd

Browse files
lsteinbakkot
andauthored
Refactoring simplet2i (CompVis#387)
* start refactoring -not yet functional * first phase of refactor done - not sure weighted prompts working * Second phase of refactoring. Everything mostly working. * The refactoring has moved all the hard-core inference work into ldm.dream.generator.*, where there are submodules for txt2img and img2img. inpaint will go in there as well. * Some additional refactoring will be done soon, but relatively minor work. * fix -save_orig flag to actually work * add @neonsecret attention.py memory optimization * remove unneeded imports * move token logging into conditioning.py * add placeholder version of inpaint; porting in progress * fix crash in img2img * inpainting working; not tested on variations * fix crashes in img2img * ported attention.py memory optimization basujindal#117 from basujindal branch * added @torch_no_grad() decorators to img2img, txt2img, inpaint closures * Final commit prior to PR against development * fixup crash when generating intermediate images in web UI * rename ldm.simplet2i to ldm.generate * add backward-compatibility simplet2i shell with deprecation warning * add back in mps exception, addresses @Vargol comment in CompVis#354 * replaced Conditioning class with exported functions * fix wrong type of with_variations attribute during intialization * changed "image_iterator()" to "get_make_image()" * raise NotImplementedError for calling get_make_image() in parent class * Update ldm/generate.py better error message Co-authored-by: Kevin Gibbons <[email protected]> * minor stylistic fixes and assertion checks from code review * moved get_noise() method into img2img class * break get_noise() into two methods, one for txt2img and the other for img2img * inpainting works on non-square images now * make get_noise() an abstract method in base class * much improved inpainting Co-authored-by: Kevin Gibbons <[email protected]>
1 parent 1ad2a8e commit 720e5cd

File tree

16 files changed

+1261
-990
lines changed

16 files changed

+1261
-990
lines changed

ldm/dream/conditioning.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
'''
2+
This module handles the generation of the conditioning tensors, including management of
3+
weighted subprompts.
4+
5+
Useful function exports:
6+
7+
get_uc_and_c() get the conditioned and unconditioned latent
8+
split_weighted_subpromopts() split subprompts, normalize and weight them
9+
log_tokenization() print out colour-coded tokens and warn if truncated
10+
11+
'''
12+
import re
13+
import torch
14+
15+
def get_uc_and_c(prompt, model, log_tokens=False, skip_normalize=False):
16+
uc = model.get_learned_conditioning([''])
17+
18+
# get weighted sub-prompts
19+
weighted_subprompts = split_weighted_subprompts(
20+
prompt, skip_normalize
21+
)
22+
23+
if len(weighted_subprompts) > 1:
24+
# i dont know if this is correct.. but it works
25+
c = torch.zeros_like(uc)
26+
# normalize each "sub prompt" and add it
27+
for subprompt, weight in weighted_subprompts:
28+
log_tokenization(subprompt, model, log_tokens)
29+
c = torch.add(
30+
c,
31+
model.get_learned_conditioning([subprompt]),
32+
alpha=weight,
33+
)
34+
else: # just standard 1 prompt
35+
log_tokenization(prompt, model, log_tokens)
36+
c = model.get_learned_conditioning([prompt])
37+
return (uc, c)
38+
39+
def split_weighted_subprompts(text, skip_normalize=False)->list:
40+
"""
41+
grabs all text up to the first occurrence of ':'
42+
uses the grabbed text as a sub-prompt, and takes the value following ':' as weight
43+
if ':' has no value defined, defaults to 1.0
44+
repeats until no text remaining
45+
"""
46+
prompt_parser = re.compile("""
47+
(?P<prompt> # capture group for 'prompt'
48+
(?:\\\:|[^:])+ # match one or more non ':' characters or escaped colons '\:'
49+
) # end 'prompt'
50+
(?: # non-capture group
51+
:+ # match one or more ':' characters
52+
(?P<weight> # capture group for 'weight'
53+
-?\d+(?:\.\d+)? # match positive or negative integer or decimal number
54+
)? # end weight capture group, make optional
55+
\s* # strip spaces after weight
56+
| # OR
57+
$ # else, if no ':' then match end of line
58+
) # end non-capture group
59+
""", re.VERBOSE)
60+
parsed_prompts = [(match.group("prompt").replace("\\:", ":"), float(
61+
match.group("weight") or 1)) for match in re.finditer(prompt_parser, text)]
62+
if skip_normalize:
63+
return parsed_prompts
64+
weight_sum = sum(map(lambda x: x[1], parsed_prompts))
65+
if weight_sum == 0:
66+
print(
67+
"Warning: Subprompt weights add up to zero. Discarding and using even weights instead.")
68+
equal_weight = 1 / len(parsed_prompts)
69+
return [(x[0], equal_weight) for x in parsed_prompts]
70+
return [(x[0], x[1] / weight_sum) for x in parsed_prompts]
71+
72+
# shows how the prompt is tokenized
73+
# usually tokens have '</w>' to indicate end-of-word,
74+
# but for readability it has been replaced with ' '
75+
def log_tokenization(text, model, log=False):
76+
if not log:
77+
return
78+
tokens = model.cond_stage_model.tokenizer._tokenize(text)
79+
tokenized = ""
80+
discarded = ""
81+
usedTokens = 0
82+
totalTokens = len(tokens)
83+
for i in range(0, totalTokens):
84+
token = tokens[i].replace('</w>', ' ')
85+
# alternate color
86+
s = (usedTokens % 6) + 1
87+
if i < model.cond_stage_model.max_length:
88+
tokenized = tokenized + f"\x1b[0;3{s};40m{token}"
89+
usedTokens += 1
90+
else: # over max token length
91+
discarded = discarded + f"\x1b[0;3{s};40m{token}"
92+
print(f"\n>> Tokens ({usedTokens}):\n{tokenized}\x1b[0m")
93+
if discarded != "":
94+
print(
95+
f">> Tokens Discarded ({totalTokens-usedTokens}):\n{discarded}\x1b[0m"
96+
)

ldm/dream/devices.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import torch
2+
from torch import autocast
3+
from contextlib import contextmanager, nullcontext
24

35
def choose_torch_device() -> str:
46
'''Convenience routine for guessing which GPU device to run model on'''
@@ -8,10 +10,11 @@ def choose_torch_device() -> str:
810
return 'mps'
911
return 'cpu'
1012

11-
def choose_autocast_device(device) -> str:
13+
def choose_autocast_device(device):
1214
'''Returns an autocast compatible device from a torch device'''
1315
device_type = device.type # this returns 'mps' on M1
1416
# autocast only supports cuda or cpu
15-
if device_type not in ('cuda','cpu'):
16-
return 'cpu'
17-
return device_type
17+
if device_type in ('cuda','cpu'):
18+
return device_type,autocast
19+
else:
20+
return 'cpu',nullcontext

ldm/dream/generator/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
'''
2+
Initialization file for the ldm.dream.generator package
3+
'''
4+
from .base import Generator

ldm/dream/generator/base.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,158 @@
1+
'''
2+
Base class for ldm.dream.generator.*
3+
including img2img, txt2img, and inpaint
4+
'''
5+
import torch
6+
import numpy as np
7+
import random
8+
from tqdm import tqdm, trange
9+
from PIL import Image
10+
from einops import rearrange, repeat
11+
from pytorch_lightning import seed_everything
12+
from ldm.dream.devices import choose_autocast_device
13+
14+
downsampling = 8
15+
16+
class Generator():
17+
def __init__(self,model):
18+
self.model = model
19+
self.seed = None
20+
self.latent_channels = model.channels
21+
self.downsampling_factor = downsampling # BUG: should come from model or config
22+
self.variation_amount = 0
23+
self.with_variations = []
24+
25+
# this is going to be overridden in img2img.py, txt2img.py and inpaint.py
26+
def get_make_image(self,prompt,**kwargs):
27+
"""
28+
Returns a function returning an image derived from the prompt and the initial image
29+
Return value depends on the seed at the time you call it
30+
"""
31+
raise NotImplementedError("image_iterator() must be implemented in a descendent class")
32+
33+
def set_variation(self, seed, variation_amount, with_variations):
34+
self.seed = seed
35+
self.variation_amount = variation_amount
36+
self.with_variations = with_variations
37+
38+
def generate(self,prompt,init_image,width,height,iterations=1,seed=None,
39+
image_callback=None, step_callback=None,
40+
**kwargs):
41+
device_type,scope = choose_autocast_device(self.model.device)
42+
make_image = self.get_make_image(
43+
prompt,
44+
init_image = init_image,
45+
width = width,
46+
height = height,
47+
step_callback = step_callback,
48+
**kwargs
49+
)
50+
51+
results = []
52+
seed = seed if seed else self.new_seed()
53+
seed, initial_noise = self.generate_initial_noise(seed, width, height)
54+
with scope(device_type), self.model.ema_scope():
55+
for n in trange(iterations, desc='Generating'):
56+
x_T = None
57+
if self.variation_amount > 0:
58+
seed_everything(seed)
59+
target_noise = self.get_noise(width,height)
60+
x_T = self.slerp(self.variation_amount, initial_noise, target_noise)
61+
elif initial_noise is not None:
62+
# i.e. we specified particular variations
63+
x_T = initial_noise
64+
else:
65+
seed_everything(seed)
66+
if self.model.device.type == 'mps':
67+
x_T = self.get_noise(width,height)
68+
69+
# make_image will do the equivalent of get_noise itself
70+
image = make_image(x_T)
71+
results.append([image, seed])
72+
if image_callback is not None:
73+
image_callback(image, seed)
74+
seed = self.new_seed()
75+
return results
76+
77+
def sample_to_image(self,samples):
78+
"""
79+
Returns a function returning an image derived from the prompt and the initial image
80+
Return value depends on the seed at the time you call it
81+
"""
82+
x_samples = self.model.decode_first_stage(samples)
83+
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
84+
if len(x_samples) != 1:
85+
raise Exception(
86+
f'>> expected to get a single image, but got {len(x_samples)}')
87+
x_sample = 255.0 * rearrange(
88+
x_samples[0].cpu().numpy(), 'c h w -> h w c'
89+
)
90+
return Image.fromarray(x_sample.astype(np.uint8))
91+
92+
def generate_initial_noise(self, seed, width, height):
93+
initial_noise = None
94+
if self.variation_amount > 0 or len(self.with_variations) > 0:
95+
# use fixed initial noise plus random noise per iteration
96+
seed_everything(seed)
97+
initial_noise = self.get_noise(width,height)
98+
for v_seed, v_weight in self.with_variations:
99+
seed = v_seed
100+
seed_everything(seed)
101+
next_noise = self.get_noise(width,height)
102+
initial_noise = self.slerp(v_weight, initial_noise, next_noise)
103+
if self.variation_amount > 0:
104+
random.seed() # reset RNG to an actually random state, so we can get a random seed for variations
105+
seed = random.randrange(0,np.iinfo(np.uint32).max)
106+
return (seed, initial_noise)
107+
else:
108+
return (seed, None)
109+
110+
# returns a tensor filled with random numbers from a normal distribution
111+
def get_noise(self,width,height):
112+
"""
113+
Returns a tensor filled with random numbers, either form a normal distribution
114+
(txt2img) or from the latent image (img2img, inpaint)
115+
"""
116+
raise NotImplementedError("get_noise() must be implemented in a descendent class")
117+
118+
def new_seed(self):
119+
self.seed = random.randrange(0, np.iinfo(np.uint32).max)
120+
return self.seed
121+
122+
def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
123+
'''
124+
Spherical linear interpolation
125+
Args:
126+
t (float/np.ndarray): Float value between 0.0 and 1.0
127+
v0 (np.ndarray): Starting vector
128+
v1 (np.ndarray): Final vector
129+
DOT_THRESHOLD (float): Threshold for considering the two vectors as
130+
colineal. Not recommended to alter this.
131+
Returns:
132+
v2 (np.ndarray): Interpolation vector between v0 and v1
133+
'''
134+
inputs_are_torch = False
135+
if not isinstance(v0, np.ndarray):
136+
inputs_are_torch = True
137+
v0 = v0.detach().cpu().numpy()
138+
if not isinstance(v1, np.ndarray):
139+
inputs_are_torch = True
140+
v1 = v1.detach().cpu().numpy()
141+
142+
dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
143+
if np.abs(dot) > DOT_THRESHOLD:
144+
v2 = (1 - t) * v0 + t * v1
145+
else:
146+
theta_0 = np.arccos(dot)
147+
sin_theta_0 = np.sin(theta_0)
148+
theta_t = theta_0 * t
149+
sin_theta_t = np.sin(theta_t)
150+
s0 = np.sin(theta_0 - theta_t) / sin_theta_0
151+
s1 = sin_theta_t / sin_theta_0
152+
v2 = s0 * v0 + s1 * v1
153+
154+
if inputs_are_torch:
155+
v2 = torch.from_numpy(v2).to(self.model.device)
156+
157+
return v2
158+

ldm/dream/generator/img2img.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
'''
2+
ldm.dream.generator.txt2img descends from ldm.dream.generator
3+
'''
4+
5+
import torch
6+
import numpy as np
7+
from ldm.dream.devices import choose_autocast_device
8+
from ldm.dream.generator.base import Generator
9+
from ldm.models.diffusion.ddim import DDIMSampler
10+
11+
class Img2Img(Generator):
12+
def __init__(self,model):
13+
super().__init__(model)
14+
self.init_latent = None # by get_noise()
15+
16+
@torch.no_grad()
17+
def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
18+
conditioning,init_image,strength,step_callback=None,**kwargs):
19+
"""
20+
Returns a function returning an image derived from the prompt and the initial image
21+
Return value depends on the seed at the time you call it.
22+
"""
23+
24+
# PLMS sampler not supported yet, so ignore previous sampler
25+
if not isinstance(sampler,DDIMSampler):
26+
print(
27+
f">> sampler '{sampler.__class__.__name__}' is not yet supported. Using DDIM sampler"
28+
)
29+
sampler = DDIMSampler(self.model, device=self.model.device)
30+
31+
sampler.make_schedule(
32+
ddim_num_steps=steps, ddim_eta=ddim_eta, verbose=False
33+
)
34+
35+
device_type,scope = choose_autocast_device(self.model.device)
36+
with scope(device_type):
37+
self.init_latent = self.model.get_first_stage_encoding(
38+
self.model.encode_first_stage(init_image)
39+
) # move to latent space
40+
41+
t_enc = int(strength * steps)
42+
uc, c = conditioning
43+
44+
@torch.no_grad()
45+
def make_image(x_T):
46+
# encode (scaled latent)
47+
z_enc = sampler.stochastic_encode(
48+
self.init_latent,
49+
torch.tensor([t_enc]).to(self.model.device),
50+
noise=x_T
51+
)
52+
# decode it
53+
samples = sampler.decode(
54+
z_enc,
55+
c,
56+
t_enc,
57+
img_callback = step_callback,
58+
unconditional_guidance_scale=cfg_scale,
59+
unconditional_conditioning=uc,
60+
)
61+
return self.sample_to_image(samples)
62+
63+
return make_image
64+
65+
def get_noise(self,width,height):
66+
device = self.model.device
67+
init_latent = self.init_latent
68+
assert init_latent is not None,'call to get_noise() when init_latent not set'
69+
if device.type == 'mps':
70+
return torch.randn_like(init_latent, device='cpu').to(device)
71+
else:
72+
return torch.randn_like(init_latent, device=device)

0 commit comments

Comments
 (0)