Skip to content

Commit defcfd6

Browse files
nifleischdavidberenstein1957
authored andcommitted
feat: add fastercache and pab (#92)
* fix: correct docstring in deepcache * feat: add model checks * feat: add pyramid attention broadcast (pab) cacher * feat: add fastercache cacher * tests: add flux tiny random fixture * tests: add algorithms tests for pab and fastercache * tests: add combination tests for pab and fastercache * fix: add 1 as value for interval parameter
1 parent 3422766 commit defcfd6

File tree

7 files changed

+514
-2
lines changed

7 files changed

+514
-2
lines changed

src/pruna/algorithms/caching/deepcache.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,12 @@ def model_check_fn(self, model: Any) -> bool:
8080

8181
def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
8282
"""
83-
Apply the step caching algorithm to the model.
83+
Apply the deepcache algorithm to the model.
8484
8585
Parameters
8686
----------
8787
model : Any
88-
The model to apply the step caching algorithm to.
88+
The model to apply the deepcache algorithm to.
8989
smash_config : SmashConfigPrefixWrapper
9090
The configuration for the caching.
9191
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from __future__ import annotations
15+
16+
from typing import Any, Dict, Optional, Tuple
17+
18+
from ConfigSpace import OrdinalHyperparameter
19+
20+
from pruna.algorithms.caching import PrunaCacher
21+
from pruna.config.smash_config import SmashConfigPrefixWrapper
22+
from pruna.engine.model_checks import (
23+
is_allegro_pipeline,
24+
is_cogvideo_pipeline,
25+
is_flux_pipeline,
26+
is_hunyuan_pipeline,
27+
is_latte_pipeline,
28+
is_mochi_pipeline,
29+
is_wan_pipeline,
30+
)
31+
from pruna.logging.logger import pruna_logger
32+
33+
34+
class FasterCacheCacher(PrunaCacher):
35+
"""
36+
Implement FasterCache.
37+
38+
FasterCache is a method that speeds up inference in diffusion transformers by:
39+
- Reusing attention states between successive inference steps, due to high similarity between them
40+
- Skipping unconditional branch prediction used in classifier-free guidance by revealing redundancies between
41+
unconditional and conditional branch outputs for the same timestep, and therefore approximating the unconditional
42+
branch output using the conditional branch output
43+
This implementation reduces the number of tunable parameters by setting pipeline specific parameters according to
44+
https://github.com/huggingface/diffusers/pull/9562.
45+
"""
46+
47+
algorithm_name = "fastercache"
48+
references = {"GitHub": "https://github.com/Vchitect/FasterCache", "Paper": "https://arxiv.org/abs/2410.19355"}
49+
tokenizer_required = False
50+
processor_required = False
51+
dataset_required = False
52+
run_on_cpu = True
53+
run_on_cuda = True
54+
compatible_algorithms = dict(quantizer=["hqq_diffusers", "diffusers_int8"])
55+
56+
def get_hyperparameters(self) -> list:
57+
"""
58+
Get the hyperparameters for the algorithm.
59+
60+
Returns
61+
-------
62+
list
63+
The hyperparameters.
64+
"""
65+
return [
66+
OrdinalHyperparameter(
67+
"interval",
68+
sequence=[1, 2, 3, 4, 5],
69+
default_value=2,
70+
meta=dict(
71+
desc="Interval at which to cache spatial attention blocks - 1 disables caching."
72+
"Higher is faster but might degrade quality."
73+
),
74+
),
75+
]
76+
77+
def model_check_fn(self, model: Any) -> bool:
78+
"""
79+
Check if the model is a valid model for the algorithm.
80+
81+
Parameters
82+
----------
83+
model : Any
84+
The model to check.
85+
86+
Returns
87+
-------
88+
bool
89+
True if the model is a valid model for the algorithm, False otherwise.
90+
"""
91+
pipeline_check_fns = [
92+
is_allegro_pipeline,
93+
is_cogvideo_pipeline,
94+
is_flux_pipeline,
95+
is_hunyuan_pipeline,
96+
is_mochi_pipeline,
97+
is_wan_pipeline,
98+
]
99+
return any(is_pipeline(model) for is_pipeline in pipeline_check_fns)
100+
101+
def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
102+
"""
103+
Apply the fastercache algorithm to the model.
104+
105+
Parameters
106+
----------
107+
model : Any
108+
The model to apply the fastercache algorithm to.
109+
smash_config : SmashConfigPrefixWrapper
110+
The configuration for the caching.
111+
112+
Returns
113+
-------
114+
Any
115+
The smashed model.
116+
"""
117+
imported_modules = self.import_algorithm_packages()
118+
# set default values according to https://huggingface.co/docs/diffusers/en/api/cache
119+
temporal_attention_block_skip_range: Optional[int] = None
120+
spatial_attention_timestep_skip_range: Tuple[int, int] = (-1, 681)
121+
temporal_attention_timestep_skip_range: Optional[Tuple[int, int]] = None
122+
low_frequency_weight_update_timestep_range: Tuple[int, int] = (99, 901)
123+
high_frequency_weight_update_timestep_range: Tuple[int, int] = (-1, 301)
124+
unconditional_batch_skip_range: int = 5
125+
unconditional_batch_timestep_skip_range: Tuple[int, int] = (-1, 641)
126+
spatial_attention_block_identifiers: Tuple[str, ...] = (
127+
"blocks.*attn1",
128+
"transformer_blocks.*attn1",
129+
"single_transformer_blocks.*attn1"
130+
)
131+
temporal_attention_block_identifiers: Tuple[str, ...] = ("temporal_transformer_blocks.*attn1",)
132+
attention_weight_callback = lambda _: 0.5 # noqa: E731
133+
tensor_format: str = "BFCHW"
134+
is_guidance_distilled: bool = False
135+
136+
# set configs according to https://github.com/huggingface/diffusers/pull/9562
137+
if is_allegro_pipeline(model):
138+
low_frequency_weight_update_timestep_range = (99, 641)
139+
spatial_attention_block_identifiers = ("transformer_blocks",)
140+
elif is_cogvideo_pipeline(model):
141+
low_frequency_weight_update_timestep_range = (99, 641)
142+
spatial_attention_block_identifiers = ("transformer_blocks",)
143+
attention_weight_callback = lambda _: 0.3 # noqa: E731
144+
elif is_flux_pipeline(model):
145+
spatial_attention_timestep_skip_range = (-1, 961)
146+
spatial_attention_block_identifiers = ("transformer_blocks", "single_transformer_blocks",)
147+
tensor_format = "BCHW"
148+
is_guidance_distilled = True
149+
elif is_hunyuan_pipeline(model):
150+
spatial_attention_timestep_skip_range = (99, 941)
151+
spatial_attention_block_identifiers = ("transformer_blocks", "single_transformer_blocks",)
152+
tensor_format = "BCFHW"
153+
is_guidance_distilled = True
154+
elif is_latte_pipeline(model):
155+
temporal_attention_block_skip_range = 2
156+
temporal_attention_timestep_skip_range = (-1, 681)
157+
low_frequency_weight_update_timestep_range = (99, 641)
158+
spatial_attention_block_identifiers = ("transformer_blocks.*attn1",)
159+
temporal_attention_block_identifiers = ("temporal_transformer_blocks",)
160+
elif is_mochi_pipeline(model):
161+
spatial_attention_timestep_skip_range = (-1, 981)
162+
low_frequency_weight_update_timestep_range = (301, 961)
163+
high_frequency_weight_update_timestep_range = (-1, 851)
164+
unconditional_batch_skip_range = 4
165+
unconditional_batch_timestep_skip_range = (-1, 975)
166+
spatial_attention_block_identifiers = ("transformer_blocks",)
167+
attention_weight_callback = lambda _: 0.6 # noqa: E731
168+
elif is_wan_pipeline(model):
169+
spatial_attention_block_identifiers = ("blocks",)
170+
tensor_format = "BCFHW"
171+
is_guidance_distilled = True
172+
173+
fastercache_config = imported_modules["FasterCacheConfig"](
174+
spatial_attention_block_skip_range=smash_config["interval"],
175+
temporal_attention_block_skip_range=temporal_attention_block_skip_range,
176+
spatial_attention_timestep_skip_range=spatial_attention_timestep_skip_range,
177+
temporal_attention_timestep_skip_range=temporal_attention_timestep_skip_range,
178+
low_frequency_weight_update_timestep_range=low_frequency_weight_update_timestep_range,
179+
high_frequency_weight_update_timestep_range=high_frequency_weight_update_timestep_range,
180+
alpha_low_frequency=1.1,
181+
alpha_high_frequency=1.1,
182+
unconditional_batch_skip_range=unconditional_batch_skip_range,
183+
unconditional_batch_timestep_skip_range=unconditional_batch_timestep_skip_range,
184+
spatial_attention_block_identifiers=spatial_attention_block_identifiers,
185+
temporal_attention_block_identifiers=temporal_attention_block_identifiers,
186+
attention_weight_callback=attention_weight_callback,
187+
tensor_format=tensor_format,
188+
current_timestep_callback=lambda: model.current_timestep,
189+
is_guidance_distilled=is_guidance_distilled,
190+
)
191+
model.transformer.enable_cache(fastercache_config)
192+
return model
193+
194+
def import_algorithm_packages(self) -> Dict[str, Any]:
195+
"""
196+
Import the algorithm packages.
197+
198+
Returns
199+
-------
200+
Dict[str, Any]
201+
The algorithm packages.
202+
"""
203+
try:
204+
from diffusers import FasterCacheConfig
205+
except ModuleNotFoundError:
206+
pruna_logger.error(
207+
"You are trying to use FasterCache, but the FasterCacheConfig can not be imported from diffusers. "
208+
"This is likely because your diffusers version is too old."
209+
)
210+
raise
211+
212+
return dict(FasterCacheConfig=FasterCacheConfig)

0 commit comments

Comments
 (0)