Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 26 additions & 7 deletions bitblas/base/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tvm.contrib.popen_pool import PopenPoolExecutor, StatusKind
from concurrent.futures import ThreadPoolExecutor, as_completed
import numpy as np
from typing import List, Tuple, Optional, Dict, Union, Literal
from typing import List, Tuple, Optional, Dict, Union, Literal, Callable
from tvm import tir, IRModule
from tvm.runtime import Module
from tvm.tir import Schedule
Expand Down Expand Up @@ -455,13 +455,13 @@ def create_dispatch_func(g_var: str, func: tir.PrimFunc, refactored_funcs: List[


def create_dispatch_mod(g_var: str, original_func: tir.PrimFunc,
specialized_funcs: List[tir.PrimFunc]) -> IRModule:
specialized_funcs: List[tir.PrimFunc], function_symbols) -> IRModule:
dispatch_mod: IRModule = tvm.IRModule()
g_var_supply = GlobalVarSupply(dispatch_mod)
refactored_funcs = []
for func in specialized_funcs:
for f_var, func in zip(function_symbols, specialized_funcs):
params, buffers_to_declare = collect_buffers_to_declare(func)
global_symbol, device_func = refactor_specialized_func(g_var, func, params,
global_symbol, device_func = refactor_specialized_func(f_var, func, params,
buffers_to_declare)
global_symbol = g_var_supply.fresh_global(global_symbol, add_prefix=False)
dispatch_mod[global_symbol] = device_func
Expand All @@ -478,6 +478,7 @@ def fast_tune_with_dynamic_range(
parallel_build: bool = True,
global_symbol: Optional[str] = None,
dynamic_range: Optional[Dict[str, List[int]]] = None,
kernel_name_generator: Optional[Callable] = None,
) -> IRModule:
if dynamic_range is None:
dynamic_range = {}
Expand Down Expand Up @@ -517,12 +518,30 @@ def fast_tune_with_dynamic_range(
# Convert the Cartesian product to a list of dictionaries
specialize_items: List[Dict] = [dict(zip(opt_shapes.keys(), values)) for values in product_list]

function_symbols: List[str] = []
specilized_tuned_funcs: List[tir.PrimFunc] = []
for item in specialize_items:
func = func.with_attr("opt_shapes", item)
_, best = fast_tune(func, target, topk, parallel_build)
if best is None:
return None
specilized_tuned_funcs.append(best.sch.mod["main"])

return create_dispatch_mod(global_symbol, func, specilized_tuned_funcs)
specialized_func = best.sch.mod["main"]
function_symbol = global_symbol
if kernel_name_generator is not None:
scheduled_mod = best.sch.mod
best_hint = best.config
assert len(scheduled_mod.get_global_vars()) == 1, (
"The optimized module should only have one global variable for default schedule.")
assert "main" in scheduled_mod, (
"The optimized module should have a function named 'main' for default schedule.")
default_kernal_name = kernel_name_generator.generate(best_hint)
specialized_func = scheduled_mod["main"].with_attr("global_symbol", default_kernal_name)
function_symbol = default_kernal_name

function_symbols.append(function_symbol)
specilized_tuned_funcs.append(specialized_func)

assert global_symbol is not None, "The global_symbol should not be None"
assert len(function_symbols) == len(specilized_tuned_funcs), (
"The length of global_symbols should be equal to the length of specilized_tuned_funcs")
return create_dispatch_mod(global_symbol, func, specilized_tuned_funcs, function_symbols)
75 changes: 36 additions & 39 deletions bitblas/builder/wrapper/tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,22 @@

logger = logging.getLogger(__name__)

PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY = """
cudaFuncSetAttribute({}, cudaFuncAttributeMaxDynamicSharedMemorySize, {});
"""

PREDEF_INIT_FUNC = """
extern "C" void init() {{
{}
}}
"""

PREDEF_HOST_FUNC = """
extern "C" void call({}) {{
{}
}}
"""


class TIRCUDASourceWrapper(object):
_TYPE_MAP = {
Expand Down Expand Up @@ -77,16 +93,11 @@ def get_cuda_init_func(self):
call_str = """"""
# If dynamic shared memory buffer is specified, prepare the cudaFuncSetAttribute call
if self.dynamic_smem_buf is not None:
call_str = """
cudaFuncSetAttribute({},
cudaFuncAttributeMaxDynamicSharedMemorySize, {});
""".format(self.function_name, self.dynamic_smem_buf)
call_str = (
PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY.format(self.function_name,
self.dynamic_smem_buf))
# Format the initialization function using the call_str
init_funcs = """
extern "C" void init() {{
{}
}}
""".format(call_str)
init_funcs = PREDEF_INIT_FUNC.format(call_str)
return init_funcs

def update_lib_code(self, code: str):
Expand Down Expand Up @@ -162,18 +173,19 @@ def legalize_c(p):
call_str += "{}<<<{}, {}, {}, stream>>>({});".format(function_name, grid_str, block_str,
smem_str, call_args)
# Create the host function wrapper for the CUDA kernel
host_func = """
extern "C" void call({}) {{
{}
}}
""".format(def_args, call_str)
host_func = PREDEF_HOST_FUNC.format(def_args, call_str)
# Combine the source, initialization function, and host function to form the complete library code
lib_code = self.source + init_func + host_func
return lib_code

@property
def prim_func(self):
return self.mod["main"]
if len(self.mod.get_global_vars()) == 1:
return self.mod[self.mod.get_global_vars()[0]]
elif "main" in self.mod:
return self.mod["main"]
else:
raise ValueError("Unable to determine primary function.")


class TIRCUDASourceWrapperWithDynamic(TIRCUDASourceWrapper):
Expand All @@ -188,16 +200,10 @@ def get_cuda_init_func(self):
for function_name, dynamic_smem_buf in self.dynamic_smem_buf.items():
if dynamic_smem_buf is not None:
# Format the cudaFuncSetAttribute call for dynamic shared memory
call_str += """
cudaFuncSetAttribute({},
cudaFuncAttributeMaxDynamicSharedMemorySize, {});
""".format(function_name, dynamic_smem_buf)
call_str += (
PREDEF_ARRTIBUTE_SET_DYNAMIC_MEMORY.format(function_name, dynamic_smem_buf))
# Define the init function that will set the attributes for each kernel
init_funcs = """
extern "C" void init() {{
{}
}}
""".format(call_str)
init_funcs = PREDEF_INIT_FUNC.format(call_str)
return init_funcs

def create_dispatch_func(self, code, function_informations):
Expand Down Expand Up @@ -278,8 +284,8 @@ def legalize_c(p):
(symbolic,) = list(dynamic_symbolic_set)
range_str = opt_shapes[symbolic]
if last_range == 0:
call_str = "if ({} == 0) return; \n".format(symbolic,)
call_str += "if ({} <= {}) {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format(
call_str = " if ({} == 0) return; \n".format(symbolic,)
call_str += " if ({} <= {}) {{\n {}<<<{}, {}, {}, stream>>>({}); \n }}\n".format(
symbolic,
range_str,
function_name,
Expand All @@ -289,7 +295,7 @@ def legalize_c(p):
call_args,
)
else:
call_str = "\t\telse if ({} <= {}) {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format(
call_str = " else if ({} <= {}) {{\n {}<<<{}, {}, {}, stream>>>({}); \n }}\n".format(
symbolic,
range_str,
function_name,
Expand All @@ -299,18 +305,13 @@ def legalize_c(p):
call_args,
)
if last_range == num_items - 1:
call_str += (
"\t\telse {{\n\t\t\t {}<<<{}, {}, {}, stream>>>({}); \n\t\t}}\n".format(
function_name, grid_str, block_str, smem_str, call_args))
call_str += (" else {{\n {}<<<{}, {}, {}, stream>>>({}); \n }}\n".format(
function_name, grid_str, block_str, smem_str, call_args))
last_range += 1
_call_str += call_str

# Wrap the kernel dispatch logic in an external C function
host_func = """
extern "C" void call({}) {{
{}
}}
""".format(def_args, _call_str)
host_func = PREDEF_HOST_FUNC.format(def_args, _call_str)
return host_func

def parse_source_information(self):
Expand Down Expand Up @@ -381,10 +382,6 @@ def compare_map_objects(map_obj):
lib_code = self.source + init_func + host_func
return lib_code

@property
def prim_func(self):
return self.mod["main"]


class TIRWrapper(BaseWrapper):

Expand Down
4 changes: 2 additions & 2 deletions bitblas/cache/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def _save_operator_config_and_artifact(self, config, op_inst, config_path):
# For writing optimized.py file
optimized_file_path = os.path.join(config_path, "optimized.py")
with open(optimized_file_path, "w") as optimized_file:
if op_inst.optimized_func is not None:
optimized_file.write(op_inst.optimized_func.script(show_meta=False))
if op_inst.optimized_mod is not None:
optimized_file.write(op_inst.optimized_mod.script(show_meta=False))
if op_inst.libpath is not None:
# copy lib name to the same directory as the artifact
srcpath = op_inst.srcpath
Expand Down
85 changes: 84 additions & 1 deletion bitblas/ops/general_matmul/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from functools import reduce
from enum import IntEnum
from bitblas.base.arch.cuda import CUDA
from bitblas.base.roller.hint import Hint
from typing import Any, Literal, Optional, Tuple, Union
from ..operator import OperatorConfig, Operator, TransformKind, OPExecutorCPU
from ..operator import OperatorConfig, Operator, TransformKind, OPExecutorCPU, BaseKernelNameGenerator
from .tirscript.matmul_dequantize_impl import select_implementation as weight_dequantize_implementation
from .tirscript.matmul_impl import select_implementation as consistent_implementation
from ...base.utils import tensor_replace_dp4a, tensor_remove_make_int4, tensor_remove_make_int2
Expand Down Expand Up @@ -226,6 +227,85 @@ def __post_init__(self):
object.__setattr__(self, "storage_dtype", self.W_dtype)


class MatmulKernelNameGenerator(BaseKernelNameGenerator):

KERNEL_PREFIX = "matmul"

@staticmethod
def serialize_hint(hint: Optional[Hint] = None) -> str:
if hint is None:
return "default"
else:
if hint.use_tc:
hint_prefix = "tc"
BM, BN = hint.block
WM, WN = hint.warp
BK = hint.rstep[-1]
reduce_k = hint.block_reduction_depth
pipeline_stage = hint.pipeline_stage
hint_name = f"{hint_prefix}x{BM}x{BN}x{BK}w{WM}x{WN}"
if reduce_k is not None and reduce_k > 1:
hint_name += f"xr{reduce_k}"
if pipeline_stage > 1:
hint_name += f"xp{pipeline_stage}"
return hint_name
else:
hint_prefix = "simt"
# do not annotate for simt currently
return hint_prefix

@staticmethod
def simplify_dtype(dtype: str) -> str:
if dtype == "float32":
return "f32"
elif dtype == "float16":
return "f16"
elif dtype == "bfloat16":
return "bf16"
elif dtype.startswith("int"):
return f"i{dtype[3:]}"
elif dtype.startswith("uint"):
return f"u{dtype[4:]}"
return dtype

def generate(self, hint=None) -> str:
config = self.config
kernel_name = self.KERNEL_PREFIX
shape_str = f"n{self.config.N}k{self.config.K}"
if isinstance(config.M, int):
shape_str = f"m{config.M}" + shape_str

A_dtype = self.simplify_dtype(config.A_dtype)
W_dtype = self.simplify_dtype(config.W_dtype)

precision_str = (f"{A_dtype}x{W_dtype}")
kernel_name = "_".join([kernel_name, shape_str, precision_str])

# if config.with_scaling:
# kernel_name += "Scale"

# if config.with_zeros:
# if config.zeros_mode == "original":
# kernel_name += "OriginalZeros"
# elif config.zeros_mode == "rescale":
# precision_str += "RescaleZeros"
# elif config.zeros_mode == "quantized":
# precision_str += "QuantizedZeros"
# else:
# raise ValueError(f"Unsupported zeros mode: {config.zeros_mode}")

# if config.propagate_a is not TransformKind.NonTransform:
# kernel_name += f"_pa{config.propagate_a.value}"
# if config.propagate_b is not TransformKind.NonTransform:
# kernel_name += f"_pb{config.propagate_b.value}"

kernel_name = "_".join([kernel_name, self.serialize_hint(hint)])
return kernel_name

def is_valid_config(self, config: OperatorConfig) -> bool:
return isinstance(config, MatmulConfig)


class Matmul(Operator):

# TODO(lei): This should be improved into a general datatype class.
Expand Down Expand Up @@ -350,6 +430,9 @@ def dispatch_tir(self,
# output data type
self.torch_output_dtype = getattr(torch, self.out_dtype)

def get_kernel_name_generator(self):
return MatmulKernelNameGenerator(self.config)

def _alloc_workspace(self):
return torch.empty(WORKSPACE_SIZE, dtype=torch.float16).cuda()

Expand Down
2 changes: 1 addition & 1 deletion bitblas/ops/ladder_permutate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(

target = self.target
if target.kind.name == "cuda":
self.optimized_func = self.apply_default_schedule(self.prim_func_mod, target)
self.optimized_mod = self.apply_default_schedule(self.prim_func_mod, target)
if enable_tuning:
self.hardware_aware_finetune()
if not from_database:
Expand Down
Loading