Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions diffsynth_engine/models/flux/flux_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from diffsynth_engine.models.basic import attention as attention_ops
from diffsynth_engine.utils.gguf import gguf_inference
from diffsynth_engine.utils.fp8_linear import fp8_inference
from diffsynth_engine.utils.aiter_linear import use_swizzle_hipblaslt
from diffsynth_engine.utils.constants import FLUX_DIT_CONFIG_FILE
from diffsynth_engine.utils.parallel import (
cfg_parallel,
Expand Down Expand Up @@ -409,6 +410,7 @@ def forward(
use_cfg = hidden_states.shape[0] > 1
with (
fp8_inference(fp8_linear_enabled),
use_swizzle_hipblaslt(swizzle=True, use_fp8_linear=fp8_linear_enabled),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The swizzle parameter is hardcoded to True. This makes it difficult to disable this optimization without changing the code. Consider making this configurable, for example, through a class attribute self.swizzle_enabled that can be set during model initialization.

gguf_inference(),
cfg_parallel(
(
Expand Down
2 changes: 2 additions & 0 deletions diffsynth_engine/models/wan/wan_dit.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from diffsynth_engine.utils.gguf import gguf_inference
from diffsynth_engine.utils.fp8_linear import fp8_inference
from diffsynth_engine.utils.aiter_linear import use_swizzle_hipblaslt
from diffsynth_engine.utils.parallel import (
cfg_parallel,
cfg_parallel_unshard,
Expand Down Expand Up @@ -390,6 +391,7 @@ def forward(
use_cfg = x.shape[0] > 1
with (
fp8_inference(fp8_linear_enabled),
use_swizzle_hipblaslt(swizzle=True, use_fp8_linear=fp8_linear_enabled),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The swizzle parameter is hardcoded to True. This makes it difficult to disable this optimization without changing the code. Consider making this configurable, for example, through a class attribute self.swizzle_enabled that can be set during model initialization.

gguf_inference(),
cfg_parallel((x, context, timestep, clip_feature, y), use_cfg=use_cfg),
):
Expand Down
110 changes: 110 additions & 0 deletions diffsynth_engine/utils/aiter_linear.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import lru_cache
from aiter import hipb_mm, hipb_create_extension, per_tensor_quant_hip
from aiter.tuned_gemm import tgemm
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This import of tgemm is not used within the file and can be safely removed to improve code cleanliness.

from aiter.ops.shuffle import shuffle_weight
from diffsynth_engine.utils.platform import DTYPE_FP8
from contextlib import contextmanager


@lru_cache(maxsize=1)
def init_hipblas():
hipb_create_extension()


@contextmanager
def use_swizzle_hipblaslt(swizzle=True, use_fp8_linear=True, use_scale_for_fp8=False):
if not swizzle:
yield
return

# Preserve original F.linear
_original_linear = F.linear

def optimized_linear(input, weight, bias=None, otype=torch.bfloat16,
scaleA=None, scaleB=None, device="cuda"):

input_flat = input.reshape(-1, input.shape[-1])

init_hipblas()

weight_preshuffle = shuffle_weight(weight.contiguous(), layout=(16, 16), use_int4=False).to(device)
Comment on lines +26 to +33
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The device parameter defaults to a hardcoded "cuda". This can cause runtime errors if the model is deployed on a different device (e.g., cuda:1). It's more robust to determine the device from the input tensor. Additionally, adding a non-standard device parameter to a function that monkey-patches F.linear can be risky. I suggest removing it and using input.device instead.

Suggested change
def optimized_linear(input, weight, bias=None, otype=torch.bfloat16,
scaleA=None, scaleB=None, device="cuda"):
input_flat = input.reshape(-1, input.shape[-1])
init_hipblas()
weight_preshuffle = shuffle_weight(weight.contiguous(), layout=(16, 16), use_int4=False).to(device)
def optimized_linear(input, weight, bias=None, otype=torch.bfloat16,
scaleA=None, scaleB=None):
input_flat = input.reshape(-1, input.shape[-1])
init_hipblas()
weight_preshuffle = shuffle_weight(weight.contiguous(), layout=(16, 16), use_int4=False).to(input.device)

output_flat = hipb_mm(
input_flat,
weight_preshuffle.t(),
bias=bias,
solution_index=-1,
out_dtype=otype,
scaleA=scaleA,
scaleB=scaleB,
scaleOut=None,
bpreshuffle=True
)

# Reshape output to match input dimensions
new_shape = input.shape[:-1] + (weight.shape[0],)
output = output_flat.view(new_shape)
return output
Comment on lines +26 to +49
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The device parameter is hardcoded to "cuda", which will cause errors if the model is run on a different device (e.g., cuda:1). The device should be dynamically determined from the input tensor. The function signature should also be updated to not include the device parameter, to better align with the original torch.nn.functional.linear.

def optimized_linear(input, weight, bias=None, otype=torch.bfloat16, 
                    scaleA=None, scaleB=None):

    input_flat = input.reshape(-1, input.shape[-1])
    
    init_hipblas()

    weight_preshuffle = shuffle_weight(weight.contiguous(), layout=(16, 16), use_int4=False).to(input.device)
    output_flat = hipb_mm(
        input_flat,
        weight_preshuffle.t(),
        bias=bias,
        solution_index=-1,
        out_dtype=otype,
        scaleA=scaleA,
        scaleB=scaleB,
        scaleOut=None,
        bpreshuffle=True
    )
    
    # Reshape output to match input dimensions
    new_shape = input.shape[:-1] + (weight.shape[0],)
    output = output_flat.view(new_shape)
    return output



def optimized_linear_fp8(input, weight, bias=None, otype=torch.bfloat16,
scaleA=None, scaleB=None, device="cuda"):

input_flat = input.reshape(-1, input.shape[-1])

if use_scale_for_fp8:

input_flat, a_scale = per_tensor_quant_hip(input_flat, quant_dtype=DTYPE_FP8)
weight = weight.to(DTYPE_FP8)

init_hipblas()

weight_preshuffle = shuffle_weight(weight.contiguous(), layout=(16, 16)).to(device)
output_flat = hipb_mm(
input_flat,
weight_preshuffle.t(),
bias=bias,
solution_index=-1,
out_dtype=otype,
scaleA=a_scale,
scaleB=scaleB,
scaleOut=None,
bpreshuffle=True
)

else:
input_flat = input_flat.to(DTYPE_FP8)
weight = weight.to(DTYPE_FP8)

init_hipblas()

weight_preshuffle = shuffle_weight(weight.contiguous(), layout=(16, 16)).to(device)
output_flat = hipb_mm(
input_flat,
weight_preshuffle.t(),
bias=bias,
solution_index=-1,
out_dtype=otype,
scaleA=scaleA,
scaleB=scaleB,
scaleOut=None,
bpreshuffle=True
)
Comment on lines +52 to +94
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This function can be improved in two ways:

  1. Hardcoded Device: Similar to optimized_linear, the device parameter is hardcoded to "cuda", which can lead to errors in a multi-GPU environment. It's better to derive the device from input.device.
  2. Code Duplication: There's significant code duplication between the if use_scale_for_fp8: and else branches. Refactoring this will improve readability and maintainability.

I've provided a suggestion that addresses both points.

    def optimized_linear_fp8(input, weight, bias=None, otype=torch.bfloat16,
                        scaleA=None, scaleB=None):
        
        input_flat = input.reshape(-1, input.shape[-1])

        a_scale_for_mm = scaleA
        if use_scale_for_fp8:
            input_flat, a_scale_for_mm = per_tensor_quant_hip(input_flat, quant_dtype=DTYPE_FP8)
        else:
            input_flat = input_flat.to(DTYPE_FP8)

        weight = weight.to(DTYPE_FP8)

        init_hipblas()

        weight_preshuffle = shuffle_weight(weight.contiguous(), layout=(16, 16)).to(input.device)
        output_flat = hipb_mm(
            input_flat,
            weight_preshuffle.t(),
            bias=bias,
            solution_index=-1,
            out_dtype=otype,
            scaleA=a_scale_for_mm,
            scaleB=scaleB,
            scaleOut=None,
            bpreshuffle=True
        )



# Reshape output to match input dimensions
new_shape = input.shape[:-1] + (weight.shape[0],)
output = output_flat.view(new_shape)
return output
Comment on lines +52 to +100
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This function has a couple of issues:

  1. Similar to optimized_linear, it hardcodes device="cuda", which can lead to runtime errors on multi-GPU setups. The device should be inferred from the input tensor.
  2. There is significant code duplication between the if and else branches. This can be refactored to improve readability and maintainability.
def optimized_linear_fp8(input, weight, bias=None, otype=torch.bfloat16,
                    scaleA=None, scaleB=None):
    
    input_flat = input.reshape(-1, input.shape[-1])
    device = input.device

    if use_scale_for_fp8:
        input_flat, a_scale = per_tensor_quant_hip(input_flat, quant_dtype=DTYPE_FP8)
        scaleA = a_scale
    else:
        input_flat = input_flat.to(DTYPE_FP8)

    weight = weight.to(DTYPE_FP8)
    init_hipblas()
    weight_preshuffle = shuffle_weight(weight.contiguous(), layout=(16, 16)).to(device)
    
    output_flat = hipb_mm(
        input_flat,
        weight_preshuffle.t(),
        bias=bias,
        solution_index=-1,
        out_dtype=otype,
        scaleA=scaleA,
        scaleB=scaleB,
        scaleOut=None,
        bpreshuffle=True
    )


    # Reshape output to match input dimensions
    new_shape = input.shape[:-1] + (weight.shape[0],)
    output = output_flat.view(new_shape)
    return output


if use_fp8_linear:
F.linear = optimized_linear_fp8
else:
F.linear = optimized_linear

yield
F.linear = _original_linear