Skip to content

Conversation

@guangzlu
Copy link
Contributor

Implement swizzle gemm for flux dit and wan dit.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @guangzlu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the performance of linear operations within the diffsynth_engine by integrating a new 'swizzled' GEMM utility. This optimization, which leverages hipblaslt, is now applied to both the Flux DiT and WAN DiT models. The change is designed to improve computational efficiency, particularly when working with FP8 precision, by temporarily overriding the standard torch.nn.functional.linear with highly optimized versions.

Highlights

  • New Swizzled GEMM Utility: Introduced a new utility module, aiter_linear.py, which provides optimized General Matrix Multiply (GEMM) operations using hipblaslt with 'swizzled' weights.
  • Flux DiT Integration: The Flux DiT model now utilizes the new use_swizzle_hipblaslt context manager within its forward method to enable performance-optimized linear operations.
  • WAN DiT Integration: The WAN DiT model has also been updated to incorporate the use_swizzle_hipblaslt context manager, benefiting from the same GEMM optimizations.
  • FP8 Linear Support: The new utility includes specific optimizations for FP8 linear operations, allowing for enhanced performance when fp8_linear_enabled is active.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @guangzlu, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a significant performance enhancement for the flux_dit and wan_dit models by implementing a swizzled General Matrix Multiply (GEMM) optimization. A new utility file provides a context manager that dynamically replaces standard PyTorch linear operations with highly optimized hipblaslt-based versions, specifically designed to boost efficiency, particularly when using FP8 linear computations. This change aims to improve the computational speed of these models on compatible hardware.

Highlights

  • New Swizzle GEMM Utility: A new utility module, aiter_linear.py, has been added to provide optimized General Matrix Multiply (GEMM) operations using hipblaslt for enhanced performance, especially with FP8 linear computations.
  • Flux DiT Integration: The flux_dit.py model now integrates the use_swizzle_hipblaslt context manager within its forward method, enabling the new swizzled GEMM optimization when applicable.
  • WAN DiT Integration: Similarly, the wan_dit.py model has been updated to utilize the use_swizzle_hipblaslt context manager in its forward method, applying the swizzled GEMM optimization.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a swizzle gemm optimization for the flux dit and wan dit models. The core logic is encapsulated in a new file, diffsynth_engine/utils/aiter_linear.py, which provides a context manager to monkey-patch torch.nn.functional.linear. The changes in the model files simply enable this new functionality. My review focuses on the new implementation, where I've identified a critical issue with a hardcoded device that could cause problems in multi-GPU environments. I have provided suggestions to fix this and to refactor the code for better maintainability.

Comment on lines +26 to +49
def optimized_linear(input, weight, bias=None, otype=torch.bfloat16,
scaleA=None, scaleB=None, device="cuda"):

input_flat = input.reshape(-1, input.shape[-1])

init_hipblas()

weight_preshuffle = shuffle_weight(weight.contiguous(), layout=(16, 16), use_int4=False).to(device)
output_flat = hipb_mm(
input_flat,
weight_preshuffle.t(),
bias=bias,
solution_index=-1,
out_dtype=otype,
scaleA=scaleA,
scaleB=scaleB,
scaleOut=None,
bpreshuffle=True
)

# Reshape output to match input dimensions
new_shape = input.shape[:-1] + (weight.shape[0],)
output = output_flat.view(new_shape)
return output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The device parameter is hardcoded to "cuda", which will cause errors if the model is run on a different device (e.g., cuda:1). The device should be dynamically determined from the input tensor. The function signature should also be updated to not include the device parameter, to better align with the original torch.nn.functional.linear.

def optimized_linear(input, weight, bias=None, otype=torch.bfloat16, 
                    scaleA=None, scaleB=None):

    input_flat = input.reshape(-1, input.shape[-1])
    
    init_hipblas()

    weight_preshuffle = shuffle_weight(weight.contiguous(), layout=(16, 16), use_int4=False).to(input.device)
    output_flat = hipb_mm(
        input_flat,
        weight_preshuffle.t(),
        bias=bias,
        solution_index=-1,
        out_dtype=otype,
        scaleA=scaleA,
        scaleB=scaleB,
        scaleOut=None,
        bpreshuffle=True
    )
    
    # Reshape output to match input dimensions
    new_shape = input.shape[:-1] + (weight.shape[0],)
    output = output_flat.view(new_shape)
    return output

Comment on lines +52 to +100
def optimized_linear_fp8(input, weight, bias=None, otype=torch.bfloat16,
scaleA=None, scaleB=None, device="cuda"):

input_flat = input.reshape(-1, input.shape[-1])

if use_scale_for_fp8:

input_flat, a_scale = per_tensor_quant_hip(input_flat, quant_dtype=DTYPE_FP8)
weight = weight.to(DTYPE_FP8)

init_hipblas()

weight_preshuffle = shuffle_weight(weight.contiguous(), layout=(16, 16)).to(device)
output_flat = hipb_mm(
input_flat,
weight_preshuffle.t(),
bias=bias,
solution_index=-1,
out_dtype=otype,
scaleA=a_scale,
scaleB=scaleB,
scaleOut=None,
bpreshuffle=True
)

else:
input_flat = input_flat.to(DTYPE_FP8)
weight = weight.to(DTYPE_FP8)

init_hipblas()

weight_preshuffle = shuffle_weight(weight.contiguous(), layout=(16, 16)).to(device)
output_flat = hipb_mm(
input_flat,
weight_preshuffle.t(),
bias=bias,
solution_index=-1,
out_dtype=otype,
scaleA=scaleA,
scaleB=scaleB,
scaleOut=None,
bpreshuffle=True
)


# Reshape output to match input dimensions
new_shape = input.shape[:-1] + (weight.shape[0],)
output = output_flat.view(new_shape)
return output
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This function has a couple of issues:

  1. Similar to optimized_linear, it hardcodes device="cuda", which can lead to runtime errors on multi-GPU setups. The device should be inferred from the input tensor.
  2. There is significant code duplication between the if and else branches. This can be refactored to improve readability and maintainability.
def optimized_linear_fp8(input, weight, bias=None, otype=torch.bfloat16,
                    scaleA=None, scaleB=None):
    
    input_flat = input.reshape(-1, input.shape[-1])
    device = input.device

    if use_scale_for_fp8:
        input_flat, a_scale = per_tensor_quant_hip(input_flat, quant_dtype=DTYPE_FP8)
        scaleA = a_scale
    else:
        input_flat = input_flat.to(DTYPE_FP8)

    weight = weight.to(DTYPE_FP8)
    init_hipblas()
    weight_preshuffle = shuffle_weight(weight.contiguous(), layout=(16, 16)).to(device)
    
    output_flat = hipb_mm(
        input_flat,
        weight_preshuffle.t(),
        bias=bias,
        solution_index=-1,
        out_dtype=otype,
        scaleA=scaleA,
        scaleB=scaleB,
        scaleOut=None,
        bpreshuffle=True
    )


    # Reshape output to match input dimensions
    new_shape = input.shape[:-1] + (weight.shape[0],)
    output = output_flat.view(new_shape)
    return output

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces swizzle GEMM optimizations for the flux dit and wan dit models, aiming to enhance performance, particularly on AMD hardware. This is accomplished by monkey-patching torch.nn.functional.linear within a new context manager. The implementation is sound, but I have identified a few areas for improvement to increase code quality, robustness, and configurability. My review includes suggestions to remove an unused import, refactor duplicated code, and address a hardcoded device name in aiter_linear.py which could lead to issues in multi-GPU setups. Additionally, I've recommended making the swizzle optimization a configurable option in flux_dit.py and wan_dit.py rather than being hardcoded.

Comment on lines +26 to +33
def optimized_linear(input, weight, bias=None, otype=torch.bfloat16,
scaleA=None, scaleB=None, device="cuda"):

input_flat = input.reshape(-1, input.shape[-1])

init_hipblas()

weight_preshuffle = shuffle_weight(weight.contiguous(), layout=(16, 16), use_int4=False).to(device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The device parameter defaults to a hardcoded "cuda". This can cause runtime errors if the model is deployed on a different device (e.g., cuda:1). It's more robust to determine the device from the input tensor. Additionally, adding a non-standard device parameter to a function that monkey-patches F.linear can be risky. I suggest removing it and using input.device instead.

Suggested change
def optimized_linear(input, weight, bias=None, otype=torch.bfloat16,
scaleA=None, scaleB=None, device="cuda"):
input_flat = input.reshape(-1, input.shape[-1])
init_hipblas()
weight_preshuffle = shuffle_weight(weight.contiguous(), layout=(16, 16), use_int4=False).to(device)
def optimized_linear(input, weight, bias=None, otype=torch.bfloat16,
scaleA=None, scaleB=None):
input_flat = input.reshape(-1, input.shape[-1])
init_hipblas()
weight_preshuffle = shuffle_weight(weight.contiguous(), layout=(16, 16), use_int4=False).to(input.device)

Comment on lines +52 to +94
def optimized_linear_fp8(input, weight, bias=None, otype=torch.bfloat16,
scaleA=None, scaleB=None, device="cuda"):

input_flat = input.reshape(-1, input.shape[-1])

if use_scale_for_fp8:

input_flat, a_scale = per_tensor_quant_hip(input_flat, quant_dtype=DTYPE_FP8)
weight = weight.to(DTYPE_FP8)

init_hipblas()

weight_preshuffle = shuffle_weight(weight.contiguous(), layout=(16, 16)).to(device)
output_flat = hipb_mm(
input_flat,
weight_preshuffle.t(),
bias=bias,
solution_index=-1,
out_dtype=otype,
scaleA=a_scale,
scaleB=scaleB,
scaleOut=None,
bpreshuffle=True
)

else:
input_flat = input_flat.to(DTYPE_FP8)
weight = weight.to(DTYPE_FP8)

init_hipblas()

weight_preshuffle = shuffle_weight(weight.contiguous(), layout=(16, 16)).to(device)
output_flat = hipb_mm(
input_flat,
weight_preshuffle.t(),
bias=bias,
solution_index=-1,
out_dtype=otype,
scaleA=scaleA,
scaleB=scaleB,
scaleOut=None,
bpreshuffle=True
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This function can be improved in two ways:

  1. Hardcoded Device: Similar to optimized_linear, the device parameter is hardcoded to "cuda", which can lead to errors in a multi-GPU environment. It's better to derive the device from input.device.
  2. Code Duplication: There's significant code duplication between the if use_scale_for_fp8: and else branches. Refactoring this will improve readability and maintainability.

I've provided a suggestion that addresses both points.

    def optimized_linear_fp8(input, weight, bias=None, otype=torch.bfloat16,
                        scaleA=None, scaleB=None):
        
        input_flat = input.reshape(-1, input.shape[-1])

        a_scale_for_mm = scaleA
        if use_scale_for_fp8:
            input_flat, a_scale_for_mm = per_tensor_quant_hip(input_flat, quant_dtype=DTYPE_FP8)
        else:
            input_flat = input_flat.to(DTYPE_FP8)

        weight = weight.to(DTYPE_FP8)

        init_hipblas()

        weight_preshuffle = shuffle_weight(weight.contiguous(), layout=(16, 16)).to(input.device)
        output_flat = hipb_mm(
            input_flat,
            weight_preshuffle.t(),
            bias=bias,
            solution_index=-1,
            out_dtype=otype,
            scaleA=a_scale_for_mm,
            scaleB=scaleB,
            scaleOut=None,
            bpreshuffle=True
        )

use_cfg = hidden_states.shape[0] > 1
with (
fp8_inference(fp8_linear_enabled),
use_swizzle_hipblaslt(swizzle=True, use_fp8_linear=fp8_linear_enabled),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The swizzle parameter is hardcoded to True. This makes it difficult to disable this optimization without changing the code. Consider making this configurable, for example, through a class attribute self.swizzle_enabled that can be set during model initialization.

use_cfg = x.shape[0] > 1
with (
fp8_inference(fp8_linear_enabled),
use_swizzle_hipblaslt(swizzle=True, use_fp8_linear=fp8_linear_enabled),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The swizzle parameter is hardcoded to True. This makes it difficult to disable this optimization without changing the code. Consider making this configurable, for example, through a class attribute self.swizzle_enabled that can be set during model initialization.

import torch.nn.functional as F
from functools import lru_cache
from aiter import hipb_mm, hipb_create_extension, per_tensor_quant_hip
from aiter.tuned_gemm import tgemm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This import of tgemm is not used within the file and can be safely removed to improve code cleanliness.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant