-
Notifications
You must be signed in to change notification settings - Fork 12
feature(wrh): add edm initial implementation #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
8bedc2d
0ee86a9
45ea69b
4ca066e
09a7425
a6e9555
9ab4216
207f35d
fff231d
249574b
8c3c90d
8ef06ac
d004867
c406de4
daa0698
5265393
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,255 @@ | ||
| from typing import Optional, Tuple, Literal | ||
| from dataclasses import dataclass | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from torch import Tensor | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
| import torch.optim as optim | ||
| from easydict import EasyDict | ||
| from functools import partial | ||
|
|
||
| from .edm_preconditioner import PreConditioner | ||
ruiheng123 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| from .edm_utils import SIGMA_T, SIGMA_T_DERIV, SIGMA_T_INV, SCALE_T, SCALE_T_DERIV, INITIAL_SIGMA_MAX, INITIAL_SIGMA_MIN | ||
| from .edm_utils import DEFAULT_SOLVER_PARAM | ||
| from grl.generative_models.intrinsic_model import IntrinsicModel | ||
| from grl.utils import set_seed | ||
| from grl.utils.log import log | ||
|
|
||
| class Simple(nn.Module): | ||
ruiheng123 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| def __init__(self): | ||
| super().__init__() | ||
| self.model = nn.Sequential( | ||
| nn.Linear(2, 32), | ||
| nn.ReLU(), | ||
| nn.Linear(32, 32), | ||
| nn.ReLU(), | ||
| nn.Linear(32, 2) | ||
| ) | ||
| def forward(self, x, noise, class_labels=None): | ||
| return self.model(x) | ||
|
|
||
ruiheng123 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| class EDMModel(nn.Module): | ||
ruiheng123 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def __init__(self, config: Optional[EasyDict]=None) -> None: | ||
|
|
||
| super().__init__() | ||
| self.config= config | ||
|
||
| # self.x_size = config.x_size | ||
| self.device = config.device | ||
|
|
||
| # EDM Type ["VP_edm", "VE_edm", "iDDPM_edm", "EDM"] | ||
| self.edm_type: str = config.edm_model.path.edm_type | ||
| assert self.edm_type in ["VP_edm", "VE_edm", "iDDPM_edm", "EDM"], \ | ||
| f"Your edm type should in 'VP_edm', 'VE_edm', 'iDDPM_edm', 'EDM'], but got {self.edm_type}" | ||
|
|
||
| #* 1. Construct basic Unet architecture through params in config | ||
| self.base_denoise_network = Simple() | ||
|
|
||
| #* 2. Precond setup | ||
| self.params = config.edm_model.path.params | ||
| self.preconditioner = PreConditioner( | ||
| self.edm_type, | ||
| base_denoise_model=self.base_denoise_network, | ||
| use_mixes_precision=False, | ||
| **self.params | ||
| ) | ||
|
|
||
| #* 3. Solver setup | ||
| self.solver_type = config.edm_model.solver.solver_type | ||
| assert self.solver_type in ['euler', 'heun'] | ||
ruiheng123 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| self.solver_params = DEFAULT_SOLVER_PARAM | ||
| self.solver_params.update(config.edm_model.solver.params) | ||
|
|
||
| # Initialize sigma_min and sigma_max if not provided | ||
|
|
||
|
|
||
| self.sigma_min = INITIAL_SIGMA_MIN[self.edm_type] if "sigma_min" not in self.params else self.params.sigma_min | ||
| self.sigma_max = INITIAL_SIGMA_MAX[self.edm_type] if "sigma_max" not in self.params else self.params.sigma_max | ||
|
|
||
|
|
||
| def get_type(self): | ||
ruiheng123 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| return "EDMModel" | ||
|
|
||
| # For VP_edm | ||
ruiheng123 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| def _sample_sigma_weight_train(self, x: Tensor, **params) -> Tuple[Tensor, Tensor]: | ||
| # assert the first dim of x is batch size | ||
| log.info(f"Params of trainig is: {params}") | ||
| rand_shape = [x.shape[0]] + [1] * (x.ndim - 1) | ||
| if self.edm_type == "VP_edm": | ||
| epsilon_t = params.get("epsilon_t", 1e-5) | ||
| beta_d = params.get("beta_d", 19.9) | ||
| beta_min = params.get("beta_min", 0.1) | ||
|
|
||
| rand_uniform = torch.rand(*rand_shape, device=x.device) | ||
| sigma = SIGMA_T["VP_edm"](1 + rand_uniform * (epsilon_t - 1), beta_d, beta_min) | ||
| weight = 1 / sigma ** 2 | ||
| elif self.edm_type == "VE_edm": | ||
| rand_uniform = torch.rand(*rand_shape, device=x.device) | ||
| sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rand_uniform) | ||
| weight = 1 / sigma ** 2 | ||
| elif self.edm_type == "EDM": | ||
| P_mean = params.get("P_mean", -1.2) | ||
| P_std = params.get("P_mean", 1.2) | ||
| sigma_data = params.get("sigma_data", 0.5) | ||
|
|
||
| rand_normal = torch.randn(*rand_shape, device=x.device) | ||
| sigma = (rand_normal * P_std + P_mean).exp() | ||
| weight = (sigma ** 2 + sigma_data ** 2) / (sigma * sigma_data) ** 2 | ||
ruiheng123 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return sigma, weight | ||
|
|
||
| def forward(self, | ||
| x: Tensor, | ||
| class_labels=None) -> Tensor: | ||
ruiheng123 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| x = x.to(self.device) | ||
| sigma, weight = self._sample_sigma_weight_train(x, **self.params) | ||
| n = torch.randn_like(x) * sigma | ||
| D_xn = self.preconditioner(x+n, sigma, class_labels=class_labels) | ||
| loss = weight * ((D_xn - x) ** 2) | ||
| return loss | ||
|
|
||
|
|
||
| def _get_sigma_steps_t_steps(self, num_steps=18, epsilon_s=1e-3, rho=7): | ||
|
||
| """ | ||
| Overview: | ||
| Get the schedule of sigma according to differernt t schedules. | ||
|
|
||
| """ | ||
| self.sigma_min = max(self.sigma_min, self.preconditioner.sigma_min) | ||
| self.sigma_max = min(self.sigma_max, self.preconditioner.sigma_max) | ||
|
|
||
| # Define time steps in terms of noise level | ||
| step_indices = torch.arange(num_steps, dtype=torch.float64, device=self.device) | ||
| sigma_steps = None | ||
| if self.edm_type == "VP_edm": | ||
| vp_beta_d = 2 * (np.log(self.sigma_min ** 2 + 1) / epsilon_s - np.log(self.sigma_max ** 2 + 1)) / (epsilon_s - 1) | ||
| vp_beta_min = np.log(self.sigma_max ** 2 + 1) - 0.5 * vp_beta_d | ||
|
|
||
| orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1) | ||
| sigma_steps = SIGMA_T["VP_edm"](orig_t_steps, vp_beta_d, vp_beta_min) | ||
|
|
||
| elif self.edm_type == "VE_edm": | ||
| orig_t_steps = (self.sigma_max ** 2) * ((self.sigma_min ** 2 / self.sigma_max ** 2) ** (step_indices / (num_steps - 1))) | ||
| sigma_steps = SIGMA_T["VE_edm"](orig_t_steps) | ||
|
|
||
| elif self.edm_type == "iDDPM_edm": | ||
| M, C_1, C_2 = self.params.M, self.params.C_1, self.params.C_2 | ||
|
|
||
| u = torch.zeros(M + 1, dtype=torch.float64, device=self.device) | ||
| alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2 | ||
| for j in torch.arange(self.params.M, 0, -1, device=self.device): # M, ..., 1 | ||
| u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt() | ||
| u_filtered = u[torch.logical_and(u >= self.sigma_min, u <= self.sigma_max)] | ||
|
|
||
| sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)] | ||
| orig_t_steps = SIGMA_T_INV[self.edm_type](self.preconditioner.round_sigma(sigma_steps)) | ||
|
|
||
| elif self.edm_type == "EDM": | ||
| sigma_steps = (self.sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * \ | ||
| (self.sigma_min ** (1 / rho) - self.sigma_max ** (1 / rho))) ** rho | ||
| orig_t_steps = SIGMA_T_INV[self.edm_type](self.preconditioner.round_sigma(sigma_steps)) | ||
|
|
||
| t_steps = torch.cat([orig_t_steps, torch.zeros_like(orig_t_steps[:1])]) # t_N = 0 | ||
|
|
||
| return sigma_steps, t_steps | ||
|
|
||
|
|
||
| def _get_sigma_deriv_inv_scale_deriv(self, epsilon_s=1e-3): | ||
| """ | ||
| Overview: | ||
| Get sigma(t) for different solver schedules. | ||
|
|
||
| Returns: | ||
| sigma(t), sigma'(t), sigma^{-1}(sigma) | ||
| """ | ||
| vp_beta_d = 2 * (np.log(self.sigma_min ** 2 + 1) / epsilon_s - np.log(self.sigma_max ** 2 + 1)) / (epsilon_s - 1) | ||
| vp_beta_min = np.log(self.sigma_max ** 2 + 1) - 0.5 * vp_beta_d | ||
| sigma = partial(SIGMA_T[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min) | ||
| sigma_deriv = partial(SIGMA_T_DERIV[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min) | ||
| sigma_inv = partial(SIGMA_T_INV[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min) | ||
| scale = partial(SCALE_T[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min) | ||
| scale_deriv = partial(SCALE_T_DERIV[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min) | ||
|
|
||
| return sigma, sigma_deriv, sigma_inv, scale, scale_deriv | ||
|
|
||
|
|
||
| def sample(self, | ||
| t_span, | ||
| batch_size, | ||
| latents: Tensor, | ||
| class_labels: Tensor=None, | ||
| use_stochastic: bool=False, | ||
| **solver_kwargs | ||
| ) -> Tensor: | ||
|
|
||
| # Get sigmas, scales, and timesteps | ||
| log.info(f"Solver param is {self.solver_params}") | ||
| num_steps = self.solver_params.num_steps | ||
| epsilon_s = self.solver_params.epsilon_s | ||
| rho = self.solver_params.rho | ||
|
|
||
| latents = latents.to(self.device) | ||
| sigma_steps, t_steps = self._get_sigma_steps_t_steps(num_steps=num_steps, epsilon_s=epsilon_s, rho=rho) | ||
| sigma, sigma_deriv, sigma_inv, scale, scale_deriv = self._get_sigma_deriv_inv_scale_deriv() | ||
|
|
||
| S_churn = self.solver_params.S_churn | ||
| S_min = self.solver_params.S_min | ||
| S_max = self.solver_params.S_max | ||
| S_noise = self.solver_params.S_noise | ||
| alpha = self.solver_params.alpha | ||
|
|
||
|
|
||
| if not use_stochastic: | ||
| # Main sampling loop | ||
| t_next = t_steps[0] | ||
| x_next = latents.to(torch.float64) * (sigma(t_next) * scale(t_next)) | ||
| for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 | ||
| x_cur = x_next | ||
|
|
||
| # Increase noise temporarily. | ||
| gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0 | ||
| t_hat = sigma_inv(self.preconditioner.round_sigma(sigma(t_cur) + gamma * sigma(t_cur))) | ||
| x_hat = scale(t_hat) / scale(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * scale(t_hat) * S_noise * torch.randn_like(x_cur) | ||
|
|
||
| # Euler step. | ||
| h = t_next - t_hat | ||
| denoised = self.preconditioner(x_hat / scale(t_hat), sigma(t_hat), class_labels).to(torch.float64) | ||
| d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + scale_deriv(t_hat) / scale(t_hat)) * x_hat - sigma_deriv(t_hat) * scale(t_hat) / sigma(t_hat) * denoised | ||
| x_prime = x_hat + alpha * h * d_cur | ||
| t_prime = t_hat + alpha * h | ||
|
|
||
| # Apply 2nd order correction. | ||
| if self.solver_type == 'euler' or i == num_steps - 1: | ||
| x_next = x_hat + h * d_cur | ||
| else: | ||
| assert self.solver_type == 'heun' | ||
| denoised = self.preconditioner(x_prime / scale(t_prime), sigma(t_prime), class_labels).to(torch.float64) | ||
| d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + scale_deriv(t_prime) / scale(t_prime)) * x_prime - sigma_deriv(t_prime) * scale(t_prime) / sigma(t_prime) * denoised | ||
| x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime) | ||
|
|
||
| else: | ||
| assert self.edm_type == "EDM", f"Stochastic can only use in EDM, but your precond type is {self.edm_type}" | ||
| x_next = latents.to(torch.float64) * t_steps[0] | ||
| for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 | ||
| x_cur = x_next | ||
|
|
||
| # Increase noise temporarily. | ||
| gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0 | ||
| t_hat = self.preconditioner.round_sigma(t_cur + gamma * t_cur) | ||
| x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * torch.randn_like(x_cur) | ||
|
|
||
| # Euler step. | ||
| denoised = self.preconditioner(x_hat, t_hat, class_labels).to(torch.float64) | ||
ruiheng123 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| d_cur = (x_hat - denoised) / t_hat | ||
| x_next = x_hat + (t_next - t_hat) * d_cur | ||
|
|
||
| # Apply 2nd order correction. | ||
| if i < num_steps - 1: | ||
| denoised = self.preconditioner(x_next, t_next, class_labels).to(torch.float64) | ||
| d_prime = (x_next - denoised) / t_next | ||
| x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime) | ||
|
|
||
|
|
||
| return x_next | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,111 @@ | ||
| from typing import Optional, Tuple, Literal | ||
| from dataclasses import dataclass | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from torch import Tensor, as_tensor | ||
| import torch.nn as nn | ||
| import torch.nn.functional as F | ||
|
|
||
| from .edm_utils import SIGMA_T, SIGMA_T_INV | ||
|
|
||
| class PreConditioner(nn.Module): | ||
|
|
||
| def __init__(self, | ||
| precondition_type: Literal["VP_edm", "VE_edm", "iDDPM_edm", "EDM"] = "EDM", | ||
| base_denoise_model: nn.Module = None, | ||
| use_mixes_precision: bool = False, | ||
| **precond_config_kwargs) -> None: | ||
|
|
||
| super().__init__() | ||
| self.precondition_type = precondition_type | ||
| self.base_denoise_model = base_denoise_model | ||
| self.use_mixes_precision = use_mixes_precision | ||
|
|
||
| if self.precondition_type == "VP_edm": | ||
| self.beta_d = precond_config_kwargs.get("beta_d", 19.9) | ||
| self.beta_min = precond_config_kwargs.get("beta_min", 0.1) | ||
| self.M = precond_config_kwargs.get("M", 1000) | ||
| self.epsilon_t = precond_config_kwargs.get("epsilon_t", 1e-5) | ||
|
|
||
| self.sigma_min = SIGMA_T["VP_edm"](self.epsilon_t, self.beta_d, self.beta_min) | ||
| self.sigma_max = SIGMA_T["VP_edm"](1, self.beta_d, self.beta_min) | ||
|
|
||
| elif self.precondition_type == "VE_edm": | ||
| self.sigma_min = precond_config_kwargs.get("sigma_min", 0.02) | ||
| self.sigma_max = precond_config_kwargs.get("sigma_max", 100) | ||
|
|
||
| elif self.precondition_type == "iDDPM_edm": | ||
| self.C_1 = precond_config_kwargs.get("C_1", 0.001) | ||
| self.C_2 = precond_config_kwargs.get("C_2", 0.008) | ||
| self.M = precond_config_kwargs.get("M", 1000) | ||
| u = torch.zeros(self.M + 1) | ||
| for j in range(self.M, 0, -1): # M, ..., 1 | ||
| u[j - 1] = ((u[j] ** 2 + 1) / (self.alpha_bar(j - 1) / self.alpha_bar(j)).clip(min=self.C_1) - 1).sqrt() | ||
| self.register_buffer('u', u) | ||
| self.sigma_min = float(u[self.M - 1]) | ||
| self.sigma_max = float(u[0]) | ||
|
|
||
| elif self.precondition_type == "EDM": | ||
| self.sigma_min = precond_config_kwargs.get("sigma_min", 0.002) | ||
| self.sigma_max = precond_config_kwargs.get("sigma_max", 80) | ||
| self.sigma_data = precond_config_kwargs.get("sigma_data", 0.5) | ||
|
|
||
| else: | ||
| raise ValueError(f"Please check your precond type {self.precondition_type} is in ['VP_edm', 'VE_edm', 'iDDPM_edm', 'EDM']") | ||
|
|
||
|
|
||
| # For iDDPM_edm | ||
| def alpha_bar(self, j): | ||
| assert self.precondition_type == "iDDPM_edm", f"Only iDDPM_edm supports the alpha bar function, but your precond type is {self.precondition_type}" | ||
| j = torch.as_tensor(j) | ||
| return (0.5 * np.pi * j / self.M / (self.C_2 + 1)).sin() ** 2 | ||
|
|
||
|
|
||
| def round_sigma(self, sigma, return_index=False): | ||
|
|
||
| if self.precondition_type == "iDDPM_edm": | ||
| sigma = torch.as_tensor(sigma) | ||
| index = torch.cdist(sigma.to(self.u.device).to(torch.float32).reshape(1, -1, 1), self.u.reshape(1, -1, 1)).argmin(2) | ||
| result = index if return_index else self.u[index.flatten()].to(sigma.dtype) | ||
| return result.reshape(sigma.shape).to(sigma.device) | ||
ruiheng123 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| else: | ||
| return torch.as_tensor(sigma) | ||
|
|
||
| def get_precondition_c(self, sigma: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: | ||
|
|
||
| if self.precondition_type == "VP_edm": | ||
| c_skip = 1 | ||
| c_out = -sigma | ||
| c_in = 1 / (sigma ** 2 + 1).sqrt() | ||
| c_noise = (self.M - 1) * SIGMA_T_INV["VP_edm"](sigma, self.beta_d, self.beta_min) | ||
| elif self.precondition_type == "VE_edm": | ||
| c_skip = 1 | ||
| c_out = sigma | ||
| c_in = 1 | ||
| c_noise = (0.5 * sigma).log() | ||
| elif self.precondition_type == "iDDPM_edm": | ||
| c_skip = 1 | ||
| c_out = -sigma | ||
| c_in = 1 / (sigma ** 2 + 1).sqrt() | ||
| c_noise = self.M - 1 - self.round_sigma(sigma, return_index=True).to(torch.float32) | ||
| elif self.precondition_type == "EDM": | ||
| c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. for some constant variable, such as |
||
| c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt() | ||
| c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt() | ||
| c_noise = sigma.log() / 4 | ||
| return c_skip, c_out, c_in, c_noise | ||
|
|
||
| def forward(self, x: Tensor, sigma: Tensor, class_labels=None, **model_kwargs): | ||
| # Suppose the first dim of x is batch size | ||
| x = x.to(torch.float32) | ||
| sigma_shape = [x.shape[0]] + [1] * (x.ndim - 1) | ||
| if sigma.numel() == 1: | ||
| sigma = sigma.view(-1).expand(*sigma_shape) | ||
|
|
||
| dtype = torch.float16 if (self.use_mixes_precision and x.device.type == 'cuda') else torch.float32 | ||
| c_skip, c_out, c_in, c_noise = self.get_precondition_c(sigma) | ||
| F_x = self.base_denoise_model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs) | ||
| assert F_x.dtype == dtype | ||
| D_x = c_skip * x + c_out * F_x.to(torch.float32) | ||
| return D_x | ||
Uh oh!
There was an error while loading. Please reload this page.