Skip to content

Commit e50e92c

Browse files
committed
add prx
1 parent de6c2c4 commit e50e92c

File tree

14 files changed

+1939
-0
lines changed

14 files changed

+1939
-0
lines changed

docs/diffusers/_toctree.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,8 @@
479479
title: PixArt-α
480480
- local: api/pipelines/pixart_sigma
481481
title: PixArt-Σ
482+
- local: api/pipelines/prx
483+
title: PRX
482484
- local: api/pipelines/qwenimage
483485
title: QwenImage
484486
- local: api/pipelines/sana
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
<!-- Copyright 2025 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License. -->
14+
15+
# PRX
16+
17+
18+
PRX generates high-quality images from text using a simplified MMDIT architecture where text tokens don't update through transformer blocks. It employs flow matching with discrete scheduling for efficient sampling and uses Google's T5Gemma-2B-2B-UL2 model for multi-language text encoding. The ~1.3B parameter transformer delivers fast inference without sacrificing quality. You can choose between Flux VAE (8x compression, 16 latent channels) for balanced quality and speed or DC-AE (32x compression, 32 latent channels) for latent compression and faster processing.
19+
20+
## Available models
21+
22+
PRX offers multiple variants with different VAE configurations, each optimized for specific resolutions. Base models excel with detailed prompts, capturing complex compositions and subtle details. Fine-tuned models trained on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) improve aesthetic quality, especially with simpler prompts.
23+
24+
25+
| Model | Resolution | Fine-tuned | Distilled | Description | Suggested prompts | Suggested parameters | Recommended dtype |
26+
|:-----:|:-----------------:|:----------:|:----------:|:----------:|:----------:|:----------:|:----------:|
27+
| [`Photoroom/prx-256-t2i`](https://huggingface.co/Photoroom/prx-256-t2i)| 256 | No | No | Base model pre-trained at 256 with Flux VAE|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `mindspore.bfloat16` |
28+
| [`Photoroom/prx-256-t2i-sft`](https://huggingface.co/Photoroom/prx-256-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts|28 steps, cfg=5.0| `mindspore.bfloat16` |
29+
| [`Photoroom/prx-512-t2i`](https://huggingface.co/Photoroom/prx-512-t2i)| 512 | No | No | Base model pre-trained at 512 with Flux VAE |Works best with detailed prompts in natural language|28 steps, cfg=5.0| `mindspore.bfloat16` |
30+
| [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with Flux VAE | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `mindspore.bfloat16` |
31+
| [`Photoroom/prx-512-t2i-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-sft`](https://huggingface.co/Photoroom/prx-512-t2i-sft) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `mindspore.bfloat16` |
32+
| [`Photoroom/prx-512-t2i-dc-ae`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae)| 512 | No | No | Base model pre-trained at 512 with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae)|Works best with detailed prompts in natural language|28 steps, cfg=5.0| `mindspore.bfloat16` |
33+
| [`Photoroom/prx-512-t2i-dc-ae-sft`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft)| 512 | Yes | No | Fine-tuned on the [Alchemist dataset](https://huggingface.co/datasets/yandex/alchemist) dataset with [Deep Compression Autoencoder (DC-AE)](https://hanlab.mit.edu/projects/dc-ae) | Can handle less detailed prompts in natural language|28 steps, cfg=5.0| `mindspore.bfloat16` |
34+
| [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled)| 512 | Yes | Yes | 8-step distilled model from [`Photoroom/prx-512-t2i-dc-ae-sft-distilled`](https://huggingface.co/Photoroom/prx-512-t2i-dc-ae-sft-distilled) | Can handle less detailed prompts in natural language|8 steps, cfg=1.0| `mindspore.bfloat16` |s
35+
36+
Refer to [this](https://huggingface.co/collections/Photoroom/prx-models-68e66254c202ebfab99ad38e) collection for more information.
37+
38+
## Loading the pipeline
39+
40+
Load the pipeline with [`~DiffusionPipeline.from_pretrained`].
41+
42+
```py
43+
import mindspore as ms
44+
from mindone.diffusers.pipelines.prx import PRXPipeline
45+
46+
# Load pipeline - VAE and text encoder will be loaded from HuggingFace
47+
pipe = PRXPipeline.from_pretrained("Photoroom/prx-512-t2i-sft", mindspore_dtype=ms.bfloat16)
48+
49+
prompt = "A front-facing portrait of a lion the golden savanna at sunset."
50+
image = pipe(prompt, num_inference_steps=28, guidance_scale=5.0).images[0]
51+
image.save("prx_output.png")
52+
```
53+
54+
::: mindone.diffusers.PRXPipeline
55+
56+
::: mindone.diffusers.pipelines.prx.pipeline_output.PRXPipelineOutput

mindone/diffusers/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@
106106
"OmniGenTransformer2DModel",
107107
"PixArtTransformer2DModel",
108108
"PriorTransformer",
109+
"PRXTransformer2DModel",
109110
"QwenImageTransformer2DModel",
110111
"SanaControlNetModel",
111112
"SanaTransformer2DModel",
@@ -263,6 +264,7 @@
263264
"PixArtAlphaPipeline",
264265
"PixArtSigmaPAGPipeline",
265266
"PixArtSigmaPipeline",
267+
"PRXPipeline",
266268
"QwenImageImg2ImgPipeline",
267269
"QwenImageInpaintPipeline",
268270
"QwenImagePipeline",
@@ -487,6 +489,7 @@
487489
OmniGenTransformer2DModel,
488490
PixArtTransformer2DModel,
489491
PriorTransformer,
492+
PRXTransformer2DModel,
490493
QwenImageTransformer2DModel,
491494
SanaControlNetModel,
492495
SanaTransformer2DModel,
@@ -655,6 +658,7 @@
655658
PixArtAlphaPipeline,
656659
PixArtSigmaPAGPipeline,
657660
PixArtSigmaPipeline,
661+
PRXPipeline,
658662
QwenImageEditInpaintPipeline,
659663
QwenImageEditPipeline,
660664
QwenImageImg2ImgPipeline,

mindone/diffusers/models/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@
8383
"transformers.transformer_lumina2": ["Lumina2Transformer2DModel"],
8484
"transformers.transformer_mochi": ["MochiTransformer3DModel"],
8585
"transformers.transformer_omnigen": ["OmniGenTransformer2DModel"],
86+
"transformers.transformer_prx": ["PRXTransformer2DModel"],
8687
"transformers.transformer_qwenimage": ["QwenImageTransformer2DModel"],
8788
"transformers.transformer_sd3": ["SD3Transformer2DModel"],
8889
"transformers.transformer_skyreels_v2": ["SkyReelsV2Transformer3DModel"],
@@ -167,6 +168,7 @@
167168
OmniGenTransformer2DModel,
168169
PixArtTransformer2DModel,
169170
PriorTransformer,
171+
PRXTransformer2DModel,
170172
QwenImageTransformer2DModel,
171173
SanaTransformer2DModel,
172174
SD3Transformer2DModel,
Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import math
2+
from typing import Optional
3+
4+
import mindspore as ms
5+
from mindspore import mint, ops
6+
7+
8+
def dispatch_attention_fn(
9+
query: ms.Tensor,
10+
key: ms.Tensor,
11+
value: ms.Tensor,
12+
attn_mask: Optional[ms.Tensor] = None,
13+
dropout_p: float = 0.0,
14+
is_causal: bool = False,
15+
scale: Optional[float] = None,
16+
):
17+
query, key, value = (x.permute(0, 2, 1, 3) for x in (query, key, value))
18+
# Note: PyTorch's SDPA and MindSpore's FA handle `attention_mask` slightly differently.
19+
# In PyTorch, if the mask is not boolean (e.g., float32 with 0/1 values), it is interpreted
20+
# as an additive bias: `attn_bias = attn_mask + attn_bias`.
21+
# This implicit branch may lead to issues if the pipeline mistakenly provides
22+
# a 0/1 float mask instead of a boolean mask.
23+
# While this behavior is consistent with HF Diffusers for now,
24+
# it may still be a potential bug source worth validating.
25+
if attn_mask is not None and attn_mask.dtype != ms.bool_ and 1.0 in attn_mask:
26+
L, S = query.shape[-2], key.shape[-2]
27+
scale_factor = 1 / math.sqrt(query.shape[-1]) if scale is None else scale
28+
attn_bias = mint.zeros((L, S), dtype=query.dtype)
29+
if is_causal:
30+
if attn_mask is not None:
31+
if attn_mask.dtype == ms.bool_:
32+
attn_mask = mint.logical_and(attn_mask, mint.ones((L, S), dtype=ms.bool_).tril(diagonal=0))
33+
else:
34+
attn_mask = attn_mask + mint.triu(
35+
mint.full((L, S), float("-inf"), dtype=attn_mask.dtype), diagonal=1
36+
)
37+
else:
38+
temp_mask = mint.ones((L, S), dtype=ms.bool_).tril(diagonal=0)
39+
attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
40+
attn_bias = attn_bias.to(query.dtype)
41+
42+
if attn_mask is not None:
43+
if attn_mask.dtype == ms.bool_:
44+
attn_bias = attn_bias.masked_fill(attn_mask.logical_not(), float("-inf"))
45+
else:
46+
attn_bias = attn_mask + attn_bias
47+
48+
attn_weight = mint.matmul(query, key.swapaxes(-2, -1)) * scale_factor
49+
attn_weight += attn_bias
50+
attn_weight = mint.softmax(attn_weight, dim=-1)
51+
attn_weight = ops.dropout(attn_weight, dropout_p, training=True)
52+
return mint.matmul(attn_weight, value).permute(0, 2, 1, 3)
53+
54+
if query.dtype in (ms.float16, ms.bfloat16):
55+
out = flash_attention_op(query, key, value, attn_mask, keep_prob=1 - dropout_p, scale=scale)
56+
else:
57+
out = flash_attention_op(
58+
query.to(ms.float16),
59+
key.to(ms.float16),
60+
value.to(ms.float16),
61+
attn_mask,
62+
keep_prob=1 - dropout_p,
63+
scale=scale,
64+
).to(query.dtype)
65+
return out.permute(0, 2, 1, 3)
66+
67+
68+
def flash_attention_op(
69+
query: ms.Tensor,
70+
key: ms.Tensor,
71+
value: ms.Tensor,
72+
attn_mask: Optional[ms.Tensor] = None,
73+
keep_prob: float = 1.0,
74+
scale: Optional[float] = None,
75+
):
76+
# For most scenarios, qkv has been processed into a BNSD layout before sdp
77+
input_layout = "BNSD"
78+
head_num = query.shape[1]
79+
if scale is None:
80+
scale = query.shape[-1] ** (-0.5)
81+
82+
# In case qkv is 3-dim after `head_to_batch_dim`
83+
if query.ndim == 3:
84+
input_layout = "BSH"
85+
head_num = 1
86+
87+
# process `attn_mask` as logic is different between PyTorch and Mindspore
88+
# In MindSpore, False indicates retention and True indicates discard, in PyTorch it is the opposite
89+
if attn_mask is not None:
90+
attn_mask = mint.logical_not(attn_mask) if attn_mask.dtype == ms.bool_ else attn_mask.bool()
91+
attn_mask = mint.broadcast_to(
92+
attn_mask, (attn_mask.shape[0], attn_mask.shape[1], query.shape[-2], key.shape[-2])
93+
)[:, :1, :, :]
94+
95+
return ops.operations.nn_ops.FlashAttentionScore(
96+
head_num=head_num, keep_prob=keep_prob, scale_value=scale, input_layout=input_layout
97+
)(query, key, value, None, None, None, attn_mask)[3]

mindone/diffusers/models/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from .transformer_lumina2 import Lumina2Transformer2DModel
2929
from .transformer_mochi import MochiTransformer3DModel
3030
from .transformer_omnigen import OmniGenTransformer2DModel
31+
from .transformer_prx import PRXTransformer2DModel
3132
from .transformer_qwenimage import QwenImageTransformer2DModel
3233
from .transformer_sd3 import SD3Transformer2DModel
3334
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel

0 commit comments

Comments
 (0)