diff --git a/src/diffusers/configuration_utils.py b/src/diffusers/configuration_utils.py index f06586b23698..ecf23010c3c1 100644 --- a/src/diffusers/configuration_utils.py +++ b/src/diffusers/configuration_utils.py @@ -24,6 +24,8 @@ from collections import OrderedDict from typing import Any, Dict, Tuple, Union +import numpy as np + from huggingface_hub import hf_hub_download from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError from requests import HTTPError @@ -502,6 +504,12 @@ def to_json_string(self) -> str: config_dict["_class_name"] = self.__class__.__name__ config_dict["_diffusers_version"] = __version__ + def to_json_saveable(value): + if isinstance(value, np.ndarray): + value = value.tolist() + return value + + config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()} return json.dumps(config_dict, indent=2, sort_keys=True) + "\n" def to_json_file(self, json_file_path: Union[str, os.PathLike]): diff --git a/src/diffusers/schedulers/scheduling_ddim.py b/src/diffusers/schedulers/scheduling_ddim.py index a2e571f9982b..b1f765a7284e 100644 --- a/src/diffusers/schedulers/scheduling_ddim.py +++ b/src/diffusers/schedulers/scheduling_ddim.py @@ -17,7 +17,7 @@ import math from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -123,7 +123,7 @@ def __init__( beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", - trained_betas: Optional[np.ndarray] = None, + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, clip_sample: bool = True, set_alpha_to_one: bool = True, steps_offset: int = 0, @@ -139,7 +139,7 @@ def __init__( self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") if trained_betas is not None: - self.betas = torch.from_numpy(trained_betas) + self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": diff --git a/src/diffusers/schedulers/scheduling_ddpm.py b/src/diffusers/schedulers/scheduling_ddpm.py index d1dfa1a44b99..a9522c5ab4bc 100644 --- a/src/diffusers/schedulers/scheduling_ddpm.py +++ b/src/diffusers/schedulers/scheduling_ddpm.py @@ -16,7 +16,7 @@ import math from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -115,7 +115,7 @@ def __init__( beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", - trained_betas: Optional[np.ndarray] = None, + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, variance_type: str = "fixed_small", clip_sample: bool = True, prediction_type: str = "epsilon", @@ -130,7 +130,7 @@ def __init__( self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") if trained_betas is not None: - self.betas = torch.from_numpy(trained_betas) + self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": diff --git a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py index 6258933dfe16..6e7804814077 100644 --- a/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py +++ b/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py @@ -127,7 +127,7 @@ def __init__( beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", - trained_betas: Optional[np.ndarray] = None, + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, solver_order: int = 2, prediction_type: str = "epsilon", thresholding: bool = False, @@ -147,7 +147,7 @@ def __init__( self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample") if trained_betas is not None: - self.betas = torch.from_numpy(trained_betas) + self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": diff --git a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py index 301ad2cebeb8..fe8a36c43f51 100644 --- a/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_ancestral_discrete.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -77,10 +77,10 @@ def __init__( beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", - trained_betas: Optional[np.ndarray] = None, + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, ): if trained_betas is not None: - self.betas = torch.from_numpy(trained_betas) + self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": diff --git a/src/diffusers/schedulers/scheduling_euler_discrete.py b/src/diffusers/schedulers/scheduling_euler_discrete.py index 10b0138abdac..a3208bfdbb21 100644 --- a/src/diffusers/schedulers/scheduling_euler_discrete.py +++ b/src/diffusers/schedulers/scheduling_euler_discrete.py @@ -13,7 +13,7 @@ # limitations under the License. from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -78,11 +78,11 @@ def __init__( beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", - trained_betas: Optional[np.ndarray] = None, + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, prediction_type: str = "epsilon", ): if trained_betas is not None: - self.betas = torch.from_numpy(trained_betas) + self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": diff --git a/src/diffusers/schedulers/scheduling_heun.py b/src/diffusers/schedulers/scheduling_heun.py index e6e5335e0dfb..d21591b3df21 100644 --- a/src/diffusers/schedulers/scheduling_heun.py +++ b/src/diffusers/schedulers/scheduling_heun.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -53,10 +53,10 @@ def __init__( beta_start: float = 0.00085, # sensible defaults beta_end: float = 0.012, beta_schedule: str = "linear", - trained_betas: Optional[np.ndarray] = None, + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, ): if trained_betas is not None: - self.betas = torch.from_numpy(trained_betas) + self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": diff --git a/src/diffusers/schedulers/scheduling_ipndm.py b/src/diffusers/schedulers/scheduling_ipndm.py index 1bcebe65a370..f22261d3ecd2 100644 --- a/src/diffusers/schedulers/scheduling_ipndm.py +++ b/src/diffusers/schedulers/scheduling_ipndm.py @@ -13,8 +13,9 @@ # limitations under the License. import math -from typing import Tuple, Union +from typing import List, Optional, Tuple, Union +import numpy as np import torch from ..configuration_utils import ConfigMixin, register_to_config @@ -40,7 +41,9 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin): order = 1 @register_to_config - def __init__(self, num_train_timesteps: int = 1000): + def __init__( + self, num_train_timesteps: int = 1000, trained_betas: Optional[Union[np.ndarray, List[float]]] = None + ): # set `betas`, `alphas`, `timesteps` self.set_timesteps(num_train_timesteps) @@ -67,7 +70,11 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic steps = torch.linspace(1, 0, num_inference_steps + 1)[:-1] steps = torch.cat([steps, torch.tensor([0.0])]) - self.betas = torch.sin(steps * math.pi / 2) ** 2 + if self.config.trained_betas is not None: + self.betas = torch.tensor(self.config.trained_betas, dtype=torch.float32) + else: + self.betas = torch.sin(steps * math.pi / 2) ** 2 + self.alphas = (1.0 - self.betas**2) ** 0.5 timesteps = (torch.atan2(self.betas, self.alphas) / math.pi * 2)[:-1] diff --git a/src/diffusers/schedulers/scheduling_lms_discrete.py b/src/diffusers/schedulers/scheduling_lms_discrete.py index 68deae8943a4..4c28db591a62 100644 --- a/src/diffusers/schedulers/scheduling_lms_discrete.py +++ b/src/diffusers/schedulers/scheduling_lms_discrete.py @@ -13,7 +13,7 @@ # limitations under the License. import warnings from dataclasses import dataclass -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -77,10 +77,10 @@ def __init__( beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", - trained_betas: Optional[np.ndarray] = None, + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, ): if trained_betas is not None: - self.betas = torch.from_numpy(trained_betas) + self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": diff --git a/src/diffusers/schedulers/scheduling_pndm.py b/src/diffusers/schedulers/scheduling_pndm.py index e2a076925cfd..bb3e098f7e42 100644 --- a/src/diffusers/schedulers/scheduling_pndm.py +++ b/src/diffusers/schedulers/scheduling_pndm.py @@ -15,7 +15,7 @@ # DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim import math -from typing import Optional, Tuple, Union +from typing import List, Optional, Tuple, Union import numpy as np import torch @@ -99,13 +99,13 @@ def __init__( beta_start: float = 0.0001, beta_end: float = 0.02, beta_schedule: str = "linear", - trained_betas: Optional[np.ndarray] = None, + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, skip_prk_steps: bool = False, set_alpha_to_one: bool = False, steps_offset: int = 0, ): if trained_betas is not None: - self.betas = torch.from_numpy(trained_betas) + self.betas = torch.tensor(trained_betas, dtype=torch.float32) elif beta_schedule == "linear": self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) elif beta_schedule == "scaled_linear": diff --git a/tests/test_scheduler.py b/tests/test_scheduler.py index f90246b337a5..9d294c3005cf 100755 --- a/tests/test_scheduler.py +++ b/tests/test_scheduler.py @@ -584,6 +584,20 @@ def test_deprecated_kwargs(self): " deprecated argument from `_deprecated_kwargs = []`" ) + def test_trained_betas(self): + for scheduler_class in self.scheduler_classes: + if scheduler_class == VQDiffusionScheduler: + continue + + scheduler_config = self.get_scheduler_config() + scheduler = scheduler_class(**scheduler_config, trained_betas=np.array([0.0, 0.1])) + + with tempfile.TemporaryDirectory() as tmpdirname: + scheduler.save_pretrained(tmpdirname) + new_scheduler = scheduler_class.from_pretrained(tmpdirname) + + assert scheduler.betas.tolist() == new_scheduler.betas.tolist() + class DDPMSchedulerTest(SchedulerCommonTest): scheduler_classes = (DDPMScheduler,) @@ -1407,7 +1421,6 @@ def get_scheduler_config(self, **kwargs): "beta_start": 0.0001, "beta_end": 0.02, "beta_schedule": "linear", - "trained_betas": None, } config.update(**kwargs) @@ -1489,7 +1502,6 @@ def get_scheduler_config(self, **kwargs): "beta_start": 0.0001, "beta_end": 0.02, "beta_schedule": "linear", - "trained_betas": None, } config.update(**kwargs) @@ -1580,7 +1592,6 @@ def get_scheduler_config(self, **kwargs): "beta_start": 0.0001, "beta_end": 0.02, "beta_schedule": "linear", - "trained_betas": None, } config.update(**kwargs) @@ -1889,7 +1900,6 @@ def get_scheduler_config(self, **kwargs): "beta_start": 0.0001, "beta_end": 0.02, "beta_schedule": "linear", - "trained_betas": None, } config.update(**kwargs)