Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
771 changes: 771 additions & 0 deletions density_func.ipynb

Large diffs are not rendered by default.

255 changes: 255 additions & 0 deletions grl/generative_models/edm_diffusion_model/edm_diffusion_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,255 @@
from typing import Optional, Tuple, Literal
from dataclasses import dataclass

import numpy as np
import torch
from torch import Tensor
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from easydict import EasyDict
from functools import partial

from .edm_preconditioner import PreConditioner
from .edm_utils import SIGMA_T, SIGMA_T_DERIV, SIGMA_T_INV, SCALE_T, SCALE_T_DERIV, INITIAL_SIGMA_MAX, INITIAL_SIGMA_MIN
from .edm_utils import DEFAULT_SOLVER_PARAM
from grl.generative_models.intrinsic_model import IntrinsicModel
from grl.utils import set_seed
from grl.utils.log import log

class Simple(nn.Module):
def __init__(self):
super().__init__()
self.model = nn.Sequential(
nn.Linear(2, 32),
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 2)
)
def forward(self, x, noise, class_labels=None):
return self.model(x)

class EDMModel(nn.Module):

def __init__(self, config: Optional[EasyDict]=None) -> None:

super().__init__()
self.config= config
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use single-layer config

# self.x_size = config.x_size
self.device = config.device

# EDM Type ["VP_edm", "VE_edm", "iDDPM_edm", "EDM"]
self.edm_type: str = config.edm_model.path.edm_type
assert self.edm_type in ["VP_edm", "VE_edm", "iDDPM_edm", "EDM"], \
f"Your edm type should in 'VP_edm', 'VE_edm', 'iDDPM_edm', 'EDM'], but got {self.edm_type}"

#* 1. Construct basic Unet architecture through params in config
self.base_denoise_network = Simple()

#* 2. Precond setup
self.params = config.edm_model.path.params
self.preconditioner = PreConditioner(
self.edm_type,
base_denoise_model=self.base_denoise_network,
use_mixes_precision=False,
**self.params
)

#* 3. Solver setup
self.solver_type = config.edm_model.solver.solver_type
assert self.solver_type in ['euler', 'heun']

self.solver_params = DEFAULT_SOLVER_PARAM
self.solver_params.update(config.edm_model.solver.params)

# Initialize sigma_min and sigma_max if not provided


self.sigma_min = INITIAL_SIGMA_MIN[self.edm_type] if "sigma_min" not in self.params else self.params.sigma_min
self.sigma_max = INITIAL_SIGMA_MAX[self.edm_type] if "sigma_max" not in self.params else self.params.sigma_max


def get_type(self):
return "EDMModel"

# For VP_edm
def _sample_sigma_weight_train(self, x: Tensor, **params) -> Tuple[Tensor, Tensor]:
# assert the first dim of x is batch size
log.info(f"Params of trainig is: {params}")
rand_shape = [x.shape[0]] + [1] * (x.ndim - 1)
if self.edm_type == "VP_edm":
epsilon_t = params.get("epsilon_t", 1e-5)
beta_d = params.get("beta_d", 19.9)
beta_min = params.get("beta_min", 0.1)

rand_uniform = torch.rand(*rand_shape, device=x.device)
sigma = SIGMA_T["VP_edm"](1 + rand_uniform * (epsilon_t - 1), beta_d, beta_min)
weight = 1 / sigma ** 2
elif self.edm_type == "VE_edm":
rand_uniform = torch.rand(*rand_shape, device=x.device)
sigma = self.sigma_min * ((self.sigma_max / self.sigma_min) ** rand_uniform)
weight = 1 / sigma ** 2
elif self.edm_type == "EDM":
P_mean = params.get("P_mean", -1.2)
P_std = params.get("P_mean", 1.2)
sigma_data = params.get("sigma_data", 0.5)

rand_normal = torch.randn(*rand_shape, device=x.device)
sigma = (rand_normal * P_std + P_mean).exp()
weight = (sigma ** 2 + sigma_data ** 2) / (sigma * sigma_data) ** 2
return sigma, weight

def forward(self,
x: Tensor,
class_labels=None) -> Tensor:
x = x.to(self.device)
sigma, weight = self._sample_sigma_weight_train(x, **self.params)
n = torch.randn_like(x) * sigma
D_xn = self.preconditioner(x+n, sigma, class_labels=class_labels)
loss = weight * ((D_xn - x) ** 2)
return loss


def _get_sigma_steps_t_steps(self, num_steps=18, epsilon_s=1e-3, rho=7):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if this method is called with the same argument during training, you should just call it once

"""
Overview:
Get the schedule of sigma according to differernt t schedules.

"""
self.sigma_min = max(self.sigma_min, self.preconditioner.sigma_min)
self.sigma_max = min(self.sigma_max, self.preconditioner.sigma_max)

# Define time steps in terms of noise level
step_indices = torch.arange(num_steps, dtype=torch.float64, device=self.device)
sigma_steps = None
if self.edm_type == "VP_edm":
vp_beta_d = 2 * (np.log(self.sigma_min ** 2 + 1) / epsilon_s - np.log(self.sigma_max ** 2 + 1)) / (epsilon_s - 1)
vp_beta_min = np.log(self.sigma_max ** 2 + 1) - 0.5 * vp_beta_d

orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1)
sigma_steps = SIGMA_T["VP_edm"](orig_t_steps, vp_beta_d, vp_beta_min)

elif self.edm_type == "VE_edm":
orig_t_steps = (self.sigma_max ** 2) * ((self.sigma_min ** 2 / self.sigma_max ** 2) ** (step_indices / (num_steps - 1)))
sigma_steps = SIGMA_T["VE_edm"](orig_t_steps)

elif self.edm_type == "iDDPM_edm":
M, C_1, C_2 = self.params.M, self.params.C_1, self.params.C_2

u = torch.zeros(M + 1, dtype=torch.float64, device=self.device)
alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2
for j in torch.arange(self.params.M, 0, -1, device=self.device): # M, ..., 1
u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt()
u_filtered = u[torch.logical_and(u >= self.sigma_min, u <= self.sigma_max)]

sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)]
orig_t_steps = SIGMA_T_INV[self.edm_type](self.preconditioner.round_sigma(sigma_steps))

elif self.edm_type == "EDM":
sigma_steps = (self.sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * \
(self.sigma_min ** (1 / rho) - self.sigma_max ** (1 / rho))) ** rho
orig_t_steps = SIGMA_T_INV[self.edm_type](self.preconditioner.round_sigma(sigma_steps))

t_steps = torch.cat([orig_t_steps, torch.zeros_like(orig_t_steps[:1])]) # t_N = 0

return sigma_steps, t_steps


def _get_sigma_deriv_inv_scale_deriv(self, epsilon_s=1e-3):
"""
Overview:
Get sigma(t) for different solver schedules.

Returns:
sigma(t), sigma'(t), sigma^{-1}(sigma)
"""
vp_beta_d = 2 * (np.log(self.sigma_min ** 2 + 1) / epsilon_s - np.log(self.sigma_max ** 2 + 1)) / (epsilon_s - 1)
vp_beta_min = np.log(self.sigma_max ** 2 + 1) - 0.5 * vp_beta_d
sigma = partial(SIGMA_T[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min)
sigma_deriv = partial(SIGMA_T_DERIV[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min)
sigma_inv = partial(SIGMA_T_INV[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min)
scale = partial(SCALE_T[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min)
scale_deriv = partial(SCALE_T_DERIV[self.edm_type], beta_d=vp_beta_d, beta_min=vp_beta_min)

return sigma, sigma_deriv, sigma_inv, scale, scale_deriv


def sample(self,
t_span,
batch_size,
latents: Tensor,
class_labels: Tensor=None,
use_stochastic: bool=False,
**solver_kwargs
) -> Tensor:

# Get sigmas, scales, and timesteps
log.info(f"Solver param is {self.solver_params}")
num_steps = self.solver_params.num_steps
epsilon_s = self.solver_params.epsilon_s
rho = self.solver_params.rho

latents = latents.to(self.device)
sigma_steps, t_steps = self._get_sigma_steps_t_steps(num_steps=num_steps, epsilon_s=epsilon_s, rho=rho)
sigma, sigma_deriv, sigma_inv, scale, scale_deriv = self._get_sigma_deriv_inv_scale_deriv()

S_churn = self.solver_params.S_churn
S_min = self.solver_params.S_min
S_max = self.solver_params.S_max
S_noise = self.solver_params.S_noise
alpha = self.solver_params.alpha


if not use_stochastic:
# Main sampling loop
t_next = t_steps[0]
x_next = latents.to(torch.float64) * (sigma(t_next) * scale(t_next))
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
x_cur = x_next

# Increase noise temporarily.
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0
t_hat = sigma_inv(self.preconditioner.round_sigma(sigma(t_cur) + gamma * sigma(t_cur)))
x_hat = scale(t_hat) / scale(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * scale(t_hat) * S_noise * torch.randn_like(x_cur)

# Euler step.
h = t_next - t_hat
denoised = self.preconditioner(x_hat / scale(t_hat), sigma(t_hat), class_labels).to(torch.float64)
d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + scale_deriv(t_hat) / scale(t_hat)) * x_hat - sigma_deriv(t_hat) * scale(t_hat) / sigma(t_hat) * denoised
x_prime = x_hat + alpha * h * d_cur
t_prime = t_hat + alpha * h

# Apply 2nd order correction.
if self.solver_type == 'euler' or i == num_steps - 1:
x_next = x_hat + h * d_cur
else:
assert self.solver_type == 'heun'
denoised = self.preconditioner(x_prime / scale(t_prime), sigma(t_prime), class_labels).to(torch.float64)
d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + scale_deriv(t_prime) / scale(t_prime)) * x_prime - sigma_deriv(t_prime) * scale(t_prime) / sigma(t_prime) * denoised
x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime)

else:
assert self.edm_type == "EDM", f"Stochastic can only use in EDM, but your precond type is {self.edm_type}"
x_next = latents.to(torch.float64) * t_steps[0]
for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
x_cur = x_next

# Increase noise temporarily.
gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
t_hat = self.preconditioner.round_sigma(t_cur + gamma * t_cur)
x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * torch.randn_like(x_cur)

# Euler step.
denoised = self.preconditioner(x_hat, t_hat, class_labels).to(torch.float64)
d_cur = (x_hat - denoised) / t_hat
x_next = x_hat + (t_next - t_hat) * d_cur

# Apply 2nd order correction.
if i < num_steps - 1:
denoised = self.preconditioner(x_next, t_next, class_labels).to(torch.float64)
d_prime = (x_next - denoised) / t_next
x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)


return x_next
111 changes: 111 additions & 0 deletions grl/generative_models/edm_diffusion_model/edm_preconditioner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
from typing import Optional, Tuple, Literal
from dataclasses import dataclass

import numpy as np
import torch
from torch import Tensor, as_tensor
import torch.nn as nn
import torch.nn.functional as F

from .edm_utils import SIGMA_T, SIGMA_T_INV

class PreConditioner(nn.Module):

def __init__(self,
precondition_type: Literal["VP_edm", "VE_edm", "iDDPM_edm", "EDM"] = "EDM",
base_denoise_model: nn.Module = None,
use_mixes_precision: bool = False,
**precond_config_kwargs) -> None:

super().__init__()
self.precondition_type = precondition_type
self.base_denoise_model = base_denoise_model
self.use_mixes_precision = use_mixes_precision

if self.precondition_type == "VP_edm":
self.beta_d = precond_config_kwargs.get("beta_d", 19.9)
self.beta_min = precond_config_kwargs.get("beta_min", 0.1)
self.M = precond_config_kwargs.get("M", 1000)
self.epsilon_t = precond_config_kwargs.get("epsilon_t", 1e-5)

self.sigma_min = SIGMA_T["VP_edm"](self.epsilon_t, self.beta_d, self.beta_min)
self.sigma_max = SIGMA_T["VP_edm"](1, self.beta_d, self.beta_min)

elif self.precondition_type == "VE_edm":
self.sigma_min = precond_config_kwargs.get("sigma_min", 0.02)
self.sigma_max = precond_config_kwargs.get("sigma_max", 100)

elif self.precondition_type == "iDDPM_edm":
self.C_1 = precond_config_kwargs.get("C_1", 0.001)
self.C_2 = precond_config_kwargs.get("C_2", 0.008)
self.M = precond_config_kwargs.get("M", 1000)
u = torch.zeros(self.M + 1)
for j in range(self.M, 0, -1): # M, ..., 1
u[j - 1] = ((u[j] ** 2 + 1) / (self.alpha_bar(j - 1) / self.alpha_bar(j)).clip(min=self.C_1) - 1).sqrt()
self.register_buffer('u', u)
self.sigma_min = float(u[self.M - 1])
self.sigma_max = float(u[0])

elif self.precondition_type == "EDM":
self.sigma_min = precond_config_kwargs.get("sigma_min", 0.002)
self.sigma_max = precond_config_kwargs.get("sigma_max", 80)
self.sigma_data = precond_config_kwargs.get("sigma_data", 0.5)

else:
raise ValueError(f"Please check your precond type {self.precondition_type} is in ['VP_edm', 'VE_edm', 'iDDPM_edm', 'EDM']")


# For iDDPM_edm
def alpha_bar(self, j):
assert self.precondition_type == "iDDPM_edm", f"Only iDDPM_edm supports the alpha bar function, but your precond type is {self.precondition_type}"
j = torch.as_tensor(j)
return (0.5 * np.pi * j / self.M / (self.C_2 + 1)).sin() ** 2


def round_sigma(self, sigma, return_index=False):

if self.precondition_type == "iDDPM_edm":
sigma = torch.as_tensor(sigma)
index = torch.cdist(sigma.to(self.u.device).to(torch.float32).reshape(1, -1, 1), self.u.reshape(1, -1, 1)).argmin(2)
result = index if return_index else self.u[index.flatten()].to(sigma.dtype)
return result.reshape(sigma.shape).to(sigma.device)
else:
return torch.as_tensor(sigma)

def get_precondition_c(self, sigma: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:

if self.precondition_type == "VP_edm":
c_skip = 1
c_out = -sigma
c_in = 1 / (sigma ** 2 + 1).sqrt()
c_noise = (self.M - 1) * SIGMA_T_INV["VP_edm"](sigma, self.beta_d, self.beta_min)
elif self.precondition_type == "VE_edm":
c_skip = 1
c_out = sigma
c_in = 1
c_noise = (0.5 * sigma).log()
elif self.precondition_type == "iDDPM_edm":
c_skip = 1
c_out = -sigma
c_in = 1 / (sigma ** 2 + 1).sqrt()
c_noise = self.M - 1 - self.round_sigma(sigma, return_index=True).to(torch.float32)
elif self.precondition_type == "EDM":
c_skip = self.sigma_data ** 2 / (sigma ** 2 + self.sigma_data ** 2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for some constant variable, such as self.sigma_data ** 2, we can pre-compute them to save computation.

c_out = sigma * self.sigma_data / (sigma ** 2 + self.sigma_data ** 2).sqrt()
c_in = 1 / (self.sigma_data ** 2 + sigma ** 2).sqrt()
c_noise = sigma.log() / 4
return c_skip, c_out, c_in, c_noise

def forward(self, x: Tensor, sigma: Tensor, class_labels=None, **model_kwargs):
# Suppose the first dim of x is batch size
x = x.to(torch.float32)
sigma_shape = [x.shape[0]] + [1] * (x.ndim - 1)
if sigma.numel() == 1:
sigma = sigma.view(-1).expand(*sigma_shape)

dtype = torch.float16 if (self.use_mixes_precision and x.device.type == 'cuda') else torch.float32
c_skip, c_out, c_in, c_noise = self.get_precondition_c(sigma)
F_x = self.base_denoise_model((c_in * x).to(dtype), c_noise.flatten(), class_labels=class_labels, **model_kwargs)
assert F_x.dtype == dtype
D_x = c_skip * x + c_out * F_x.to(torch.float32)
return D_x
Loading