Skip to content

Commit c4024ce

Browse files
Nathan Lambertbglick13
andauthored
Add UNet 1d for RL model for planning + colab (huggingface#105)
* re-add RL model code * match model forward api * add register_to_config, pass training tests * fix tests, update forward outputs * remove unused code, some comments * add to docs * remove extra embedding code * unify time embedding * remove conv1d output sequential * remove sequential from conv1dblock * style and deleting duplicated code * clean files * remove unused variables * clean variables * add 1d resnet block structure for downsample * rename as unet1d * fix renaming * rename files * add get_block(...) api * unify args for model1d like model2d * minor cleaning * fix docs * improve 1d resnet blocks * fix tests, remove permuts * fix style * add output activation * rename flax blocks file * Add Value Function and corresponding example script to Diffuser implementation (huggingface#884) * valuefunction code * start example scripts * missing imports * bug fixes and placeholder example script * add value function scheduler * load value function from hub and get best actions in example * very close to working example * larger batch size for planning * more tests * merge unet1d changes * wandb for debugging, use newer models * success! * turns out we just need more diffusion steps * run on modal * merge and code cleanup * use same api for rl model * fix variance type * wrong normalization function * add tests * style * style and quality * edits based on comments * style and quality * remove unused var * hack unet1d into a value function * add pipeline * fix arg order * add pipeline to core library * community pipeline * fix couple shape bugs * style * Apply suggestions from code review Co-authored-by: Nathan Lambert <[email protected]> * update post merge of scripts * add mdiblock / outblock architecture * Pipeline cleanup (huggingface#947) * valuefunction code * start example scripts * missing imports * bug fixes and placeholder example script * add value function scheduler * load value function from hub and get best actions in example * very close to working example * larger batch size for planning * more tests * merge unet1d changes * wandb for debugging, use newer models * success! * turns out we just need more diffusion steps * run on modal * merge and code cleanup * use same api for rl model * fix variance type * wrong normalization function * add tests * style * style and quality * edits based on comments * style and quality * remove unused var * hack unet1d into a value function * add pipeline * fix arg order * add pipeline to core library * community pipeline * fix couple shape bugs * style * Apply suggestions from code review * clean up comments * convert older script to using pipeline and add readme * rename scripts * style, update tests * delete unet rl model file * remove imports in src Co-authored-by: Nathan Lambert <[email protected]> * Update src/diffusers/models/unet_1d_blocks.py * Update tests/test_models_unet.py * RL Cleanup v2 (huggingface#965) * valuefunction code * start example scripts * missing imports * bug fixes and placeholder example script * add value function scheduler * load value function from hub and get best actions in example * very close to working example * larger batch size for planning * more tests * merge unet1d changes * wandb for debugging, use newer models * success! * turns out we just need more diffusion steps * run on modal * merge and code cleanup * use same api for rl model * fix variance type * wrong normalization function * add tests * style * style and quality * edits based on comments * style and quality * remove unused var * hack unet1d into a value function * add pipeline * fix arg order * add pipeline to core library * community pipeline * fix couple shape bugs * style * Apply suggestions from code review * clean up comments * convert older script to using pipeline and add readme * rename scripts * style, update tests * delete unet rl model file * remove imports in src * add specific vf block and update tests * style * Update tests/test_models_unet.py Co-authored-by: Nathan Lambert <[email protected]> * fix quality in tests * fix quality style, split test file * fix checks / tests * make timesteps closer to main * unify block API * unify forward api * delete lines in examples * style * examples style * all tests pass * make style * make dance_diff test pass * Refactoring RL PR (huggingface#1200) * init file changes * add import utils * finish cleaning files, imports * remove import flags * clean examples * fix imports, tests for merge * update readmes * hotfix for tests * quality * fix some tests * change defaults * more mps test fixes * unet1d defaults * do not default import experimental * defaults for tests * fix tests * fix-copies * fix * changes per Patrik's comments (huggingface#1285) * changes per Patrik's comments * update conversion script * fix renaming * skip more mps tests * last test fix * Update examples/rl/README.md Co-authored-by: Ben Glickenhaus <[email protected]>
1 parent 313afb7 commit c4024ce

File tree

9 files changed

+696
-58
lines changed

9 files changed

+696
-58
lines changed

experimental/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# 🧨 Diffusers Experimental
2+
3+
We are adding experimental code to support novel applications and usages of the Diffusers library.
4+
Currently, the following experiments are supported:
5+
* Reinforcement learning via an implementation of the [Diffuser](https://arxiv.org/abs/2205.09991) model.

experimental/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .rl import ValueGuidedRLPipeline

experimental/rl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .value_guided_sampling import ValueGuidedRLPipeline
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
# Copyright 2022 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import numpy as np
16+
import torch
17+
18+
import tqdm
19+
20+
from ...models.unet_1d import UNet1DModel
21+
from ...pipeline_utils import DiffusionPipeline
22+
from ...utils.dummy_pt_objects import DDPMScheduler
23+
24+
25+
class ValueGuidedRLPipeline(DiffusionPipeline):
26+
def __init__(
27+
self,
28+
value_function: UNet1DModel,
29+
unet: UNet1DModel,
30+
scheduler: DDPMScheduler,
31+
env,
32+
):
33+
super().__init__()
34+
self.value_function = value_function
35+
self.unet = unet
36+
self.scheduler = scheduler
37+
self.env = env
38+
self.data = env.get_dataset()
39+
self.means = dict()
40+
for key in self.data.keys():
41+
try:
42+
self.means[key] = self.data[key].mean()
43+
except:
44+
pass
45+
self.stds = dict()
46+
for key in self.data.keys():
47+
try:
48+
self.stds[key] = self.data[key].std()
49+
except:
50+
pass
51+
self.state_dim = env.observation_space.shape[0]
52+
self.action_dim = env.action_space.shape[0]
53+
54+
def normalize(self, x_in, key):
55+
return (x_in - self.means[key]) / self.stds[key]
56+
57+
def de_normalize(self, x_in, key):
58+
return x_in * self.stds[key] + self.means[key]
59+
60+
def to_torch(self, x_in):
61+
if type(x_in) is dict:
62+
return {k: self.to_torch(v) for k, v in x_in.items()}
63+
elif torch.is_tensor(x_in):
64+
return x_in.to(self.unet.device)
65+
return torch.tensor(x_in, device=self.unet.device)
66+
67+
def reset_x0(self, x_in, cond, act_dim):
68+
for key, val in cond.items():
69+
x_in[:, key, act_dim:] = val.clone()
70+
return x_in
71+
72+
def run_diffusion(self, x, conditions, n_guide_steps, scale):
73+
batch_size = x.shape[0]
74+
y = None
75+
for i in tqdm.tqdm(self.scheduler.timesteps):
76+
# create batch of timesteps to pass into model
77+
timesteps = torch.full((batch_size,), i, device=self.unet.device, dtype=torch.long)
78+
for _ in range(n_guide_steps):
79+
with torch.enable_grad():
80+
x.requires_grad_()
81+
y = self.value_function(x.permute(0, 2, 1), timesteps).sample
82+
grad = torch.autograd.grad([y.sum()], [x])[0]
83+
84+
posterior_variance = self.scheduler._get_variance(i)
85+
model_std = torch.exp(0.5 * posterior_variance)
86+
grad = model_std * grad
87+
grad[timesteps < 2] = 0
88+
x = x.detach()
89+
x = x + scale * grad
90+
x = self.reset_x0(x, conditions, self.action_dim)
91+
prev_x = self.unet(x.permute(0, 2, 1), timesteps).sample.permute(0, 2, 1)
92+
x = self.scheduler.step(prev_x, i, x, predict_epsilon=False)["prev_sample"]
93+
94+
# apply conditions to the trajectory
95+
x = self.reset_x0(x, conditions, self.action_dim)
96+
x = self.to_torch(x)
97+
return x, y
98+
99+
def __call__(self, obs, batch_size=64, planning_horizon=32, n_guide_steps=2, scale=0.1):
100+
# normalize the observations and create batch dimension
101+
obs = self.normalize(obs, "observations")
102+
obs = obs[None].repeat(batch_size, axis=0)
103+
104+
conditions = {0: self.to_torch(obs)}
105+
shape = (batch_size, planning_horizon, self.state_dim + self.action_dim)
106+
107+
# generate initial noise and apply our conditions (to make the trajectories start at current state)
108+
x1 = torch.randn(shape, device=self.unet.device)
109+
x = self.reset_x0(x1, conditions, self.action_dim)
110+
x = self.to_torch(x)
111+
112+
# run the diffusion process
113+
x, y = self.run_diffusion(x, conditions, n_guide_steps, scale)
114+
115+
# sort output trajectories by value
116+
sorted_idx = y.argsort(0, descending=True).squeeze()
117+
sorted_values = x[sorted_idx]
118+
actions = sorted_values[:, :, : self.action_dim]
119+
actions = actions.detach().cpu().numpy()
120+
denorm_actions = self.de_normalize(actions, key="actions")
121+
122+
# select the action with the highest value
123+
if y is not None:
124+
selected_index = 0
125+
else:
126+
# if we didn't run value guiding, select a random action
127+
selected_index = np.random.randint(0, batch_size)
128+
denorm_actions = denorm_actions[selected_index, 0]
129+
return denorm_actions

models/embeddings.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,14 +62,21 @@ def get_timestep_embedding(
6262

6363

6464
class TimestepEmbedding(nn.Module):
65-
def __init__(self, channel: int, time_embed_dim: int, act_fn: str = "silu"):
65+
def __init__(self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None):
6666
super().__init__()
6767

68-
self.linear_1 = nn.Linear(channel, time_embed_dim)
68+
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
6969
self.act = None
7070
if act_fn == "silu":
7171
self.act = nn.SiLU()
72-
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
72+
elif act_fn == "mish":
73+
self.act = nn.Mish()
74+
75+
if out_dim is not None:
76+
time_embed_dim_out = out_dim
77+
else:
78+
time_embed_dim_out = time_embed_dim
79+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out)
7380

7481
def forward(self, sample):
7582
sample = self.linear_1(sample)

models/resnet.py

Lines changed: 136 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,84 @@
55
import torch.nn.functional as F
66

77

8+
class Upsample1D(nn.Module):
9+
"""
10+
An upsampling layer with an optional convolution.
11+
12+
Parameters:
13+
channels: channels in the inputs and outputs.
14+
use_conv: a bool determining if a convolution is applied.
15+
use_conv_transpose:
16+
out_channels:
17+
"""
18+
19+
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
20+
super().__init__()
21+
self.channels = channels
22+
self.out_channels = out_channels or channels
23+
self.use_conv = use_conv
24+
self.use_conv_transpose = use_conv_transpose
25+
self.name = name
26+
27+
self.conv = None
28+
if use_conv_transpose:
29+
self.conv = nn.ConvTranspose1d(channels, self.out_channels, 4, 2, 1)
30+
elif use_conv:
31+
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, padding=1)
32+
33+
def forward(self, x):
34+
assert x.shape[1] == self.channels
35+
if self.use_conv_transpose:
36+
return self.conv(x)
37+
38+
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
39+
40+
if self.use_conv:
41+
x = self.conv(x)
42+
43+
return x
44+
45+
46+
class Downsample1D(nn.Module):
47+
"""
48+
A downsampling layer with an optional convolution.
49+
50+
Parameters:
51+
channels: channels in the inputs and outputs.
52+
use_conv: a bool determining if a convolution is applied.
53+
out_channels:
54+
padding:
55+
"""
56+
57+
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
58+
super().__init__()
59+
self.channels = channels
60+
self.out_channels = out_channels or channels
61+
self.use_conv = use_conv
62+
self.padding = padding
63+
stride = 2
64+
self.name = name
65+
66+
if use_conv:
67+
self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
68+
else:
69+
assert self.channels == self.out_channels
70+
self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride)
71+
72+
def forward(self, x):
73+
assert x.shape[1] == self.channels
74+
return self.conv(x)
75+
76+
877
class Upsample2D(nn.Module):
978
"""
1079
An upsampling layer with an optional convolution.
1180
1281
Parameters:
1382
channels: channels in the inputs and outputs.
1483
use_conv: a bool determining if a convolution is applied.
15-
dims: determines if the signal is 1D, 2D, or 3D. If 3D, then upsampling occurs in the inner-two dimensions.
84+
use_conv_transpose:
85+
out_channels:
1686
"""
1787

1888
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
@@ -80,7 +150,8 @@ class Downsample2D(nn.Module):
80150
Parameters:
81151
channels: channels in the inputs and outputs.
82152
use_conv: a bool determining if a convolution is applied.
83-
dims: determines if the signal is 1D, 2D, or 3D. If 3D, then downsampling occurs in the inner-two dimensions.
153+
out_channels:
154+
padding:
84155
"""
85156

86157
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
@@ -415,6 +486,69 @@ def forward(self, hidden_states):
415486
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
416487

417488

489+
# unet_rl.py
490+
def rearrange_dims(tensor):
491+
if len(tensor.shape) == 2:
492+
return tensor[:, :, None]
493+
if len(tensor.shape) == 3:
494+
return tensor[:, :, None, :]
495+
elif len(tensor.shape) == 4:
496+
return tensor[:, :, 0, :]
497+
else:
498+
raise ValueError(f"`len(tensor)`: {len(tensor)} has to be 2, 3 or 4.")
499+
500+
501+
class Conv1dBlock(nn.Module):
502+
"""
503+
Conv1d --> GroupNorm --> Mish
504+
"""
505+
506+
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=8):
507+
super().__init__()
508+
509+
self.conv1d = nn.Conv1d(inp_channels, out_channels, kernel_size, padding=kernel_size // 2)
510+
self.group_norm = nn.GroupNorm(n_groups, out_channels)
511+
self.mish = nn.Mish()
512+
513+
def forward(self, x):
514+
x = self.conv1d(x)
515+
x = rearrange_dims(x)
516+
x = self.group_norm(x)
517+
x = rearrange_dims(x)
518+
x = self.mish(x)
519+
return x
520+
521+
522+
# unet_rl.py
523+
class ResidualTemporalBlock1D(nn.Module):
524+
def __init__(self, inp_channels, out_channels, embed_dim, kernel_size=5):
525+
super().__init__()
526+
self.conv_in = Conv1dBlock(inp_channels, out_channels, kernel_size)
527+
self.conv_out = Conv1dBlock(out_channels, out_channels, kernel_size)
528+
529+
self.time_emb_act = nn.Mish()
530+
self.time_emb = nn.Linear(embed_dim, out_channels)
531+
532+
self.residual_conv = (
533+
nn.Conv1d(inp_channels, out_channels, 1) if inp_channels != out_channels else nn.Identity()
534+
)
535+
536+
def forward(self, x, t):
537+
"""
538+
Args:
539+
x : [ batch_size x inp_channels x horizon ]
540+
t : [ batch_size x embed_dim ]
541+
542+
returns:
543+
out : [ batch_size x out_channels x horizon ]
544+
"""
545+
t = self.time_emb_act(t)
546+
t = self.time_emb(t)
547+
out = self.conv_in(x) + rearrange_dims(t)
548+
out = self.conv_out(out)
549+
return out + self.residual_conv(x)
550+
551+
418552
def upsample_2d(hidden_states, kernel=None, factor=2, gain=1):
419553
r"""Upsample2D a batch of 2D images with the given filter.
420554
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given

0 commit comments

Comments
 (0)