@@ -137,18 +137,18 @@ def _set_attention_slice(self, slice_size):
137137 for block in self .transformer_blocks :
138138 block ._set_attention_slice (slice_size )
139139
140- def forward (self , x , context = None ):
140+ def forward (self , hidden_states , context = None ):
141141 # note: if no context is given, cross-attention defaults to self-attention
142- b , c , h , w = x .shape
143- x_in = x
144- x = self .norm (x )
145- x = self .proj_in (x )
146- x = x .permute (0 , 2 , 3 , 1 ).reshape (b , h * w , c )
142+ batch , channel , height , weight = hidden_states .shape
143+ residual = hidden_states
144+ hidden_states = self .norm (hidden_states )
145+ hidden_states = self .proj_in (hidden_states )
146+ hidden_states = hidden_states .permute (0 , 2 , 3 , 1 ).reshape (batch , height * weight , channel )
147147 for block in self .transformer_blocks :
148- x = block (x , context = context )
149- x = x .reshape (b , h , w , c ).permute (0 , 3 , 1 , 2 )
150- x = self .proj_out (x )
151- return x + x_in
148+ hidden_states = block (hidden_states , context = context )
149+ hidden_states = hidden_states .reshape (batch , height , weight , channel ).permute (0 , 3 , 1 , 2 )
150+ hidden_states = self .proj_out (hidden_states )
151+ return hidden_states + residual
152152
153153
154154class BasicTransformerBlock (nn .Module ):
@@ -192,12 +192,12 @@ def _set_attention_slice(self, slice_size):
192192 self .attn1 ._slice_size = slice_size
193193 self .attn2 ._slice_size = slice_size
194194
195- def forward (self , x , context = None ):
196- x = x .contiguous () if x .device .type == "mps" else x
197- x = self .attn1 (self .norm1 (x )) + x
198- x = self .attn2 (self .norm2 (x ), context = context ) + x
199- x = self .ff (self .norm3 (x )) + x
200- return x
195+ def forward (self , hidden_states , context = None ):
196+ hidden_states = hidden_states .contiguous () if hidden_states .device .type == "mps" else hidden_states
197+ hidden_states = self .attn1 (self .norm1 (hidden_states )) + hidden_states
198+ hidden_states = self .attn2 (self .norm2 (hidden_states ), context = context ) + hidden_states
199+ hidden_states = self .ff (self .norm3 (hidden_states )) + hidden_states
200+ return hidden_states
201201
202202
203203class CrossAttention (nn .Module ):
@@ -247,22 +247,22 @@ def reshape_batch_dim_to_heads(self, tensor):
247247 tensor = tensor .permute (0 , 2 , 1 , 3 ).reshape (batch_size // head_size , seq_len , dim * head_size )
248248 return tensor
249249
250- def forward (self , x , context = None , mask = None ):
251- batch_size , sequence_length , dim = x .shape
250+ def forward (self , hidden_states , context = None , mask = None ):
251+ batch_size , sequence_length , dim = hidden_states .shape
252252
253- q = self .to_q (x )
254- context = context if context is not None else x
255- k = self .to_k (context )
256- v = self .to_v (context )
253+ query = self .to_q (hidden_states )
254+ context = context if context is not None else hidden_states
255+ key = self .to_k (context )
256+ value = self .to_v (context )
257257
258- q = self .reshape_heads_to_batch_dim (q )
259- k = self .reshape_heads_to_batch_dim (k )
260- v = self .reshape_heads_to_batch_dim (v )
258+ query = self .reshape_heads_to_batch_dim (query )
259+ key = self .reshape_heads_to_batch_dim (key )
260+ value = self .reshape_heads_to_batch_dim (value )
261261
262262 # TODO(PVP) - mask is currently never used. Remember to re-implement when used
263263
264264 # attention, what we cannot get enough of
265- hidden_states = self ._attention (q , k , v , sequence_length , dim )
265+ hidden_states = self ._attention (query , key , value , sequence_length , dim )
266266
267267 return self .to_out (hidden_states )
268268
@@ -308,8 +308,8 @@ def __init__(
308308
309309 self .net = nn .Sequential (project_in , nn .Dropout (dropout ), nn .Linear (inner_dim , dim_out ))
310310
311- def forward (self , x ):
312- return self .net (x )
311+ def forward (self , hidden_states ):
312+ return self .net (hidden_states )
313313
314314
315315# feedforward
@@ -326,6 +326,6 @@ def __init__(self, dim_in: int, dim_out: int):
326326 super ().__init__ ()
327327 self .proj = nn .Linear (dim_in , dim_out * 2 )
328328
329- def forward (self , x ):
330- x , gate = self .proj (x ).chunk (2 , dim = - 1 )
331- return x * F .gelu (gate )
329+ def forward (self , hidden_states ):
330+ hidden_states , gate = self .proj (hidden_states ).chunk (2 , dim = - 1 )
331+ return hidden_states * F .gelu (gate )
0 commit comments