Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/hunyuanvideo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,13 @@ Here is the development plan of the project:

| MindSpore | Ascend Driver | Firmware | CANN toolkit/kernel |
|:---------:|:-------------:|:-----------:|:-------------------:|
| 2.5.0 | 24.1.RC2 | 7.5.0.2.220 | 8.0.RC3.beta1 |
| 2.6.0 | 24.1.RC2 | 7.5.0.2.220 | 8.1.RC1 |
| 2.7.0 | 24.1.RC2 | 7.5.0.2.220 | 8.2.RC1 |

</div>

1. Install
[CANN 8.0.RC3.beta1](https://www.hiascend.com/developer/download/community/result?module=cann&cann=8.0.RC3.beta1)
[CANN 8.2.RC1](https://www.hiascend.com/developer/download/community/result?module=cann&cann=8.2.RC1)
and MindSpore according to the [official instructions](https://www.mindspore.cn/install).
2. Install requirements
```shell
Expand Down
4 changes: 2 additions & 2 deletions examples/hunyuanvideo/hyvideo/text_encoder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,7 @@ def encode(
if model_return_dict:
last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)]
else:
last_hidden_state = outputs[2][-(hidden_state_skip_layer + 1)]
last_hidden_state = outputs[1][-(hidden_state_skip_layer + 1)]
# last_hidden_state = outputs[0][-(hidden_state_skip_layer + 1)]
# Real last hidden state already has layer norm applied. So here we only apply it
# for intermediate layers.
Expand All @@ -307,7 +307,7 @@ def encode(
outputs_hidden_states = outputs.hidden_states
else:
last_hidden_state = outputs[self.key_idx]
outputs_hidden_states = outputs[2] if len(outputs) >= 3 else None # TODO: double-check if use t5
outputs_hidden_states = outputs[1] if len(outputs) >= 2 else None # TODO: double-check if use t5

# Remove hidden states of instruction tokens, only keep prompt tokens.
if self.use_template:
Expand Down
83 changes: 80 additions & 3 deletions examples/hunyuanvideo/hyvideo/vae/unet_causal_3d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,13 @@

import mindspore as ms
import mindspore.mint.nn.functional as F
from mindspore import mint, nn, ops
from mindspore import Parameter, Tensor, mint, nn, ops
from mindspore.common.initializer import initializer

from mindone.diffusers.models.activations import get_activation
from mindone.diffusers.models.attention_processor import Attention, SpatialNorm
from mindone.diffusers.models.normalization import AdaGroupNorm, GroupNorm, RMSNorm
from mindone.diffusers.models.layers_compat import group_norm
from mindone.diffusers.models.normalization import AdaGroupNorm, RMSNorm
from mindone.diffusers.utils import logging

logger = logging.get_logger(__name__) # pylint: disable=invalid-name
Expand All @@ -38,6 +40,81 @@
MAX_VALUE = 1e5


class GroupNorm(nn.Cell):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why not reuse the previous GroupNorm, e.g. the one in mindone.diffusers

r"""Applies Group Normalization over a mini-batch of inputs.

This layer implements the operation as described in
the paper `Group Normalization <https://arxiv.org/abs/1803.08494>`__

.. math::
y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta

The input channels are separated into :attr:`num_groups` groups, each containing
``num_channels / num_groups`` channels. :attr:`num_channels` must be divisible by
:attr:`num_groups`. The mean and standard-deviation are calculated
separately over the each group. :math:`\gamma` and :math:`\beta` are learnable
per-channel affine transform parameter vectors of size :attr:`num_channels` if
:attr:`affine` is ``True``.

This layer uses statistics computed from input data in both training and
evaluation modes.

Args:
num_groups (int): number of groups to separate the channels into
num_channels (int): number of channels expected in input
eps: a value added to the denominator for numerical stability. Default: 1e-5
affine: a boolean value that when set to ``True``, this module
has learnable per-channel affine parameters initialized to ones (for weights)
and zeros (for biases). Default: ``True``.

Shape:
- Input: :math:`(N, C, *)` where :math:`C=\text{num\_channels}`
- Output: :math:`(N, C, *)` (same shape as input)

Examples::

>>> input = mint.randn(20, 6, 10, 10)
>>> # Separate 6 channels into 3 groups
>>> m = GroupNorm(3, 6)
>>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm)
>>> m = GroupNorm(6, 6)
>>> # Put all 6 channels into a single group (equivalent with LayerNorm)
>>> m = GroupNorm(1, 6)
>>> # Activating the module
>>> output = m(input)
"""

num_groups: int
num_channels: int
eps: float
affine: bool

def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5, affine: bool = True, dtype=ms.float32):
super().__init__()
if num_channels % num_groups != 0:
raise ValueError("num_channels must be divisible by num_groups")

self.num_groups = num_groups
self.num_channels = num_channels
self.eps = eps
self.affine = affine
weight = initializer("ones", num_channels, dtype=dtype)
bias = initializer("zeros", num_channels, dtype=dtype)
if self.affine:
self.weight = Parameter(weight, name="weight")
self.bias = Parameter(bias, name="bias")
else:
self.weight = None
self.bias = None

def construct(self, x: Tensor):
if self.affine:
x = group_norm(x, self.num_groups, self.weight.to(x.dtype), self.bias.to(x.dtype), self.eps)
else:
x = group_norm(x, self.num_groups, self.weight, self.bias, self.eps)
return x


def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, batch_size: int = None, return_fa_mask: bool = False):
seq_len = n_frame * n_hw
mask = mint.full((seq_len, seq_len), float("-inf"), dtype=dtype)
Expand Down Expand Up @@ -413,7 +490,7 @@ def __init__(
conv_3d_out_channels = conv_3d_out_channels or out_channels
self.conv2 = CausalConv3d(out_channels, conv_3d_out_channels, kernel_size=3, stride=1)

self.nonlinearity = get_activation(non_linearity)()
self.nonlinearity = get_activation(non_linearity)

self.upsample = self.downsample = None
if self.up:
Expand Down
4 changes: 2 additions & 2 deletions examples/hunyuanvideo/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ imageio
imageio-ffmpeg
safetensors
mindcv==0.3.0
tokenizers==0.20.3
transformers==4.46.3
tokenizers==0.21.4
transformers==4.50.0
gradio
albumentations>=2.0
ftfy
Expand Down
8 changes: 4 additions & 4 deletions examples/hunyuanvideo/scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,18 +399,18 @@ def main(args):
# validation
val_group = parser.add_argument_group("Validation")
val_group.add_argument(
"valid.sampling_steps", type=int, default=10, help="Number of sampling steps for validation."
"--valid.sampling_steps", type=int, default=10, help="Number of sampling steps for validation."
)
val_group.add_argument("valid.frequency", type=int, default=1, help="Frequency of validation in steps.")
val_group.add_argument("--valid.frequency", type=int, default=1, help="Frequency of validation in steps.")
val_group.add_subclass_arguments(
ImageVideoDataset,
"valid.dataset",
"--valid.dataset",
skip={"frames_mask_generator", "t_compress_func"},
instantiate=False,
required=False,
)
val_group.add_function_arguments(
create_dataloader, "valid.dataloader", skip={"dataset", "transforms", "device_num", "rank_id"}
create_dataloader, "--valid.dataloader", skip={"dataset", "transforms", "device_num", "rank_id"}
)
parser.link_arguments("env.debug", "valid.dataloader.debug", apply_on="parse")

Expand Down