@@ -72,8 +72,8 @@ def dummy_func(*args, **kwargs):
7272_is_cuda = is_cuda ()
7373
7474if _is_hip :
75- from aiter import ActivationType
76- from aiter .fused_moe_bf16_asm import asm_moe , ck_moe_2stages , ck_moe_2stages_win4
75+ from aiter import ActivationType , QuantType
76+ from aiter .fused_moe_bf16_asm import asm_moe , ck_moe_2stages
7777 from aiter .ops .shuffle import shuffle_weight
7878
7979if not _is_cuda :
@@ -484,7 +484,7 @@ def create_weights(
484484 if self .quant_config .is_checkpoint_fp8_serialized :
485485 params_dtype = (
486486 torch .uint32
487- if get_bool_env_var ("USE_INT4_WEIGHT " )
487+ if get_bool_env_var ("SGLANG_INT4_WEIGHT " )
488488 else torch .float8_e4m3fn
489489 )
490490 tp_size = get_tensor_model_parallel_world_size ()
@@ -511,7 +511,7 @@ def create_weights(
511511 )
512512
513513 # WEIGHTS
514- if _is_hip and get_bool_env_var ("USE_INT4_WEIGHT " ):
514+ if _is_hip and get_bool_env_var ("SGLANG_INT4_WEIGHT " ):
515515 # INT4 MoE weight - INT32 packed
516516 w13_weight = torch .nn .Parameter (
517517 torch .empty (
@@ -585,7 +585,7 @@ def create_weights(
585585
586586 if (
587587 _is_hip
588- ): # and get_bool_env_var("CK_MOE "): TODO: add check back after triton kernel
588+ ): # and get_bool_env_var("SGLANG_AITER_MOE "): TODO: add check back after triton kernel
589589 # ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
590590 w13_weight_scale1 = torch .nn .Parameter (
591591 torch .ones (num_experts , 2 * intermediate_size , dtype = torch .float32 ),
@@ -612,7 +612,7 @@ def create_weights(
612612 set_weight_attrs (w13_weight_scale , extra_weight_attrs )
613613 set_weight_attrs (w2_weight_scale , extra_weight_attrs )
614614
615- if _is_hip and get_bool_env_var ("USE_INT4_WEIGHT " ):
615+ if _is_hip and get_bool_env_var ("SGLANG_INT4_WEIGHT " ):
616616 extra_weight_attrs .update (
617617 {"quant_method" : FusedMoeWeightScaleSupported .CHANNEL .value }
618618 )
@@ -644,7 +644,7 @@ def create_weights(
644644 layer .w2_input_scale = None
645645
646646 def process_weights_after_loading (self , layer : Module ) -> None :
647- if _is_hip and get_bool_env_var ("USE_INT4_WEIGHT " ):
647+ if _is_hip and get_bool_env_var ("SGLANG_INT4_WEIGHT " ):
648648 self .process_weights_hip_int4 (layer )
649649 return
650650
@@ -675,7 +675,7 @@ def process_weights_after_loading(self, layer: Module) -> None:
675675 )
676676 layer .w2_input_scale = None
677677
678- if get_bool_env_var ("CK_MOE " ):
678+ if get_bool_env_var ("SGLANG_AITER_MOE " ):
679679 # Pre-shuffle weights
680680 layer .w13_weight .data = shuffle_weight (
681681 layer .w13_weight .contiguous (), (16 , 16 )
@@ -798,17 +798,15 @@ def process_weights_after_loading(self, layer: Module) -> None:
798798 return
799799
800800 def process_weights_hip_int4 (self , layer : Module ):
801- # TODO: and get_bool_env_var("CK_MOE "): add after triton kernel added
801+ # TODO: and get_bool_env_var("SGLANG_AITER_MOE "): add after triton kernel added
802802 # INT4-FP8 (INT4 MoE Weight, FP8 Compute)
803803 # Weight Permutation
804804 layer .w13_weight = torch .nn .Parameter (
805- # permute_weight(layer.w13_weight.data),
806805 shuffle_weight (layer .w13_weight .data , (16 , 16 )),
807806 requires_grad = False ,
808807 )
809808 torch .cuda .empty_cache ()
810809 layer .w2_weight = torch .nn .Parameter (
811- # permute_weight(layer.w2_weight.data),
812810 shuffle_weight (layer .w2_weight .data , (16 , 16 )),
813811 requires_grad = False ,
814812 )
@@ -847,23 +845,21 @@ def process_weights_hip_scale_padding(self, layer: Module):
847845 padding_size , # Avoid circular import
848846 )
849847
850- if get_bool_env_var ("CK_MOE " ):
848+ if get_bool_env_var ("SGLANG_AITER_MOE " ):
851849 layer .w13_weight = torch .nn .Parameter (
852- # permute_weight(layer.w13_weight.data),
853850 shuffle_weight (layer .w13_weight .data , (16 , 16 )),
854851 requires_grad = False ,
855852 )
856853 torch .cuda .empty_cache ()
857854 layer .w2_weight = torch .nn .Parameter (
858- # permute_weight(layer.w2_weight.data),
859855 shuffle_weight (layer .w2_weight .data , (16 , 16 )),
860856 requires_grad = False ,
861857 )
862858 torch .cuda .empty_cache ()
863- # ROCm (CK_MOE ): using column-wise scaling
859+ # ROCm (SGLANG_AITER_MOE ): using column-wise scaling
864860 layer .w13_weight_scale1 *= layer .w13_weight_scale .unsqueeze (- 1 )
865861 layer .w2_weight_scale1 *= layer .w2_weight_scale .unsqueeze (- 1 )
866- elif get_bool_env_var ("MOE_PADDING " ):
862+ elif get_bool_env_var ("SGLANG_MOE_PADDING " ):
867863 # If ROCm, apply weight padding (min. Mem channel contention) only if set
868864 layer .w13_weight = torch .nn .Parameter (
869865 F .pad (layer .w13_weight .data , (0 , padding_size ), "constant" , 0 ),
@@ -912,15 +908,16 @@ def apply(
912908 )
913909
914910 if _is_hip :
915- if get_bool_env_var ("USE_INT4_WEIGHT " ):
916- # TODO: add triton kernel and add check get_bool_env_var("CK_MOE ")
911+ if get_bool_env_var ("SGLANG_INT4_WEIGHT " ):
912+ # TODO: add triton kernel and add check get_bool_env_var("SGLANG_AITER_MOE ")
917913 assert not no_combine , f"{ no_combine = } is not supported."
918- return ck_moe_2stages_win4 (
914+ return ck_moe_2stages (
919915 x ,
920916 layer .w13_weight ,
921917 layer .w2_weight ,
922918 topk_weights ,
923919 topk_ids ,
920+ QuantType .per_Token ,
924921 layer .w13_weight_scale1 ,
925922 layer .w2_weight_scale1 ,
926923 activation = (
@@ -930,13 +927,13 @@ def apply(
930927 ),
931928 )
932929
933- if get_bool_env_var ("CK_MOE " ):
930+ if get_bool_env_var ("SGLANG_AITER_MOE " ):
934931 assert not no_combine , f"{ no_combine = } is not supported."
935932 if self .block_quant :
936- # TODO(CK_MOE ): FP8 block_quant only supports 'silu' for the time-being.
933+ # TODO(SGLANG_AITER_MOE ): FP8 block_quant only supports 'silu' for the time-being.
937934 assert (
938935 activation == "silu"
939- ), f"CK_MOE : FP8 bloack_quant { activation = } will be supported later, unset CK_MOE "
936+ ), f"SGLANG_AITER_MOE : FP8 bloack_quant { activation = } will be supported later, unset SGLANG_AITER_MOE "
940937 return asm_moe (
941938 x ,
942939 layer .w13_weight ,
@@ -955,6 +952,7 @@ def apply(
955952 layer .w2_weight ,
956953 topk_weights ,
957954 topk_ids ,
955+ QuantType .per_Token ,
958956 layer .w13_weight_scale1 ,
959957 layer .w2_weight_scale1 ,
960958 activation = (
0 commit comments