Skip to content

Commit d922b53

Browse files
committed
added seamless tiling mode and commands
1 parent 3393643 commit d922b53

File tree

2 files changed

+35
-0
lines changed

2 files changed

+35
-0
lines changed

ldm/simplet2i.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from tqdm import tqdm, trange
1515
from itertools import islice
1616
from einops import rearrange, repeat
17+
from torch import nn
1718
from torchvision.utils import make_grid
1819
from pytorch_lightning import seed_everything
1920
from torch import autocast
@@ -109,6 +110,7 @@ class T2I:
109110
downsampling_factor
110111
precision
111112
strength
113+
seamless
112114
embedding_path
113115
114116
The vast majority of these arguments default to reasonable values.
@@ -132,6 +134,7 @@ def __init__(
132134
precision='autocast',
133135
full_precision=False,
134136
strength=0.75, # default in scripts/img2img.py
137+
seamless=False,
135138
embedding_path=None,
136139
device_type = 'cuda',
137140
# just to keep track of this parameter when regenerating prompt
@@ -153,6 +156,7 @@ def __init__(
153156
self.precision = precision
154157
self.full_precision = full_precision
155158
self.strength = strength
159+
self.seamless = seamless
156160
self.embedding_path = embedding_path
157161
self.device_type = device_type
158162
self.model = None # empty for now
@@ -217,6 +221,7 @@ def prompt2image(
217221
step_callback = None,
218222
width = None,
219223
height = None,
224+
seamless = False,
220225
# these are specific to img2img
221226
init_img = None,
222227
fit = False,
@@ -238,6 +243,7 @@ def prompt2image(
238243
width // width of image, in multiples of 64 (512)
239244
height // height of image, in multiples of 64 (512)
240245
cfg_scale // how strongly the prompt influences the image (7.5) (must be >1)
246+
seamless // whether the generated image should tile
241247
init_img // path to an initial image - its dimensions override width and height
242248
strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely
243249
gfpgan_strength // strength for GFPGAN. 0.0 preserves image exactly, 1.0 replaces it completely
@@ -265,6 +271,7 @@ def process_image(image,seed):
265271
seed = seed or self.seed
266272
width = width or self.width
267273
height = height or self.height
274+
seamless = seamless or self.seamless
268275
cfg_scale = cfg_scale or self.cfg_scale
269276
ddim_eta = ddim_eta or self.ddim_eta
270277
iterations = iterations or self.iterations
@@ -274,6 +281,10 @@ def process_image(image,seed):
274281
model = (
275282
self.load_model()
276283
) # will instantiate the model or return it from cache
284+
for m in model.modules():
285+
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
286+
m.padding_mode = 'circular' if seamless else m._orig_padding_mode
287+
277288
assert cfg_scale > 1.0, 'CFG_Scale (-C) must be >1.0'
278289
assert (
279290
0.0 <= strength <= 1.0
@@ -562,6 +573,10 @@ def load_model(self):
562573

563574
self._set_sampler()
564575

576+
for m in self.model.modules():
577+
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
578+
m._orig_padding_mode = m.padding_mode
579+
565580
return self.model
566581

567582
def _set_sampler(self):

scripts/dream.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import copy
1010
import warnings
1111
import time
12+
import torch.nn as nn
1213
from ldm.dream.devices import choose_torch_device
1314
import ldm.dream.readline
1415
from ldm.dream.pngwriter import PngWriter, PromptFormatter
@@ -60,6 +61,7 @@ def main():
6061
grid = opt.grid,
6162
# this is solely for recreating the prompt
6263
latent_diffusion_weights=opt.laion400m,
64+
seamless=opt.seamless,
6365
embedding_path=opt.embedding_path,
6466
device_type=opt.device
6567
)
@@ -92,6 +94,14 @@ def main():
9294
f'>> model loaded in', '%4.2fs' % (time.time() - tic)
9395
)
9496

97+
for m in t2i.model.modules():
98+
if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
99+
m._orig_padding_mode = m.padding_mode
100+
if opt.seamless:
101+
m.padding_mode = 'circular'
102+
if opt.seamless:
103+
print(">> changed to seamless tiling mode")
104+
95105
if not infile:
96106
print(
97107
"\n* Initialization done! Awaiting your command (-h for help, 'q' to quit)"
@@ -374,6 +384,11 @@ def create_argv_parser():
374384
default='outputs/img-samples',
375385
help='Directory to save generated images and a log of prompts and seeds. Default: outputs/img-samples',
376386
)
387+
parser.add_argument(
388+
'--seamless',
389+
action='store_true',
390+
help='Change the model to seamless tiling (circular) mode',
391+
)
377392
parser.add_argument(
378393
'--embedding_path',
379394
type=str,
@@ -474,6 +489,11 @@ def create_cmd_parser():
474489
default=None,
475490
help='Directory to save generated images and a log of prompts and seeds',
476491
)
492+
parser.add_argument(
493+
'--seamless',
494+
action='store_true',
495+
help='Change the model to seamless tiling (circular) mode',
496+
)
477497
parser.add_argument(
478498
'-i',
479499
'--individual',

0 commit comments

Comments
 (0)