Skip to content

Commit 7c4b38b

Browse files
authored
Removing .float() (autocast in fp16 will discard this (I think)). (#495)
1 parent ab7a78e commit 7c4b38b

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/diffusers/models/resnet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def forward(self, x, temb):
333333

334334
# make sure hidden states is in float32
335335
# when running in half-precision
336-
hidden_states = self.norm1(hidden_states.float()).type(hidden_states.dtype)
336+
hidden_states = self.norm1(hidden_states).type(hidden_states.dtype)
337337
hidden_states = self.nonlinearity(hidden_states)
338338

339339
if self.upsample is not None:
@@ -351,7 +351,7 @@ def forward(self, x, temb):
351351

352352
# make sure hidden states is in float32
353353
# when running in half-precision
354-
hidden_states = self.norm2(hidden_states.float()).type(hidden_states.dtype)
354+
hidden_states = self.norm2(hidden_states).type(hidden_states.dtype)
355355
hidden_states = self.nonlinearity(hidden_states)
356356

357357
hidden_states = self.dropout(hidden_states)

0 commit comments

Comments
 (0)