Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/diffusers/configuration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
from collections import OrderedDict
from typing import Any, Dict, Tuple, Union

import numpy as np

from huggingface_hub import hf_hub_download
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError
from requests import HTTPError
Expand Down Expand Up @@ -502,6 +504,12 @@ def to_json_string(self) -> str:
config_dict["_class_name"] = self.__class__.__name__
config_dict["_diffusers_version"] = __version__

def to_json_saveable(value):
if isinstance(value, np.ndarray):
value = value.tolist()
return value

config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"

def to_json_file(self, json_file_path: Union[str, os.PathLike]):
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/schedulers/scheduling_ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -123,7 +123,7 @@ def __init__(
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
clip_sample: bool = True,
set_alpha_to_one: bool = True,
steps_offset: int = 0,
Expand All @@ -139,7 +139,7 @@ def __init__(
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")

if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/schedulers/scheduling_ddpm.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import math
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -115,7 +115,7 @@ def __init__(
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💯

variance_type: str = "fixed_small",
clip_sample: bool = True,
prediction_type: str = "epsilon",
Expand All @@ -130,7 +130,7 @@ def __init__(
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")

if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
Expand Down
4 changes: 2 additions & 2 deletions src/diffusers/schedulers/scheduling_dpmsolver_multistep.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def __init__(
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
solver_order: int = 2,
prediction_type: str = "epsilon",
thresholding: bool = False,
Expand All @@ -147,7 +147,7 @@ def __init__(
self.register_to_config(prediction_type="epsilon" if predict_epsilon else "sample")

if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from dataclasses import dataclass
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -77,10 +77,10 @@ def __init__(
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
):
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/schedulers/scheduling_euler_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

from dataclasses import dataclass
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -78,11 +78,11 @@ def __init__(
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
prediction_type: str = "epsilon",
):
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/schedulers/scheduling_heun.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -53,10 +53,10 @@ def __init__(
beta_start: float = 0.00085, # sensible defaults
beta_end: float = 0.012,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
):
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
Expand Down
13 changes: 10 additions & 3 deletions src/diffusers/schedulers/scheduling_ipndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@
# limitations under the License.

import math
from typing import Tuple, Union
from typing import List, Optional, Tuple, Union

import numpy as np
import torch

from ..configuration_utils import ConfigMixin, register_to_config
Expand All @@ -40,7 +41,9 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
order = 1

@register_to_config
def __init__(self, num_train_timesteps: int = 1000):
def __init__(
self, num_train_timesteps: int = 1000, trained_betas: Optional[Union[np.ndarray, List[float]]] = None
):
# set `betas`, `alphas`, `timesteps`
self.set_timesteps(num_train_timesteps)

Expand All @@ -67,7 +70,11 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
steps = torch.linspace(1, 0, num_inference_steps + 1)[:-1]
steps = torch.cat([steps, torch.tensor([0.0])])

self.betas = torch.sin(steps * math.pi / 2) ** 2
if self.config.trained_betas is not None:
self.betas = torch.tensor(self.config.trained_betas, dtype=torch.float32)
else:
self.betas = torch.sin(steps * math.pi / 2) ** 2

self.alphas = (1.0 - self.betas**2) ** 0.5

timesteps = (torch.atan2(self.betas, self.alphas) / math.pi * 2)[:-1]
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/schedulers/scheduling_lms_discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import warnings
from dataclasses import dataclass
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -77,10 +77,10 @@ def __init__(
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
):
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
Expand Down
6 changes: 3 additions & 3 deletions src/diffusers/schedulers/scheduling_pndm.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
# DISCLAIMER: This file is strongly influenced by https://github.com/ermongroup/ddim

import math
from typing import Optional, Tuple, Union
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -99,13 +99,13 @@ def __init__(
beta_start: float = 0.0001,
beta_end: float = 0.02,
beta_schedule: str = "linear",
trained_betas: Optional[np.ndarray] = None,
trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
skip_prk_steps: bool = False,
set_alpha_to_one: bool = False,
steps_offset: int = 0,
):
if trained_betas is not None:
self.betas = torch.from_numpy(trained_betas)
self.betas = torch.tensor(trained_betas, dtype=torch.float32)
elif beta_schedule == "linear":
self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
elif beta_schedule == "scaled_linear":
Expand Down
18 changes: 14 additions & 4 deletions tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,6 +584,20 @@ def test_deprecated_kwargs(self):
" deprecated argument from `_deprecated_kwargs = [<deprecated_argument>]`"
)

def test_trained_betas(self):
for scheduler_class in self.scheduler_classes:
if scheduler_class == VQDiffusionScheduler:
continue

scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config, trained_betas=np.array([0.0, 0.1]))

with tempfile.TemporaryDirectory() as tmpdirname:
scheduler.save_pretrained(tmpdirname)
new_scheduler = scheduler_class.from_pretrained(tmpdirname)

assert scheduler.betas.tolist() == new_scheduler.betas.tolist()


class DDPMSchedulerTest(SchedulerCommonTest):
scheduler_classes = (DDPMScheduler,)
Expand Down Expand Up @@ -1407,7 +1421,6 @@ def get_scheduler_config(self, **kwargs):
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
"trained_betas": None,
}

config.update(**kwargs)
Expand Down Expand Up @@ -1489,7 +1502,6 @@ def get_scheduler_config(self, **kwargs):
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
"trained_betas": None,
}

config.update(**kwargs)
Expand Down Expand Up @@ -1580,7 +1592,6 @@ def get_scheduler_config(self, **kwargs):
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
"trained_betas": None,
}

config.update(**kwargs)
Expand Down Expand Up @@ -1889,7 +1900,6 @@ def get_scheduler_config(self, **kwargs):
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
"trained_betas": None,
}

config.update(**kwargs)
Expand Down