1414
1515# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
1616
17- # TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
18- from typing import Union
17+ import warnings
18+ from typing import Optional , Union
1919
2020import numpy as np
2121import torch
@@ -99,6 +99,11 @@ def get_adjacent_sigma(self, timesteps, t):
9999 raise ValueError (f"`self.tensor_format`: { self .tensor_format } is not valid." )
100100
101101 def set_seed (self , seed ):
102+ warnings .warn (
103+ "The method `set_seed` is deprecated and will be removed in version `0.4.0`. Please consider passing a"
104+ " generator instead." ,
105+ DeprecationWarning ,
106+ )
102107 tensor_format = getattr (self , "tensor_format" , "pt" )
103108 if tensor_format == "np" :
104109 np .random .seed (seed )
@@ -112,14 +117,14 @@ def step_pred(
112117 model_output : Union [torch .FloatTensor , np .ndarray ],
113118 timestep : int ,
114119 sample : Union [torch .FloatTensor , np .ndarray ],
115- seed = None ,
120+ generator : Optional [torch .Generator ] = None ,
121+ ** kwargs ,
116122 ):
117123 """
118124 Predict the sample at the previous timestep by reversing the SDE.
119125 """
120- if seed is not None :
121- self .set_seed (seed )
122- # TODO(Patrick) non-PyTorch
126+ if "seed" in kwargs and kwargs ["seed" ] is not None :
127+ self .set_seed (kwargs ["seed" ])
123128
124129 if self .timesteps is None :
125130 raise ValueError (
@@ -141,7 +146,7 @@ def step_pred(
141146 drift = drift - diffusion [:, None , None , None ] ** 2 * model_output
142147
143148 # equation 6: sample noise for the diffusion term of
144- noise = self .randn_like (sample )
149+ noise = self .randn_like (sample , generator = generator )
145150 prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
146151 # TODO is the variable diffusion the correct scaling term for the noise?
147152 prev_sample = prev_sample_mean + diffusion [:, None , None , None ] * noise # add impact of diffusion field g
@@ -152,14 +157,15 @@ def step_correct(
152157 self ,
153158 model_output : Union [torch .FloatTensor , np .ndarray ],
154159 sample : Union [torch .FloatTensor , np .ndarray ],
155- seed = None ,
160+ generator : Optional [torch .Generator ] = None ,
161+ ** kwargs ,
156162 ):
157163 """
158164 Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
159165 after making the prediction for the previous timestep.
160166 """
161- if seed is not None :
162- self .set_seed (seed )
167+ if " seed" in kwargs and kwargs [ "seed" ] is not None :
168+ self .set_seed (kwargs [ " seed" ] )
163169
164170 if self .timesteps is None :
165171 raise ValueError (
@@ -168,7 +174,7 @@ def step_correct(
168174
169175 # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
170176 # sample noise for correction
171- noise = self .randn_like (sample )
177+ noise = self .randn_like (sample , generator = generator )
172178
173179 # compute step size from the model_output, the noise, and the snr
174180 grad_norm = self .norm (model_output )
0 commit comments