Skip to content

Commit 723d074

Browse files
authored
Allow ctrl c when using --from_file (CompVis#472)
* added ansi escapes to highlight key parts of CLI session * adjust exception handling so that ^C will abort when reading prompts from a file
1 parent 75f633c commit 723d074

File tree

2 files changed

+21
-6
lines changed

2 files changed

+21
-6
lines changed

ldm/generate.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def __init__(
117117
seamless = False,
118118
embedding_path = None,
119119
device_type = 'cuda',
120+
ignore_ctrl_c = False,
120121
):
121122
self.iterations = iterations
122123
self.width = width
@@ -134,6 +135,7 @@ def __init__(
134135
self.seamless = seamless
135136
self.embedding_path = embedding_path
136137
self.device_type = device_type
138+
self.ignore_ctrl_c = ignore_ctrl_c # note, this logic probably doesn't belong here...
137139
self.model = None # empty for now
138140
self.sampler = None
139141
self.device = None
@@ -210,7 +212,7 @@ def prompt2image(
210212
**args,
211213
): # eat up additional cruft
212214
"""
213-
ldm.prompt2image() is the common entry point for txt2img() and img2img()
215+
ldm.generate.prompt2image() is the common entry point for txt2img() and img2img()
214216
It takes the following arguments:
215217
prompt // prompt string (no default)
216218
iterations // iterations (1); image count=iterations
@@ -341,6 +343,8 @@ def process_image(image,seed):
341343

342344
except KeyboardInterrupt:
343345
print('*interrupted*')
346+
if not self.ignore_ctrl_c:
347+
raise KeyboardInterrupt
344348
print(
345349
'>> Partial results will be returned; if --grid was requested, nothing will be returned.'
346350
)

scripts/dream.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@
1515
from ldm.dream.image_util import make_grid
1616
from omegaconf import OmegaConf
1717

18+
# Placeholder to be replaced with proper class that tracks the
19+
# outputs and associates with the prompt that generated them.
20+
# Just want to get the formatting look right for now.
21+
output_cntr = 0
22+
1823
def main():
1924
"""Initialize command-line parsers and the diffusion model"""
2025
arg_parser = create_argv_parser()
@@ -63,7 +68,8 @@ def main():
6368
# this is solely for recreating the prompt
6469
seamless = opt.seamless,
6570
embedding_path = opt.embedding_path,
66-
device_type = opt.device
71+
device_type = opt.device,
72+
ignore_ctrl_c = opt.infile is None,
6773
)
6874

6975
# make sure the output directory exists
@@ -292,16 +298,18 @@ def image_writer(image, seed, upscaled=False):
292298
print(e)
293299
continue
294300

295-
print('Outputs:')
301+
print('\033[1mOutputs:\033[0m')
296302
log_path = os.path.join(current_outdir, 'dream_log.txt')
297303
write_log_message(results, log_path)
298304

299-
print('goodbye!')
305+
print('goodbye!\033[0m')
300306

301307

302308
def get_next_command(infile=None) -> str: #command string
303309
if infile is None:
304-
command = input('dream> ')
310+
print('\033[1m') # add some boldface
311+
command = input('dream> ')
312+
print('\033[0m',end='')
305313
else:
306314
command = infile.readline()
307315
if not command:
@@ -339,8 +347,11 @@ def dream_server_loop(t2i, host, port, outdir):
339347

340348
def write_log_message(results, log_path):
341349
"""logs the name of the output image, prompt, and prompt args to the terminal and log file"""
350+
global output_cntr
342351
log_lines = [f'{path}: {prompt}\n' for path, prompt in results]
343-
print(*log_lines, sep='')
352+
for l in log_lines:
353+
output_cntr += 1
354+
print(f'\033[1m[{output_cntr}]\033[0m {l}',end='')
344355

345356
with open(log_path, 'a', encoding='utf-8') as file:
346357
file.writelines(log_lines)

0 commit comments

Comments
 (0)