Skip to content

Commit eef7889

Browse files
authored
feat(txt2img): allow from_file to work with len(lines) < batch_size (#349)
1 parent 720e5cd commit eef7889

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

scripts/orig_scripts/txt2img.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,12 @@ def forward(self, x, sigma, uncond, cond, cond_scale):
232232
print(f"reading prompts from {opt.from_file}")
233233
with open(opt.from_file, "r") as f:
234234
data = f.read().splitlines()
235-
data = list(chunk(data, batch_size))
235+
if (len(data) >= batch_size):
236+
data = list(chunk(data, batch_size))
237+
else:
238+
while (len(data) < batch_size):
239+
data.append(data[-1])
240+
data = [data]
236241

237242
sample_path = os.path.join(outpath, "samples")
238243
os.makedirs(sample_path, exist_ok=True)
@@ -264,7 +269,7 @@ def forward(self, x, sigma, uncond, cond, cond_scale):
264269
prompts = list(prompts)
265270
c = model.get_learned_conditioning(prompts)
266271
shape = [opt.C, opt.H // opt.f, opt.W // opt.f]
267-
272+
268273
if not opt.klms:
269274
samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
270275
conditioning=c,
@@ -284,7 +289,7 @@ def forward(self, x, sigma, uncond, cond, cond_scale):
284289
model_wrap_cfg = CFGDenoiser(model_wrap)
285290
extra_args = {'cond': c, 'uncond': uc, 'cond_scale': opt.scale}
286291
samples_ddim = K.sampling.sample_lms(model_wrap_cfg, x, sigmas, extra_args=extra_args)
287-
292+
288293
x_samples_ddim = model.decode_first_stage(samples_ddim)
289294
x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)
290295

0 commit comments

Comments
 (0)