diff --git a/docs/source/api_doc/generative_models/index.rst b/docs/source/api_doc/generative_models/index.rst index 6e560e2..43f478f 100644 --- a/docs/source/api_doc/generative_models/index.rst +++ b/docs/source/api_doc/generative_models/index.rst @@ -28,3 +28,9 @@ OptimalTransportConditionalFlowModel .. autoclass:: OptimalTransportConditionalFlowModel :special-members: __init__ :members: + +EDMDiffusionModel +------------------------------- +.. autoclass:: EDMModel + :special-members: __init__ + :members: \ No newline at end of file diff --git a/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py b/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py new file mode 100644 index 0000000..db7503e --- /dev/null +++ b/grl/generative_models/edm_diffusion_model/edm_diffusion_model.py @@ -0,0 +1,541 @@ +from typing import Any, Callable, Dict, List, Tuple, Union, Optional +from dataclasses import dataclass + +import numpy as np +from torch import Tensor +import torch +import torch.nn as nn +import torch.nn.functional as F +import treetensor +from tensordict import TensorDict + +import torch.optim as optim +from easydict import EasyDict +from functools import partial + +from grl.generative_models.intrinsic_model import IntrinsicModel +from grl.generative_models.random_generator import gaussian_random_variable +from grl.numerical_methods.numerical_solvers import get_solver +from grl.numerical_methods.numerical_solvers.dpm_solver import DPMSolver +from grl.numerical_methods.numerical_solvers.ode_solver import ( + DictTensorODESolver, + ODESolver, +) +from grl.numerical_methods.numerical_solvers.sde_solver import SDESolver + +from grl.utils import find_parameters +from grl.utils import set_seed +from grl.utils.log import log + +from grl.generative_models.edm_diffusion_model.edm_preconditioner import PreConditioner +from grl.generative_models.edm_diffusion_model.edm_utils import SIGMA_T, SIGMA_T_DERIV, SIGMA_T_INV +from grl.generative_models.edm_diffusion_model.edm_utils import SCALE_T, SCALE_T_DERIV +from grl.generative_models.edm_diffusion_model.edm_utils import INITIAL_SIGMA_MAX, INITIAL_SIGMA_MIN +from grl.generative_models.edm_diffusion_model.edm_utils import DEFAULT_PARAM, DEFAULT_SOLVER_PARAM + + +class EDMModel(nn.Module): + """ + Overview: + An implementation of EDM, which eludicates diffusion based generative model through preconditioning, training, sampling. + This implementation supports 4 types: `VP_edm`(DDPM-SDE), `VE_edm` (SGM-SDE), `iDDPM_edm`, `EDM`. More details see Table 1 in paper + EDM class utilizes different params and executes different scheules during precondition, training and sample process. + Sampling supports 1st order Euler step and 2nd order Heun step as Algorithm 1 in paper. + For EDM type itself, stochastic sampler as Algorithm 2 in paper is also supported. + + Interface: + ``__init__``, ``forward``, ``sample`` + + Reference: + EDM original paper link: https://arxiv.org/abs/2206.00364 + Code reference: https://github.com/NVlabs/edm + """ + def __init__(self, config: Optional[EasyDict]=None) -> None: + """ + Overview: + Initialization of EDMModel. + + Arguments: + config (:obj:`EasyDict`): The configuration. + """ + super().__init__() + self.config = config + self.x_size = config.x_size + self.device = config.device + + + self.gaussian_generator = gaussian_random_variable( + config.x_size, + config.device, + config.use_tree_tensor if hasattr(config, "use_tree_tensor") else False, + ) + + if hasattr(config, "solver"): + self.solver = get_solver(config.solver.type)(**config.solver.args) + + # EDM Type ["VP_edm", "VE_edm", "iDDPM_edm", "EDM"] + self.edm_type = config.path.edm_type + assert self.edm_type in ["VP_edm", "VE_edm", "iDDPM_edm", "EDM"], \ + f"Your edm type should in 'VP_edm', 'VE_edm', 'iDDPM_edm', 'EDM'], but got {self.edm_type}" + + #* 1. Construct basic Unet architecture through params in config + self.model = IntrinsicModel(config.model.args) + + #* 2. Precond setup + self.params = EasyDict(DEFAULT_PARAM[self.edm_type]) + self.params.update(config.path.params) + log.info(f"Using edm type: {self.edm_type}\nParam is {self.params}") + self.preconditioner = PreConditioner( + self.edm_type, + denoise_model=self.model, + use_mixes_precision=False, + **self.params + ) + + self.solver_params = EasyDict(DEFAULT_SOLVER_PARAM) + self.solver_params.update(config.sample_params) + + # Initialize sigma_min and sigma_max if not provided + + self.sigma_min = INITIAL_SIGMA_MIN[self.edm_type] if "sigma_min" not in self.params else self.params.sigma_min + self.sigma_max = INITIAL_SIGMA_MAX[self.edm_type] if "sigma_max" not in self.params else self.params.sigma_max + + @property + def get_type(self) -> str: + return "EDMModel" + + def _sample_sigma_weight_train(self, x: Tensor) -> Tuple[Tensor, Tensor]: + """ + Overview: + Sample sigma from given distribution for training according to edm type. + More details refer to Training section in the Table 1 of EDM paper. + + ..math: + \sigma\sim p_{\mathrm{train}}, \lambda(\sigma) + + Arguments: + x (:obj:`torch.Tensor`): The sample which needs to add noise. + + Returns: + sigma (:obj:`torch.Tensor`): Sampled sigma from the distribution. + weight (:obj:`torch.Tensor`): Loss weight lambda(sigma) obtained from sampled sigma. + """ + # assert the first dim of x is batch size + + rand_shape = [x.shape[0]] + [1] * (x.ndim - 1) + if self.edm_type == "VP_edm": + epsilon_t = self.params.epsilon_t + beta_d = self.params.beta_d + beta_min = self.params.beta_min + + rand_uniform = torch.rand(*rand_shape, device=self.device) + sigma = SIGMA_T["VP_edm"](1 + rand_uniform * (epsilon_t - 1), beta_d, beta_min) + weight = 1 / sigma ** 2 + elif self.edm_type == "VE_edm": + rand_uniform = torch.rand(*rand_shape, device=self.device) + sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rand_uniform) + weight = 1 / sigma ** 2 + elif self.edm_type == "iDDPM_edm": + u = self.preconditioner.u + sigma_index = torch.randint(0, self.params.M - 1, rand_shape, device=self.device) + sigma = u[sigma_index] + weight = 1 / sigma ** 2 + elif self.edm_type == "EDM": + P_mean = self.params.P_mean + P_std = self.params.P_std + sigma_data = self.params.sigma_data + + rand_normal = torch.randn(*rand_shape, device=self.device) + sigma = (rand_normal * P_std + P_mean).exp() + weight = (sigma ** 2 + sigma_data ** 2) / (sigma * sigma_data) ** 2 + return sigma, weight + + def forward(self, x: Tensor, condition: Optional[Tensor]=None): + return self.sample(x, condition) + + def L2_denoising_matching_loss( + self, + x: Tensor, + condition: Optional[Tensor]=None + ) -> Tensor: + """ + Overview: + Calculate the L2 denoising matching loss. The denoise matching loss is given in Equation 2, 3 in EDM paper. + + ..math: + \mathbb{E}_{\sigma\sim p_{\mathrm{train}}}\mathbb{E}_{y\sim p_\mathrm{data}}\mathbb{E}_{n\sim \mathcal{N}(0, \sigma^2 \mathbf{I})} \left[\lambda(\sigma) \| \mathbf{D}(y+n) - y \|_2^2\right] + + Arguments: + x (:obj:`torch.Tensor`): The sample which needs to add noise. + condition (:obj:`Optional[torch.Tensor]`): The condition for the sample. + + Returns: + loss (:obj:`torch.Tensor`): The L2 denoising matching loss. + """ + + sigma, weight = self._sample_sigma_weight_train(x) + n = torch.randn_like(x) * sigma + inv_t = SIGMA_T_INV[self.edm_type](sigma) # TODO: Use t? or sigma? as input + D_xn = self.preconditioner(sigma, x+n, condition=condition) + loss = weight * ((D_xn - x) ** 2) + return loss.mean() + + def _get_sigma_steps_t_steps(self, + num_steps: int=18, + epsilon_s: float=1e-3, + rho: Union[int, float]=7 + ) -> Tuple[Tensor, Tensor]: + """ + Overview: + Get the schedule of sigma steps and t steps according to differernt t schedules (or sigma schedules). + + ..math: + \sigma_{i= self.sigma_min, u <= self.sigma_max)] + + sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)] + orig_t_steps = SIGMA_T_INV[self.edm_type](self.preconditioner.round_sigma(sigma_steps)) + + elif self.edm_type == "EDM": + sigma_steps = (self.sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * \ + (self.sigma_min ** (1 / rho) - self.sigma_max ** (1 / rho))) ** rho + orig_t_steps = SIGMA_T_INV[self.edm_type](self.preconditioner.round_sigma(sigma_steps)) + else: + raise NotImplementedError(f"Please check your edm_type: {self.edm_type}, which is not in ['VP_edm', 'VE_edm', 'iDDPM_edm', 'EDM']") + + t_steps = torch.cat([orig_t_steps, torch.zeros_like(orig_t_steps[:1])]) # t_N = 0 + + return sigma_steps, t_steps + + + def _get_sigma_deriv_inv_scale_deriv(self, epsilon_s: Union[int, float]=1e-3) \ + -> Tuple[Callable, Callable, Callable, Callable, Callable]: + """ + Overview: + Get sigma(t) and scale(t) for different solver schedules. + More details in sampling section of Table 1 in EDM paper. + + ..math: + \sigma(t), \sigma^\prime(t), \sigma^{-1}(\sigma), s(t), s^\prime(t) + + Returns: + sigma: (:obj:`Callable`): sigma(t) + sigma_deriv: (:obj:`Callable`): sigma'(t) + sigma_inv: (:obj:`Callable`): sigma^{-1} (sigma) + scale: (:obj:`Callable`): s(t) + scale_deriv: (:obj:`Callable`): s'(t) + + """ + vp_beta_d = 2 * (np.log(self.sigma_min ** 2 + 1) / epsilon_s - np.log(self.sigma_max ** 2 + 1)) / (epsilon_s - 1) + vp_beta_min = np.log(self.sigma_max ** 2 + 1) - 0.5 * vp_beta_d + sigma = partial(SIGMA_T[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min) + sigma_deriv = partial(SIGMA_T_DERIV[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min) + sigma_inv = partial(SIGMA_T_INV[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min) + scale = partial(SCALE_T[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min) + scale_deriv = partial(SCALE_T_DERIV[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min) + + return sigma, sigma_deriv, sigma_inv, scale, scale_deriv + + def sample( + self, + t_span: Tensor = None, + batch_size: Union[torch.Size, int, Tuple[int], List[int]] = None, + x_0: Union[Tensor, TensorDict, treetensor.torch.Tensor] = None, + condition: Union[Tensor, TensorDict, treetensor.torch.Tensor] = None, + with_grad: bool = False, + solver_config: EasyDict = None, + ) -> Tensor: + """ + Overview: + Use forward path of the diffusion model given the sampled x. Note that this is not the reverse process, and thus is not designed for sampling form the diffusion model. + Rather, it is used for encode a sampled x to the latent space. + + Arguments: + t_span (:obj:`torch.Tensor`): The time span. + batch_size: (:obj:`Union[torch.Size, int, Tuple[int], List[int]]`): The batch size of sampling. + x (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input state. + condition (:obj:`Union[torch.Tensor, TensorDict, treetensor.torch.Tensor]`): The input condition. + with_grad (:obj:`bool`): Whether to return the gradient. + solver_config (:obj:`EasyDict`): The configuration of the solver. + + """ + return self.sample_forward_process( + t_span=t_span, + batch_size=batch_size, + x_0=x_0, + condition=condition, + with_grad=with_grad, + solver_config=solver_config, + )[-1] + + def sample_forward_process( + self, + t_span: torch.Tensor = None, + batch_size: Union[torch.Size, int, Tuple[int], List[int]] = None, + x_0: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, + condition: Union[torch.Tensor, TensorDict, treetensor.torch.Tensor] = None, + with_grad: bool = False, + solver_config: EasyDict = None, + ): + sigma_steps, t_steps = self._get_sigma_steps_t_steps(num_steps=self.solver_params.num_steps, epsilon_s=self.solver_params.epsilon_s, rho=self.solver_params.rho) + + sigma, sigma_deriv, sigma_inv, scale, scale_deriv = self._get_sigma_deriv_inv_scale_deriv() + + S_churn = self.solver_params.S_churn + S_min = self.solver_params.S_min + S_max = self.solver_params.S_max + S_noise = self.solver_params.S_noise + alpha = self.solver_params.alpha + + t_next = t_steps[0] + # x_next = x_0 * (sigma(t_next) * scale(t_next)) + if t_span is not None: + t_span = t_span.to(self.device) + + if batch_size is None: + extra_batch_size = torch.tensor((1,), device=self.device) + elif isinstance(batch_size, int): + extra_batch_size = torch.tensor((batch_size,), device=self.device) + else: + if ( + isinstance(batch_size, torch.Size) + or isinstance(batch_size, Tuple) + or isinstance(batch_size, List) + ): + extra_batch_size = torch.tensor(batch_size, device=self.device) + else: + assert False, "Invalid batch size" + + if x_0 is not None and condition is not None: + assert ( + x_0.shape[0] == condition.shape[0] + ), "The batch size of x_0 and condition must be the same" + data_batch_size = x_0.shape[0] + elif x_0 is not None: + data_batch_size = x_0.shape[0] + elif condition is not None: + data_batch_size = condition.shape[0] + else: + data_batch_size = 1 + + if solver_config is not None: + solver = get_solver(solver_config.type)(**solver_config.args) + else: + assert hasattr( + self, "solver" + ), "solver must be specified in config or solver_config" + solver = self.solver + + if x_0 is None: + x = self.gaussian_generator( + batch_size=torch.prod(extra_batch_size) * data_batch_size + ) + x = x * (sigma(t_next) * scale(t_next)) + # x.shape = (B*N, D) + else: + if isinstance(self.x_size, int): + assert ( + torch.Size([self.x_size]) == x_0[0].shape + ), "The shape of x_0 must be the same as the x_size that is specified in the config" + elif ( + isinstance(self.x_size, Tuple) + or isinstance(self.x_size, List) + or isinstance(self.x_size, torch.Size) + ): + assert ( + torch.Size(self.x_size) == x_0[0].shape + ), "The shape of x_0 must be the same as the x_size that is specified in the config" + else: + assert False, "Invalid x_size" + + x = torch.repeat_interleave(x_0, torch.prod(extra_batch_size), dim=0) + # x.shape = (B*N, D) + + if condition is not None: + condition = torch.repeat_interleave( + condition, torch.prod(extra_batch_size), dim=0 + ) + # condition.shape = (B*N, D) + + + + + # # Main sampling loop + + # x_next = torch.randn_like(x) + # x_list = [x_next] + # for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1 + # x_cur = x_next + + # # Euler step. + # h = t_next - t_cur + # denoised = self.preconditioner(sigma(t_cur), x_cur / scale(t_cur), condition) + # d_cur = ((sigma_deriv(t_cur) / sigma(t_cur)) + (scale_deriv(t_cur) / scale(t_cur))) * x_cur - ((sigma_deriv(t_cur) * scale(t_cur)) / sigma(t_cur)) * denoised + + # x_next = x_cur + h * d_cur + # x_list.append(x_next) + + # return x_list + + + def drift(t, x): + t_shape = [x.shape[0]] + [1] * (x.ndim - 1) + t = t.view(*t_shape) + denoised = self.preconditioner(sigma(t), x / scale(t), condition) + f=(sigma_deriv(t) / sigma(t) + scale_deriv(t) / scale(t)) * x - sigma_deriv(t) * scale(t) / sigma(t) * denoised + return f + + t_span = torch.tensor(t_steps, device=self.device) + if isinstance(solver, ODESolver): + # TODO: make it compatible with TensorDict + if with_grad: + data = solver.integrate( + drift=drift, + x0=x, + t_span=t_span, + adjoint_params=find_parameters(self.model), + ) + else: + with torch.no_grad(): + data = solver.integrate( + drift=drift, + x0=x, + t_span=t_span, + adjoint_params=find_parameters(self.model), + ) + + + if isinstance(data, torch.Tensor): + # data.shape = (T, B*N, D) + if len(extra_batch_size.shape) == 0: + if isinstance(self.x_size, int): + data = data.reshape( + -1, extra_batch_size, data_batch_size, self.x_size + ) + elif ( + isinstance(self.x_size, Tuple) + or isinstance(self.x_size, List) + or isinstance(self.x_size, torch.Size) + ): + data = data.reshape( + -1, extra_batch_size, data_batch_size, *self.x_size + ) + else: + assert False, "Invalid x_size" + else: + if isinstance(self.x_size, int): + data = data.reshape( + -1, *extra_batch_size, data_batch_size, self.x_size + ) + elif ( + isinstance(self.x_size, Tuple) + or isinstance(self.x_size, List) + or isinstance(self.x_size, torch.Size) + ): + data = data.reshape( + -1, *extra_batch_size, data_batch_size, *self.x_size + ) + else: + assert False, "Invalid x_size" + # data.shape = (T, B, N, D) + + if batch_size is None: + if x_0 is None and condition is None: + data = data.squeeze(1).squeeze(1) + # data.shape = (T, D) + else: + data = data.squeeze(1) + # data.shape = (T, N, D) + else: + if x_0 is None and condition is None: + data = data.squeeze(1 + len(extra_batch_size.shape)) + # data.shape = (T, B, D) + else: + # data.shape = (T, B, N, D) + pass + elif isinstance(data, TensorDict): + raise NotImplementedError("Not implemented") + elif isinstance(data, treetensor.torch.Tensor): + for key in data.keys(): + if len(extra_batch_size.shape) == 0: + if isinstance(self.x_size[key], int): + data[key] = data[key].reshape( + -1, extra_batch_size, data_batch_size, self.x_size[key] + ) + elif ( + isinstance(self.x_size[key], Tuple) + or isinstance(self.x_size[key], List) + or isinstance(self.x_size[key], torch.Size) + ): + data[key] = data[key].reshape( + -1, extra_batch_size, data_batch_size, *self.x_size[key] + ) + else: + assert False, "Invalid x_size" + else: + if isinstance(self.x_size[key], int): + data[key] = data[key].reshape( + -1, *extra_batch_size, data_batch_size, self.x_size[key] + ) + elif ( + isinstance(self.x_size[key], Tuple) + or isinstance(self.x_size[key], List) + or isinstance(self.x_size[key], torch.Size) + ): + data[key] = data[key].reshape( + -1, *extra_batch_size, data_batch_size, *self.x_size[key] + ) + else: + assert False, "Invalid x_size" + # data.shape = (T, B, N, D) + + if batch_size is None: + if x_0 is None and condition is None: + data[key] = data[key].squeeze(1).squeeze(1) + # data.shape = (T, D) + else: + data[key] = data[key].squeeze(1) + # data.shape = (T, N, D) + else: + if x_0 is None and condition is None: + data[key] = data[key].squeeze(1 + len(extra_batch_size.shape)) + # data.shape = (T, B, D) + else: + # data.shape = (T, B, N, D) + pass + else: + raise NotImplementedError("Not implemented") + + return data \ No newline at end of file diff --git a/grl/generative_models/edm_diffusion_model/edm_preconditioner.py b/grl/generative_models/edm_diffusion_model/edm_preconditioner.py new file mode 100644 index 0000000..e562554 --- /dev/null +++ b/grl/generative_models/edm_diffusion_model/edm_preconditioner.py @@ -0,0 +1,178 @@ +from typing import Union, Optional, Tuple, Literal +from dataclasses import dataclass + +from torch import Tensor, as_tensor +from easydict import EasyDict +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from grl.utils.log import log +from grl.generative_models.edm_diffusion_model.edm_utils import SIGMA_T, SIGMA_T_INV + +class PreConditioner(nn.Module): + """ + Overview: + Precondition step in EDM. + + Interface: + ``__init__``, ``round_sigma``, ``get_precondition_c``, ``forward`` + """ + def __init__(self, + precondition_type: Literal["VP_edm", "VE_edm", "iDDPM_edm", "EDM"] = "EDM", + denoise_model: Optional[nn.Module] = None, + use_mixes_precision: bool = False, + **precond_params) -> None: + """ + Overview: + Initialize preconditioner for Network preconditioning in EDM. + More details in Network and Preconditioning in Section 5 of EDM paper. + + Arguments: + precondition_type (:obj:`Literal["VP_edm", "VE_edm", "iDDPM_edm", "EDM"]`): The precond type. + denoise_model (:obj:`Optional[nn.Module]`): The basic denoise network. + use_mixes_precision (:obj:`bool`): If mixes precision is used. + + Reference: + EDM original paper link: https://arxiv.org/abs/2206.00364 + Code reference: https://github.com/NVlabs/edm + """ + super().__init__() + log.info(f"Precond_params: {precond_params}") + precond_params = EasyDict(precond_params) + self.precondition_type = precondition_type + self.denoise_model = denoise_model + self.use_mixes_precision = use_mixes_precision + + if self.precondition_type == "VP_edm": + self.beta_d = precond_params.beta_d + self.beta_min = precond_params.beta_min + self.M = precond_params.M + self.epsilon_t = precond_params.epsilon_t + + self.sigma_min = float(SIGMA_T["VP_edm"](torch.tensor(self.epsilon_t), self.beta_d, self.beta_min)) + self.sigma_max = float(SIGMA_T["VP_edm"](torch.tensor(1), self.beta_d, self.beta_min)) + + elif self.precondition_type == "VE_edm": + self.sigma_min = precond_params.sigma_min + self.sigma_max = precond_params.sigma_max + + elif self.precondition_type == "iDDPM_edm": + self.C_1 = precond_params.C_1 + self.C_2 = precond_params.C_2 + self.M = precond_params.M + + # For iDDPM_edm + def alpha_bar(j): + assert self.precondition_type == "iDDPM_edm", f"Only iDDPM_edm supports the alpha bar function, but your precond type is {self.precondition_type}" + j = torch.as_tensor(j) + return (0.5 * np.pi * j / self.M / (self.C_2 + 1)).sin() ** 2 + + u = torch.zeros(self.M + 1) + for j in range(self.M, 0, -1): # M, ..., 1 + u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=self.C_1) - 1).sqrt() + self.register_buffer('u', u) + self.sigma_min = float(u[self.M - 1]) + self.sigma_max = float(u[0]) + + elif self.precondition_type == "EDM": + self.sigma_min = precond_params.sigma_min + self.sigma_max = precond_params.sigma_max + self.sigma_data = precond_params.sigma_data + + else: + raise ValueError(f"Please check your precond type {self.precondition_type} is in ['VP_edm', 'VE_edm', 'iDDPM_edm', 'EDM']") + + + def round_sigma(self, sigma: Union[Tensor, float], return_index: bool=False) -> Tensor: + """ + Overview: + return sigma as tensor. When in iDDPM_edm mode, we need index as sigma. + + Arguments: + sigma (:obj:`Union[torch.Tensor, float]`): Input sigma. + return_index (:obj:`bool`): whether index is returned. Only iDDPM_edm type needs it. + + Returns: + sigma (:obj:`torch.Tensor`): Output sigma in Tensor format. + """ + if self.precondition_type == "iDDPM_edm": + sigma = torch.as_tensor(sigma) + index = torch.cdist(sigma.to(torch.float32).reshape(1, -1, 1), self.u.reshape(1, -1, 1)).argmin(2) + result = index if return_index else self.u[index.flatten()].to(sigma.dtype) + return result.reshape(sigma.shape) + else: + return torch.as_tensor(sigma) + + def get_precondition_c(self, sigma: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]: + """ + Overview: + Obtain precondition c according to sigma including c_skip, c_out, c_in, c_noise + Accordig to section Network and preconditioning Table 1, 4 precondition functions are shown as follows: + + .. math:: + \mathbf{c}_{\mathrm{skip}}(\sigma), \mathbf{c}_{\mathrm{out}}(\sigma), \mathbf{c}_{\mathrm{in}}(\sigma), \mathbf{c}_{\mathrm{noise}}(\sigma) + + Arguments: + sigma (:obj:`torch.Tensor`): Input sigma. + + Returns: + c_skip (:obj:`torch.Tensor`): Output c_skip(sigma). + c_out (:obj:`torch.Tensor`): Output c_out(sigma). + c_in (:obj:`torch.Tensor`): Output c_in(sigma). + c_noise (:obj:`torch.Tensor`): Output c_noise(sigma). + """ + if self.precondition_type == "VP_edm": + c_skip = 1 + c_out = -sigma + c_in = 1 / (sigma ** 2 + 1).sqrt() + c_noise = (self.M - 1) * SIGMA_T_INV["VP_edm"](sigma, self.beta_d, self.beta_min) + elif self.precondition_type == "VE_edm": + c_skip = 1 + c_out = sigma + c_in = 1 + c_noise = (0.5 * sigma).log() + elif self.precondition_type == "iDDPM_edm": + c_skip = 1 + c_out = -sigma + c_in = 1 / (sigma ** 2 + 1).sqrt() + c_noise = self.M - 1 - self.round_sigma(sigma, return_index=True).to(torch.float32) + elif self.precondition_type == "EDM": + c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2) + c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt() + c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt() + c_noise = sigma.log() / 4 + return c_skip, c_out, c_in, c_noise + + def forward(self, sigma: Tensor, x: Tensor, condition: Optional[Tensor]=None, **model_kwargs): + """ + Overview: + Obtain denoiser from basic denoise network and precondition scaling functions, which is given as follows: + + .. math: + \mathbf{D}_{\theta} (\mathbf{x}; \sigma; c) = \mathbf{c}_{\mathrm{skip}}(\sigma) \mathbf{x} + \mathbf{c}_{\mathrm{out}}(\sigma) \mathbf{F}_{\theta}(\mathbf{c}_{\mathrm{in}}(\sigma)\mathbf{x}; \mathbf{c}_{\mathrm{noise}}(\sigma); c) + + Arguments: + sigma (:obj:`torch.Tensor`): Input sigma. + x (:obj:`torch.Tensor`): Input x. + condition: (:obj:`Optional[torch.Tensor]`): Input condition. + + Returns: + D_x (:obj:`torch.Tensor`): Output denoiser. + """ + # Suppose the first dim of x is batch size + x = x.to(torch.float32) + sigma_shape = [x.shape[0]] + [1] * (x.ndim - 1) + + + if sigma.numel() == 1: + sigma = sigma.view(-1).expand(*sigma_shape) + else: + sigma = sigma.reshape(*sigma_shape) + dtype = torch.float16 if (self.use_mixes_precision and x.device.type == 'cuda') else torch.float32 + c_skip, c_out, c_in, c_noise = self.get_precondition_c(sigma) + F_x = self.denoise_model(c_noise.flatten(), (c_in * x).to(dtype), condition=condition, **model_kwargs) + assert F_x.dtype == dtype + D_x = c_skip * x + c_out * F_x.to(torch.float32) + return D_x \ No newline at end of file diff --git a/grl/generative_models/edm_diffusion_model/edm_utils.py b/grl/generative_models/edm_diffusion_model/edm_utils.py new file mode 100644 index 0000000..7ddb26c --- /dev/null +++ b/grl/generative_models/edm_diffusion_model/edm_utils.py @@ -0,0 +1,100 @@ +import numpy as np +import torch +from easydict import EasyDict + +############# Sampling Section ############# + +# Scheduling in Table 1 of paper https://arxiv.org/abs/2206.00364 +SIGMA_T = { + "VP_edm": lambda t, beta_d=19.9, beta_min=0.1: ((0.5 * beta_d * (t ** 2) + beta_min * t).exp() - 1) ** 0.5, + "VE_edm": lambda t, **kwargs: t.sqrt(), + "iDDPM_edm": lambda t, **kwargs: t, + "EDM": lambda t, **kwargs: t +} + +SIGMA_T_DERIV = { + "VP_edm": lambda t, beta_d=19.9, beta_min=0.1: 0.5 * (beta_min + beta_d * t) * (SIGMA_T["VP_edm"](t, beta_d, beta_min) + (1 / SIGMA_T["VP_edm"](t, beta_d, beta_min))), + "VE_edm": lambda t, **kwargs: 1 / (2 * t.sqrt()), + "iDDPM_edm": lambda t, **kwargs: 1, + "EDM": lambda t, **kwargs: 1 +} + +SIGMA_T_INV = { + "VP_edm": lambda sigma, beta_d=19.9, beta_min=0.1: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1)).log() - beta_min).sqrt() / beta_d, + "VE_edm": lambda sigma, **kwargs: sigma ** 2, + "iDDPM_edm": lambda sigma, **kwargs: sigma, + "EDM": lambda sigma, **kwargs: sigma +} + +# Scaling in Table 1 +SCALE_T = { + "VP_edm": lambda t, beta_d=19.9, beta_min=0.1: 1 / (1 + SIGMA_T["VP_edm"](t, beta_d, beta_min) ** 2).sqrt(), + "VE_edm": lambda t, **kwargs: 1, + "iDDPM_edm": lambda t, **kwargs: 1, + "EDM": lambda t, **kwargs: 1 +} + +SCALE_T_DERIV = { + "VP_edm": lambda t, beta_d=19.9, beta_min=0.1: -SIGMA_T["VP_edm"](t, beta_d, beta_min) * SIGMA_T_DERIV["VP_edm"](t, beta_d, beta_min) * (SCALE_T["VP_edm"](t, beta_d, beta_min) ** 3), + "VE_edm": lambda t, **kwargs: 0, + "iDDPM_edm": lambda t, **kwargs: 0, + "EDM": lambda t, **kwargs: 0 +} + + +INITIAL_SIGMA_MIN = { + "VP_edm": float(SIGMA_T["VP_edm"](torch.tensor(1e-3), 19.9, 0.1)), + "VE_edm": 0.02, + "iDDPM_edm": 0.002, + "EDM": 0.002 +} + +INITIAL_SIGMA_MAX = { + "VP_edm": float(SIGMA_T["VP_edm"](torch.tensor(1.), 19.9, 0.1)), + "VE_edm": 100, + "iDDPM_edm": 81, + "EDM": 80 +} + +###### Default Params ###### + +DEFAULT_PARAM = EasyDict({ + "VP_edm": + { + "beta_d": 19.9, + "beta_min": 0.1, + "M": 1000, + "epsilon_t": 1e-5, + }, + "VE_edm": + { + "sigma_min": 0.02, + "sigma_max": 100 + }, + "iDDPM_edm": + { + "C_1": 0.001, + "C_2": 0.008, + "M": 1000 + }, + "EDM": + { + "sigma_min": 0.002, + "sigma_max": 80, + "sigma_data": 0.5, + "P_mean": -1.2, + "P_std": 1.2 + } +}) + +DEFAULT_SOLVER_PARAM = EasyDict( + { + "num_steps": 18, + "epsilon_s": 1e-3, + "rho": 7, + "S_churn": 0., + "S_min": 0., + "S_max": float("inf"), + "S_noise": 1., + "alpha": 1 +}) diff --git a/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py new file mode 100644 index 0000000..93bf2c6 --- /dev/null +++ b/grl_pipelines/tutorials/toy_examples/swiss_roll/swiss_roll_edm_diffusion.py @@ -0,0 +1,263 @@ +################################################################################################ +# This script demonstrates how to use edm diffusion to train Swiss Roll dataset. +################################################################################################ + +import os +import signal +import sys + +import matplotlib +import numpy as np +from easydict import EasyDict +from rich.progress import track +from sklearn.datasets import make_swiss_roll + +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import torch +from easydict import EasyDict +from matplotlib import animation + +from grl.generative_models.edm_diffusion_model.edm_diffusion_model import EDMModel +from grl.utils import set_seed +from grl.utils.log import log + +x_size = 2 +device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") +t_embedding_dim = 32 +t_encoder = dict( + type="GaussianFourierProjectionTimeEncoder", + args=dict( + embed_dim=t_embedding_dim, + scale=30.0, + ), +) +config = EasyDict( + dict( + device=device, + + edm_model=dict( + device=device, + x_size=[2], + sample_params=dict( + num_steps=18, + alpha=1, + S_churn=0.0, + S_min=0.0, + S_max=float("inf"), + S_noise=1.0, + rho=7, # * EDM needs rho + epsilon_s=1e-3, # * VP needs epsilon_s + ), + solver=dict( + type="ODESolver", + args=dict( + library="torchdyn", + ode_solver="euler", + ), + ), + path=dict( + edm_type="EDM", # *["VP_edm", "VE_edm", "iDDPM_edm", "EDM"] + # solver=dict( + # solver_type="heun", + # # *['euler', 'heun'] + # params=dict( + # num_steps=18, + # alpha=1, + # S_churn=0.0, + # S_min=0.0, + # S_max=float("inf"), + # S_noise=1.0, + # rho=7, # * EDM needs rho + # epsilon_s=1e-3, # * VP needs epsilon_s + # ), + # ), + params=dict( + # ^ 1: VP_edm + # beta_d=19.9, + # beta_min=0.1, + # M=1000, + # epsilon_t=1e-5, + # epsilon_s=1e-4, + # ^ 2: VE_edm + # sigma_min=0.02, + # sigma_max=100, + # ^ 3: iDDPM_edm + # C_1=0.001, + # C_2=0.008, + # M=1000, + # ^ 4: EDM + # sigma_min=0.002, + # sigma_max=80, + # sigma_data=0.5, + # P_mean=-1.21, + # P_std=1.21, + ), + ), + + + model=dict( + type="noise_function", + args=dict( + t_encoder=t_encoder, + backbone=dict( + type="TemporalSpatialResidualNet", + args=dict( + hidden_sizes=[512, 256, 128], + output_dim=x_size, + t_dim=t_embedding_dim, + ), + ), + ), + ), + ), + parameter=dict( + training_loss_type="score_matching", + lr=1e-4, + data_num=10000, + iterations=1000, + batch_size=2048, + clip_grad_norm=1.0, + eval_freq=500, + checkpoint_freq=100, + checkpoint_path="./checkpoint", + video_save_path="./video", + device=device, + ), + ) +) +if __name__ == "__main__": + seed_value = set_seed() + log.info(f"start exp with seed value {seed_value}.") + edm_diffusion_model = EDMModel(config=config.edm_model).to(config.device) + # edm_diffusion_model = torch.compile(edm_diffusion_model) + # get data + data = make_swiss_roll(n_samples=config.parameter.data_num, noise=0.01)[0].astype( + np.float32 + )[:, [0, 2]] + # transform data + data[:, 0] = data[:, 0] / np.max(np.abs(data[:, 0])) + data[:, 1] = data[:, 1] / np.max(np.abs(data[:, 1])) + data = (data - data.min()) / (data.max() - data.min()) + data = data * 10 - 5 + + # + optimizer = torch.optim.Adam( + edm_diffusion_model.parameters(), + lr=config.parameter.lr, + ) + if config.parameter.checkpoint_path is not None: + + if ( + not os.path.exists(config.parameter.checkpoint_path) + or len(os.listdir(config.parameter.checkpoint_path)) == 0 + ): + log.warning( + f"Checkpoint path {config.parameter.checkpoint_path} does not exist" + ) + last_iteration = -1 + else: + checkpoint_files = [ + f + for f in os.listdir(config.parameter.checkpoint_path) + if f.endswith(".pt") + ] + checkpoint_files = sorted( + checkpoint_files, key=lambda x: int(x.split("_")[-1].split(".")[0]) + ) + checkpoint = torch.load( + os.path.join(config.parameter.checkpoint_path, checkpoint_files[-1]), + map_location="cpu", + ) + edm_diffusion_model.load_state_dict(checkpoint["model"]) + optimizer.load_state_dict(checkpoint["optimizer"]) + last_iteration = checkpoint["iteration"] + else: + last_iteration = -1 + + data_loader = torch.utils.data.DataLoader( + data, batch_size=config.parameter.batch_size, shuffle=True + ) + + def get_train_data(dataloader): + while True: + yield from dataloader + + data_generator = get_train_data(data_loader) + + gradient_sum = 0.0 + loss_sum = 0.0 + counter = 0 + iteration = 0 + + def plot2d(data): + + plt.scatter(data[:, 0], data[:, 1]) + plt.show() + + def render_video(data_list, video_save_path, iteration, fps=100, dpi=100): + if not os.path.exists(video_save_path): + os.makedirs(video_save_path) + fig = plt.figure(figsize=(6, 6)) + plt.xlim([-10, 10]) + plt.ylim([-10, 10]) + ims = [] + colors = np.linspace(0, 1, len(data_list)) + + for i, data in enumerate(data_list): + # image alpha frm 0 to 1 + im = plt.scatter(data[:, 0], data[:, 1], s=1) + ims.append([im]) + ani = animation.ArtistAnimation(fig, ims, interval=0.1, blit=True) + ani.save( + os.path.join(video_save_path, f"iteration_{iteration}.mp4"), + fps=fps, + dpi=dpi, + ) + # clean up + plt.close(fig) + plt.clf() + + def save_checkpoint(model, optimizer, iteration): + if not os.path.exists(config.parameter.checkpoint_path): + os.makedirs(config.parameter.checkpoint_path) + torch.save( + dict( + model=model.state_dict(), + optimizer=optimizer.state_dict(), + iteration=iteration, + ), + f=os.path.join( + config.parameter.checkpoint_path, f"checkpoint_{iteration}.pt" + ), + ) + + history_iteration = [-1] + # batch_data = next(data_generator).to(config.device) + + for i in range(10000): + batch_data = next(data_generator).to(config.device) + edm_diffusion_model.train() + loss = edm_diffusion_model.L2_denoising_matching_loss(batch_data) + optimizer.zero_grad() + loss.backward() + gradien_norm = torch.nn.utils.clip_grad_norm_( + edm_diffusion_model.parameters(), config.parameter.clip_grad_norm + ) + optimizer.step() + gradient_sum += gradien_norm.item() + loss_sum += loss.item() + counter += 1 + iteration += 1 + log.info( + f"iteration {iteration}, gradient {gradient_sum/counter}, loss {loss_sum/counter}" + ) + + edm_diffusion_model.eval() + + sampled = edm_diffusion_model.sample(batch_size=1000) + log.info(f"Sampled size: {sampled.shape}") + + plt.scatter(sampled[:, 0].detach().cpu(), sampled[:, 1].detach().cpu(), s=1) + + plt.savefig("./result.png") \ No newline at end of file