@@ -84,6 +84,78 @@ def get_3d_sincos_pos_embed(
8484 temporal_size : int ,
8585 spatial_interpolation_scale : float = 1.0 ,
8686 temporal_interpolation_scale : float = 1.0 ,
87+ device : Optional [torch .device ] = None ,
88+ output_type : str = "np" ,
89+ ) -> torch .Tensor :
90+ r"""
91+ Creates 3D sinusoidal positional embeddings.
92+
93+ Args:
94+ embed_dim (`int`):
95+ The embedding dimension of inputs. It must be divisible by 16.
96+ spatial_size (`int` or `Tuple[int, int]`):
97+ The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
98+ spatial dimensions (height and width).
99+ temporal_size (`int`):
100+ The temporal dimension of postional embeddings (number of frames).
101+ spatial_interpolation_scale (`float`, defaults to 1.0):
102+ Scale factor for spatial grid interpolation.
103+ temporal_interpolation_scale (`float`, defaults to 1.0):
104+ Scale factor for temporal grid interpolation.
105+
106+ Returns:
107+ `torch.Tensor`:
108+ The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
109+ embed_dim]`.
110+ """
111+ if output_type == "np" :
112+ return _get_3d_sincos_pos_embed_np (
113+ embed_dim = embed_dim ,
114+ spatial_size = spatial_size ,
115+ temporal_size = temporal_size ,
116+ spatial_interpolation_scale = spatial_interpolation_scale ,
117+ temporal_interpolation_scale = temporal_interpolation_scale ,
118+ )
119+ if embed_dim % 4 != 0 :
120+ raise ValueError ("`embed_dim` must be divisible by 4" )
121+ if isinstance (spatial_size , int ):
122+ spatial_size = (spatial_size , spatial_size )
123+
124+ embed_dim_spatial = 3 * embed_dim // 4
125+ embed_dim_temporal = embed_dim // 4
126+
127+ # 1. Spatial
128+ grid_h = torch .arange (spatial_size [1 ], device = device , dtype = torch .float32 ) / spatial_interpolation_scale
129+ grid_w = torch .arange (spatial_size [0 ], device = device , dtype = torch .float32 ) / spatial_interpolation_scale
130+ grid = torch .meshgrid (grid_w , grid_h , indexing = "xy" ) # here w goes first
131+ grid = torch .stack (grid , dim = 0 )
132+
133+ grid = grid .reshape ([2 , 1 , spatial_size [1 ], spatial_size [0 ]])
134+ pos_embed_spatial = get_2d_sincos_pos_embed_from_grid (embed_dim_spatial , grid , output_type = "pt" )
135+
136+ # 2. Temporal
137+ grid_t = torch .arange (temporal_size , device = device , dtype = torch .float32 ) / temporal_interpolation_scale
138+ pos_embed_temporal = get_1d_sincos_pos_embed_from_grid (embed_dim_temporal , grid_t , output_type = "pt" )
139+
140+ # 3. Concat
141+ pos_embed_spatial = pos_embed_spatial [None , :, :]
142+ pos_embed_spatial = pos_embed_spatial .repeat_interleave (temporal_size , dim = 0 ) # [T, H*W, D // 4 * 3]
143+
144+ pos_embed_temporal = pos_embed_temporal [:, None , :]
145+ pos_embed_temporal = pos_embed_temporal .repeat_interleave (
146+ spatial_size [0 ] * spatial_size [1 ], dim = 1
147+ ) # [T, H*W, D // 4]
148+
149+ pos_embed = torch .concat ([pos_embed_temporal , pos_embed_spatial ], dim = - 1 ) # [T, H*W, D]
150+ return pos_embed
151+
152+
153+ def _get_3d_sincos_pos_embed_np (
154+ embed_dim : int ,
155+ spatial_size : Union [int , Tuple [int , int ]],
156+ temporal_size : int ,
157+ spatial_interpolation_scale : float = 1.0 ,
158+ temporal_interpolation_scale : float = 1.0 ,
87159) -> np .ndarray :
88160 r"""
89161 Creates 3D sinusoidal positional embeddings.
@@ -106,6 +178,12 @@ def get_3d_sincos_pos_embed(
106178 The 3D sinusoidal positional embeddings of shape `[temporal_size, spatial_size[0] * spatial_size[1],
107179 embed_dim]`.
108180 """
181+ deprecation_message = (
182+ "`get_3d_sincos_pos_embed` uses `torch` and supports `device`."
183+ " `from_numpy` is no longer required."
184+ " Pass `output_type='pt' to use the new version now."
185+ )
186+ deprecate ("output_type=='np'" , "0.33.0" , deprecation_message , standard_warn = False )
109187 if embed_dim % 4 != 0 :
110188 raise ValueError ("`embed_dim` must be divisible by 4" )
111189 if isinstance (spatial_size , int ):
@@ -139,6 +217,143 @@ def get_3d_sincos_pos_embed(
139217
140218
141219def get_2d_sincos_pos_embed (
220+ embed_dim ,
221+ grid_size ,
222+ cls_token = False ,
223+ extra_tokens = 0 ,
224+ interpolation_scale = 1.0 ,
225+ base_size = 16 ,
226+ device : Optional [torch .device ] = None ,
227+ output_type : str = "np" ,
228+ ):
229+ """
230+ Creates 2D sinusoidal positional embeddings.
231+
232+ Args:
233+ embed_dim (`int`):
234+ The embedding dimension.
235+ grid_size (`int`):
236+ The size of the grid height and width.
237+ cls_token (`bool`, defaults to `False`):
238+ Whether or not to add a classification token.
239+ extra_tokens (`int`, defaults to `0`):
240+ The number of extra tokens to add.
241+ interpolation_scale (`float`, defaults to `1.0`):
242+ The scale of the interpolation.
243+
244+ Returns:
245+ pos_embed (`torch.Tensor`):
246+ Shape is either `[grid_size * grid_size, embed_dim]` if not using cls_token, or `[1 + grid_size*grid_size,
247+ embed_dim]` if using cls_token
248+ """
249+ if output_type == "np" :
250+ deprecation_message = (
251+ "`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
252+ " `from_numpy` is no longer required."
253+ " Pass `output_type='pt' to use the new version now."
254+ )
255+ deprecate ("output_type=='np'" , "0.33.0" , deprecation_message , standard_warn = False )
256+ return get_2d_sincos_pos_embed_np (
257+ embed_dim = embed_dim ,
258+ grid_size = grid_size ,
259+ cls_token = cls_token ,
260+ extra_tokens = extra_tokens ,
261+ interpolation_scale = interpolation_scale ,
262+ base_size = base_size ,
263+ )
264+ if isinstance (grid_size , int ):
265+ grid_size = (grid_size , grid_size )
266+
267+ grid_h = (
268+ torch .arange (grid_size [0 ], device = device , dtype = torch .float32 )
269+ / (grid_size [0 ] / base_size )
270+ / interpolation_scale
271+ )
272+ grid_w = (
273+ torch .arange (grid_size [1 ], device = device , dtype = torch .float32 )
274+ / (grid_size [1 ] / base_size )
275+ / interpolation_scale
276+ )
277+ grid = torch .meshgrid (grid_w , grid_h , indexing = "xy" ) # here w goes first
278+ grid = torch .stack (grid , dim = 0 )
279+
280+ grid = grid .reshape ([2 , 1 , grid_size [1 ], grid_size [0 ]])
281+ pos_embed = get_2d_sincos_pos_embed_from_grid (embed_dim , grid , output_type = output_type )
282+ if cls_token and extra_tokens > 0 :
283+ pos_embed = torch .concat ([torch .zeros ([extra_tokens , embed_dim ]), pos_embed ], dim = 0 )
284+ return pos_embed
285+
286+
287+ def get_2d_sincos_pos_embed_from_grid (embed_dim , grid , output_type = "np" ):
288+ r"""
289+ This function generates 2D sinusoidal positional embeddings from a grid.
290+
291+ Args:
292+ embed_dim (`int`): The embedding dimension.
293+ grid (`torch.Tensor`): Grid of positions with shape `(H * W,)`.
294+
295+ Returns:
296+ `torch.Tensor`: The 2D sinusoidal positional embeddings with shape `(H * W, embed_dim)`
297+ """
298+ if output_type == "np" :
299+ deprecation_message = (
300+ "`get_2d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
301+ " `from_numpy` is no longer required."
302+ " Pass `output_type='pt' to use the new version now."
303+ )
304+ deprecate ("output_type=='np'" , "0.33.0" , deprecation_message , standard_warn = False )
305+ return get_2d_sincos_pos_embed_from_grid_np (
306+ embed_dim = embed_dim ,
307+ grid = grid ,
308+ )
309+ if embed_dim % 2 != 0 :
310+ raise ValueError ("embed_dim must be divisible by 2" )
311+
312+ # use half of dimensions to encode grid_h
313+ emb_h = get_1d_sincos_pos_embed_from_grid (embed_dim // 2 , grid [0 ], output_type = output_type ) # (H*W, D/2)
314+ emb_w = get_1d_sincos_pos_embed_from_grid (embed_dim // 2 , grid [1 ], output_type = output_type ) # (H*W, D/2)
315+
316+ emb = torch .concat ([emb_h , emb_w ], dim = 1 ) # (H*W, D)
317+ return emb
318+
319+
320+ def get_1d_sincos_pos_embed_from_grid (embed_dim , pos , output_type = "np" ):
321+ """
322+ This function generates 1D positional embeddings from a grid.
323+
324+ Args:
325+ embed_dim (`int`): The embedding dimension `D`
326+ pos (`torch.Tensor`): 1D tensor of positions with shape `(M,)`
327+
328+ Returns:
329+ `torch.Tensor`: Sinusoidal positional embeddings of shape `(M, D)`.
330+ """
331+ if output_type == "np" :
332+ deprecation_message = (
333+ "`get_1d_sincos_pos_embed_from_grid` uses `torch` and supports `device`."
334+ " `from_numpy` is no longer required."
335+ " Pass `output_type='pt' to use the new version now."
336+ )
337+ deprecate ("output_type=='np'" , "0.33.0" , deprecation_message , standard_warn = False )
338+ return get_1d_sincos_pos_embed_from_grid_np (embed_dim = embed_dim , pos = pos )
339+ if embed_dim % 2 != 0 :
340+ raise ValueError ("embed_dim must be divisible by 2" )
341+
342+ omega = torch .arange (embed_dim // 2 , device = pos .device , dtype = torch .float64 )
343+ omega /= embed_dim / 2.0
344+ omega = 1.0 / 10000 ** omega # (D/2,)
345+
346+ pos = pos .reshape (- 1 ) # (M,)
347+ out = torch .outer (pos , omega ) # (M, D/2), outer product
348+
349+ emb_sin = torch .sin (out ) # (M, D/2)
350+ emb_cos = torch .cos (out ) # (M, D/2)
351+
352+ emb = torch .concat ([emb_sin , emb_cos ], dim = 1 ) # (M, D)
353+ return emb
354+
355+
356+ def get_2d_sincos_pos_embed_np (
142357 embed_dim , grid_size , cls_token = False , extra_tokens = 0 , interpolation_scale = 1.0 , base_size = 16
143358):
144359 """
@@ -170,13 +385,13 @@ def get_2d_sincos_pos_embed(
170385 grid = np .stack (grid , axis = 0 )
171386
172387 grid = grid .reshape ([2 , 1 , grid_size [1 ], grid_size [0 ]])
173- pos_embed = get_2d_sincos_pos_embed_from_grid (embed_dim , grid )
388+ pos_embed = get_2d_sincos_pos_embed_from_grid_np (embed_dim , grid )
174389 if cls_token and extra_tokens > 0 :
175390 pos_embed = np .concatenate ([np .zeros ([extra_tokens , embed_dim ]), pos_embed ], axis = 0 )
176391 return pos_embed
177392
178393
179- def get_2d_sincos_pos_embed_from_grid (embed_dim , grid ):
394+ def get_2d_sincos_pos_embed_from_grid_np (embed_dim , grid ):
180395 r"""
181396 This function generates 2D sinusoidal positional embeddings from a grid.
182397
@@ -191,14 +406,14 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
191406 raise ValueError ("embed_dim must be divisible by 2" )
192407
193408 # use half of dimensions to encode grid_h
194- emb_h = get_1d_sincos_pos_embed_from_grid (embed_dim // 2 , grid [0 ]) # (H*W, D/2)
195- emb_w = get_1d_sincos_pos_embed_from_grid (embed_dim // 2 , grid [1 ]) # (H*W, D/2)
409+ emb_h = get_1d_sincos_pos_embed_from_grid_np (embed_dim // 2 , grid [0 ]) # (H*W, D/2)
410+ emb_w = get_1d_sincos_pos_embed_from_grid_np (embed_dim // 2 , grid [1 ]) # (H*W, D/2)
196411
197412 emb = np .concatenate ([emb_h , emb_w ], axis = 1 ) # (H*W, D)
198413 return emb
199414
200415
201- def get_1d_sincos_pos_embed_from_grid (embed_dim , pos ):
416+ def get_1d_sincos_pos_embed_from_grid_np (embed_dim , pos ):
202417 """
203418 This function generates 1D positional embeddings from a grid.
204419
@@ -288,10 +503,14 @@ def __init__(
288503 self .pos_embed = None
289504 elif pos_embed_type == "sincos" :
290505 pos_embed = get_2d_sincos_pos_embed (
291- embed_dim , grid_size , base_size = self .base_size , interpolation_scale = self .interpolation_scale
506+ embed_dim ,
507+ grid_size ,
508+ base_size = self .base_size ,
509+ interpolation_scale = self .interpolation_scale ,
510+ output_type = "pt" ,
292511 )
293512 persistent = True if pos_embed_max_size else False
294- self .register_buffer ("pos_embed" , torch . from_numpy ( pos_embed ) .float ().unsqueeze (0 ), persistent = persistent )
513+ self .register_buffer ("pos_embed" , pos_embed .float ().unsqueeze (0 ), persistent = persistent )
295514 else :
296515 raise ValueError (f"Unsupported pos_embed_type: { pos_embed_type } " )
297516
@@ -341,8 +560,10 @@ def forward(self, latent):
341560 grid_size = (height , width ),
342561 base_size = self .base_size ,
343562 interpolation_scale = self .interpolation_scale ,
563+ device = latent .device ,
564+ output_type = "pt" ,
344565 )
345- pos_embed = torch . from_numpy ( pos_embed ) .float ().unsqueeze (0 ). to ( latent . device )
566+ pos_embed = pos_embed .float ().unsqueeze (0 )
346567 else :
347568 pos_embed = self .pos_embed
348569
@@ -453,7 +674,9 @@ def __init__(
453674 pos_embedding = self ._get_positional_embeddings (sample_height , sample_width , sample_frames )
454675 self .register_buffer ("pos_embedding" , pos_embedding , persistent = persistent )
455676
456- def _get_positional_embeddings (self , sample_height : int , sample_width : int , sample_frames : int ) -> torch .Tensor :
677+ def _get_positional_embeddings (
678+ self , sample_height : int , sample_width : int , sample_frames : int , device : Optional [torch .device ] = None
679+ ) -> torch .Tensor :
457680 post_patch_height = sample_height // self .patch_size
458681 post_patch_width = sample_width // self .patch_size
459682 post_time_compression_frames = (sample_frames - 1 ) // self .temporal_compression_ratio + 1
@@ -465,8 +688,10 @@ def _get_positional_embeddings(self, sample_height: int, sample_width: int, samp
465688 post_time_compression_frames ,
466689 self .spatial_interpolation_scale ,
467690 self .temporal_interpolation_scale ,
691+ device = device ,
692+ output_type = "pt" ,
468693 )
469- pos_embedding = torch . from_numpy ( pos_embedding ) .flatten (0 , 1 )
694+ pos_embedding = pos_embedding .flatten (0 , 1 )
470695 joint_pos_embedding = torch .zeros (
471696 1 , self .max_text_seq_length + num_patches , self .embed_dim , requires_grad = False
472697 )
@@ -521,8 +746,10 @@ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
521746 or self .sample_width != width
522747 or self .sample_frames != pre_time_compression_frames
523748 ):
524- pos_embedding = self ._get_positional_embeddings (height , width , pre_time_compression_frames )
525- pos_embedding = pos_embedding .to (embeds .device , dtype = embeds .dtype )
749+ pos_embedding = self ._get_positional_embeddings (
750+ height , width , pre_time_compression_frames , device = embeds .device
751+ )
752+ pos_embedding = pos_embedding .to (dtype = embeds .dtype )
526753 else :
527754 pos_embedding = self .pos_embedding
528755
@@ -552,9 +779,11 @@ def __init__(
552779 # Linear projection for text embeddings
553780 self .text_proj = nn .Linear (text_hidden_size , hidden_size )
554781
555- pos_embed = get_2d_sincos_pos_embed (hidden_size , pos_embed_max_size , base_size = pos_embed_max_size )
782+ pos_embed = get_2d_sincos_pos_embed (
783+ hidden_size , pos_embed_max_size , base_size = pos_embed_max_size , output_type = "pt"
784+ )
556785 pos_embed = pos_embed .reshape (pos_embed_max_size , pos_embed_max_size , hidden_size )
557- self .register_buffer ("pos_embed" , torch . from_numpy ( pos_embed ) .float (), persistent = False )
786+ self .register_buffer ("pos_embed" , pos_embed .float (), persistent = False )
558787
559788 def forward (self , hidden_states : torch .Tensor , encoder_hidden_states : torch .Tensor ) -> torch .Tensor :
560789 batch_size , channel , height , width = hidden_states .shape
0 commit comments