From daea787e5d1396cca8140a65d7627cce5c3b4ee9 Mon Sep 17 00:00:00 2001 From: fighting-ye <1138455646@qq.com> Date: Mon, 20 Oct 2025 10:10:12 +0800 Subject: [PATCH] HunyuanVideo is compatible with MindSpore 2.6 and 2.7 --- examples/hunyuanvideo/README.md | 5 +- .../hyvideo/text_encoder/__init__.py | 4 +- .../hyvideo/vae/unet_causal_3d_blocks.py | 83 ++++++++++++++++++- examples/hunyuanvideo/requirements.txt | 4 +- examples/hunyuanvideo/scripts/train.py | 8 +- 5 files changed, 91 insertions(+), 13 deletions(-) diff --git a/examples/hunyuanvideo/README.md b/examples/hunyuanvideo/README.md index a6bf02bcb4..097951bd63 100644 --- a/examples/hunyuanvideo/README.md +++ b/examples/hunyuanvideo/README.md @@ -30,12 +30,13 @@ Here is the development plan of the project: | MindSpore | Ascend Driver | Firmware | CANN toolkit/kernel | |:---------:|:-------------:|:-----------:|:-------------------:| -| 2.5.0 | 24.1.RC2 | 7.5.0.2.220 | 8.0.RC3.beta1 | +| 2.6.0 | 24.1.RC2 | 7.5.0.2.220 | 8.1.RC1 | +| 2.7.0 | 24.1.RC2 | 7.5.0.2.220 | 8.2.RC1 | 1. Install - [CANN 8.0.RC3.beta1](https://www.hiascend.com/developer/download/community/result?module=cann&cann=8.0.RC3.beta1) + [CANN 8.2.RC1](https://www.hiascend.com/developer/download/community/result?module=cann&cann=8.2.RC1) and MindSpore according to the [official instructions](https://www.mindspore.cn/install). 2. Install requirements ```shell diff --git a/examples/hunyuanvideo/hyvideo/text_encoder/__init__.py b/examples/hunyuanvideo/hyvideo/text_encoder/__init__.py index b76817002c..c6ee74d444 100644 --- a/examples/hunyuanvideo/hyvideo/text_encoder/__init__.py +++ b/examples/hunyuanvideo/hyvideo/text_encoder/__init__.py @@ -295,7 +295,7 @@ def encode( if model_return_dict: last_hidden_state = outputs.hidden_states[-(hidden_state_skip_layer + 1)] else: - last_hidden_state = outputs[2][-(hidden_state_skip_layer + 1)] + last_hidden_state = outputs[1][-(hidden_state_skip_layer + 1)] # last_hidden_state = outputs[0][-(hidden_state_skip_layer + 1)] # Real last hidden state already has layer norm applied. So here we only apply it # for intermediate layers. @@ -307,7 +307,7 @@ def encode( outputs_hidden_states = outputs.hidden_states else: last_hidden_state = outputs[self.key_idx] - outputs_hidden_states = outputs[2] if len(outputs) >= 3 else None # TODO: double-check if use t5 + outputs_hidden_states = outputs[1] if len(outputs) >= 2 else None # TODO: double-check if use t5 # Remove hidden states of instruction tokens, only keep prompt tokens. if self.use_template: diff --git a/examples/hunyuanvideo/hyvideo/vae/unet_causal_3d_blocks.py b/examples/hunyuanvideo/hyvideo/vae/unet_causal_3d_blocks.py index e261e93755..c124951d83 100644 --- a/examples/hunyuanvideo/hyvideo/vae/unet_causal_3d_blocks.py +++ b/examples/hunyuanvideo/hyvideo/vae/unet_causal_3d_blocks.py @@ -25,11 +25,13 @@ import mindspore as ms import mindspore.mint.nn.functional as F -from mindspore import mint, nn, ops +from mindspore import Parameter, Tensor, mint, nn, ops +from mindspore.common.initializer import initializer from mindone.diffusers.models.activations import get_activation from mindone.diffusers.models.attention_processor import Attention, SpatialNorm -from mindone.diffusers.models.normalization import AdaGroupNorm, GroupNorm, RMSNorm +from mindone.diffusers.models.layers_compat import group_norm +from mindone.diffusers.models.normalization import AdaGroupNorm, RMSNorm from mindone.diffusers.utils import logging logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -38,6 +40,81 @@ MAX_VALUE = 1e5 +class GroupNorm(nn.Cell): + r"""Applies Group Normalization over a mini-batch of inputs. + + This layer implements the operation as described in + the paper `Group Normalization `__ + + .. math:: + y = \frac{x - \mathrm{E}[x]}{ \sqrt{\mathrm{Var}[x] + \epsilon}} * \gamma + \beta + + The input channels are separated into :attr:`num_groups` groups, each containing + ``num_channels / num_groups`` channels. :attr:`num_channels` must be divisible by + :attr:`num_groups`. The mean and standard-deviation are calculated + separately over the each group. :math:`\gamma` and :math:`\beta` are learnable + per-channel affine transform parameter vectors of size :attr:`num_channels` if + :attr:`affine` is ``True``. + + This layer uses statistics computed from input data in both training and + evaluation modes. + + Args: + num_groups (int): number of groups to separate the channels into + num_channels (int): number of channels expected in input + eps: a value added to the denominator for numerical stability. Default: 1e-5 + affine: a boolean value that when set to ``True``, this module + has learnable per-channel affine parameters initialized to ones (for weights) + and zeros (for biases). Default: ``True``. + + Shape: + - Input: :math:`(N, C, *)` where :math:`C=\text{num\_channels}` + - Output: :math:`(N, C, *)` (same shape as input) + + Examples:: + + >>> input = mint.randn(20, 6, 10, 10) + >>> # Separate 6 channels into 3 groups + >>> m = GroupNorm(3, 6) + >>> # Separate 6 channels into 6 groups (equivalent with InstanceNorm) + >>> m = GroupNorm(6, 6) + >>> # Put all 6 channels into a single group (equivalent with LayerNorm) + >>> m = GroupNorm(1, 6) + >>> # Activating the module + >>> output = m(input) + """ + + num_groups: int + num_channels: int + eps: float + affine: bool + + def __init__(self, num_groups: int, num_channels: int, eps: float = 1e-5, affine: bool = True, dtype=ms.float32): + super().__init__() + if num_channels % num_groups != 0: + raise ValueError("num_channels must be divisible by num_groups") + + self.num_groups = num_groups + self.num_channels = num_channels + self.eps = eps + self.affine = affine + weight = initializer("ones", num_channels, dtype=dtype) + bias = initializer("zeros", num_channels, dtype=dtype) + if self.affine: + self.weight = Parameter(weight, name="weight") + self.bias = Parameter(bias, name="bias") + else: + self.weight = None + self.bias = None + + def construct(self, x: Tensor): + if self.affine: + x = group_norm(x, self.num_groups, self.weight.to(x.dtype), self.bias.to(x.dtype), self.eps) + else: + x = group_norm(x, self.num_groups, self.weight, self.bias, self.eps) + return x + + def prepare_causal_attention_mask(n_frame: int, n_hw: int, dtype, batch_size: int = None, return_fa_mask: bool = False): seq_len = n_frame * n_hw mask = mint.full((seq_len, seq_len), float("-inf"), dtype=dtype) @@ -413,7 +490,7 @@ def __init__( conv_3d_out_channels = conv_3d_out_channels or out_channels self.conv2 = CausalConv3d(out_channels, conv_3d_out_channels, kernel_size=3, stride=1) - self.nonlinearity = get_activation(non_linearity)() + self.nonlinearity = get_activation(non_linearity) self.upsample = self.downsample = None if self.up: diff --git a/examples/hunyuanvideo/requirements.txt b/examples/hunyuanvideo/requirements.txt index 8298a98164..a4ad3cee9a 100644 --- a/examples/hunyuanvideo/requirements.txt +++ b/examples/hunyuanvideo/requirements.txt @@ -7,8 +7,8 @@ imageio imageio-ffmpeg safetensors mindcv==0.3.0 -tokenizers==0.20.3 -transformers==4.46.3 +tokenizers==0.21.4 +transformers==4.50.0 gradio albumentations>=2.0 ftfy diff --git a/examples/hunyuanvideo/scripts/train.py b/examples/hunyuanvideo/scripts/train.py index 68c78d894b..425d4225ae 100644 --- a/examples/hunyuanvideo/scripts/train.py +++ b/examples/hunyuanvideo/scripts/train.py @@ -399,18 +399,18 @@ def main(args): # validation val_group = parser.add_argument_group("Validation") val_group.add_argument( - "valid.sampling_steps", type=int, default=10, help="Number of sampling steps for validation." + "--valid.sampling_steps", type=int, default=10, help="Number of sampling steps for validation." ) - val_group.add_argument("valid.frequency", type=int, default=1, help="Frequency of validation in steps.") + val_group.add_argument("--valid.frequency", type=int, default=1, help="Frequency of validation in steps.") val_group.add_subclass_arguments( ImageVideoDataset, - "valid.dataset", + "--valid.dataset", skip={"frames_mask_generator", "t_compress_func"}, instantiate=False, required=False, ) val_group.add_function_arguments( - create_dataloader, "valid.dataloader", skip={"dataset", "transforms", "device_num", "rank_id"} + create_dataloader, "--valid.dataloader", skip={"dataset", "transforms", "device_num", "rank_id"} ) parser.link_arguments("env.debug", "valid.dataloader.debug", apply_on="parse")