1414from tqdm import tqdm , trange
1515from itertools import islice
1616from einops import rearrange , repeat
17+ from torch import nn
1718from torchvision .utils import make_grid
1819from pytorch_lightning import seed_everything
1920from torch import autocast
@@ -109,6 +110,7 @@ class T2I:
109110 downsampling_factor
110111 precision
111112 strength
113+ seamless
112114 embedding_path
113115
114116 The vast majority of these arguments default to reasonable values.
@@ -132,6 +134,7 @@ def __init__(
132134 precision = 'autocast' ,
133135 full_precision = False ,
134136 strength = 0.75 , # default in scripts/img2img.py
137+ seamless = False ,
135138 embedding_path = None ,
136139 device_type = 'cuda' ,
137140 # just to keep track of this parameter when regenerating prompt
@@ -153,6 +156,7 @@ def __init__(
153156 self .precision = precision
154157 self .full_precision = full_precision
155158 self .strength = strength
159+ self .seamless = seamless
156160 self .embedding_path = embedding_path
157161 self .device_type = device_type
158162 self .model = None # empty for now
@@ -217,6 +221,7 @@ def prompt2image(
217221 step_callback = None ,
218222 width = None ,
219223 height = None ,
224+ seamless = False ,
220225 # these are specific to img2img
221226 init_img = None ,
222227 fit = False ,
@@ -238,6 +243,7 @@ def prompt2image(
238243 width // width of image, in multiples of 64 (512)
239244 height // height of image, in multiples of 64 (512)
240245 cfg_scale // how strongly the prompt influences the image (7.5) (must be >1)
246+ seamless // whether the generated image should tile
241247 init_img // path to an initial image - its dimensions override width and height
242248 strength // strength for noising/unnoising init_img. 0.0 preserves image exactly, 1.0 replaces it completely
243249 gfpgan_strength // strength for GFPGAN. 0.0 preserves image exactly, 1.0 replaces it completely
@@ -265,6 +271,7 @@ def process_image(image,seed):
265271 seed = seed or self .seed
266272 width = width or self .width
267273 height = height or self .height
274+ seamless = seamless or self .seamless
268275 cfg_scale = cfg_scale or self .cfg_scale
269276 ddim_eta = ddim_eta or self .ddim_eta
270277 iterations = iterations or self .iterations
@@ -274,6 +281,10 @@ def process_image(image,seed):
274281 model = (
275282 self .load_model ()
276283 ) # will instantiate the model or return it from cache
284+ for m in model .modules ():
285+ if isinstance (m , (nn .Conv2d , nn .ConvTranspose2d )):
286+ m .padding_mode = 'circular' if seamless else m ._orig_padding_mode
287+
277288 assert cfg_scale > 1.0 , 'CFG_Scale (-C) must be >1.0'
278289 assert (
279290 0.0 <= strength <= 1.0
@@ -562,6 +573,10 @@ def load_model(self):
562573
563574 self ._set_sampler ()
564575
576+ for m in self .model .modules ():
577+ if isinstance (m , (nn .Conv2d , nn .ConvTranspose2d )):
578+ m ._orig_padding_mode = m .padding_mode
579+
565580 return self .model
566581
567582 def _set_sampler (self ):
0 commit comments