1414
1515# DISCLAIMER: This file is strongly influenced by https://github.com/yang-song/score_sde_pytorch
1616
17- # TODO(Patrick, Anton, Suraj) - make scheduler framework indepedent and clean-up a bit
18- from typing import Union
17+ import warnings
18+ from typing import Optional , Union
1919
2020import numpy as np
2121import torch
@@ -98,6 +98,11 @@ def get_adjacent_sigma(self, timesteps, t):
9898 raise ValueError (f"`self.tensor_format`: { self .tensor_format } is not valid." )
9999
100100 def set_seed (self , seed ):
101+ warnings .warn (
102+ "The method `set_seed` is deprecated and will be removed in version `0.4.0`. Please consider passing a"
103+ " generator instead." ,
104+ DeprecationWarning ,
105+ )
101106 tensor_format = getattr (self , "tensor_format" , "pt" )
102107 if tensor_format == "np" :
103108 np .random .seed (seed )
@@ -111,14 +116,14 @@ def step_pred(
111116 model_output : Union [torch .FloatTensor , np .ndarray ],
112117 timestep : int ,
113118 sample : Union [torch .FloatTensor , np .ndarray ],
114- seed = None ,
119+ generator : Optional [torch .Generator ] = None ,
120+ ** kwargs ,
115121 ):
116122 """
117123 Predict the sample at the previous timestep by reversing the SDE.
118124 """
119- if seed is not None :
120- self .set_seed (seed )
121- # TODO(Patrick) non-PyTorch
125+ if "seed" in kwargs and kwargs ["seed" ] is not None :
126+ self .set_seed (kwargs ["seed" ])
122127
123128 if self .timesteps is None :
124129 raise ValueError (
@@ -140,7 +145,7 @@ def step_pred(
140145 drift = drift - diffusion [:, None , None , None ] ** 2 * model_output
141146
142147 # equation 6: sample noise for the diffusion term of
143- noise = self .randn_like (sample )
148+ noise = self .randn_like (sample , generator = generator )
144149 prev_sample_mean = sample - drift # subtract because `dt` is a small negative timestep
145150 # TODO is the variable diffusion the correct scaling term for the noise?
146151 prev_sample = prev_sample_mean + diffusion [:, None , None , None ] * noise # add impact of diffusion field g
@@ -151,14 +156,15 @@ def step_correct(
151156 self ,
152157 model_output : Union [torch .FloatTensor , np .ndarray ],
153158 sample : Union [torch .FloatTensor , np .ndarray ],
154- seed = None ,
159+ generator : Optional [torch .Generator ] = None ,
160+ ** kwargs ,
155161 ):
156162 """
157163 Correct the predicted sample based on the output model_output of the network. This is often run repeatedly
158164 after making the prediction for the previous timestep.
159165 """
160- if seed is not None :
161- self .set_seed (seed )
166+ if " seed" in kwargs and kwargs [ "seed" ] is not None :
167+ self .set_seed (kwargs [ " seed" ] )
162168
163169 if self .timesteps is None :
164170 raise ValueError (
@@ -167,7 +173,7 @@ def step_correct(
167173
168174 # For small batch sizes, the paper "suggest replacing norm(z) with sqrt(d), where d is the dim. of z"
169175 # sample noise for correction
170- noise = self .randn_like (sample )
176+ noise = self .randn_like (sample , generator = generator )
171177
172178 # compute step size from the model_output, the noise, and the snr
173179 grad_norm = self .norm (model_output )
0 commit comments