diff --git a/problems/nvidia/eval.py b/problems/nvidia/eval.py new file mode 100644 index 0000000..981b932 --- /dev/null +++ b/problems/nvidia/eval.py @@ -0,0 +1,375 @@ +import base64 +import dataclasses +import multiprocessing +import re +import time +import os +import sys +import math +from pathlib import Path +from typing import Any, Optional + +import torch.cuda + +from utils import set_seed, clear_l2_cache +try: + from task import TestSpec +except ImportError: + TestSpec = dict + +from reference import check_implementation, generate_input + + +class PopcornOutput: + def __init__(self, fd: int): + self.file = os.fdopen(fd, 'w') + os.set_inheritable(fd, False) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.file.close() + + def print(self, *args, **kwargs): + print(*args, **kwargs, file=self.file, flush=True) + + def log(self, key, value): + self.print(f"{key}: {value}") + + +@dataclasses.dataclass +class TestCase: + args: dict + spec: str + + +def _combine(a: int, b: int) -> int: + # combine two integers into one: + # we need this to generate a secret seed based on the test-level seed and + # the global secret seed. + # the test-level seeds are public knowledge, and typically relatively small numbers, + # so we need to make sure they don't provide any useful info for the full seed. + # This Cantor construction ensures that if the secret seed is a large number, + # then so is the overall seed. + return int(a + (a+b)*(a+b+1)//2) + + +def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: + try: + content = Path(file_name).read_text() + except Exception as E: + print(f"Could not open test file`{file_name}`: {E}", file=sys.stderr) + exit(113) + + tests = [] + lines = content.splitlines() + match = r"\s*([a-zA-Z]+):\s*([a-zA-Z]+|[+-]?[0-9]+)\s*" + for line in lines: + parts = line.split(";") + case = {} + for part in parts: + matched = re.match(match, part) + if not re.fullmatch(match, part): + print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) + exit(113) + key = matched[1] + val = matched[2] + try: + val = int(val) + except ValueError: + pass + + case[key] = val + tests.append(TestCase(spec=line, args=case)) + + if seed is not None: + for test in tests: + if "seed" in test.args: + test.args["seed"] = _combine(test.args["seed"], seed) + + return tests + + +@dataclasses.dataclass +class Stats: + runs: int + mean: float + std: float + err: float + best: float + worst: float + + +def calculate_stats(durations: list[int]): + """ + Calculate statistical data from a list of durations. + + @param durations: A list of durations in nanoseconds. + @return: A Stats object containing the number of runs, mean, standard deviation, error, best, and worst durations. + """ + runs = len(durations) + total = sum(durations) + best = min(durations) + worst = max(durations) + + avg = total / runs + variance = sum(map(lambda x: (x - avg)**2, durations)) + std = math.sqrt(variance / (runs - 1)) + err = std / math.sqrt(runs) + + return Stats(runs=runs, mean=avg, std=std, err=err, best=float(best), + worst=float(worst)) + + +def _clone_data(data): + """ + Recursively goes through data and clones all tensors. + """ + if isinstance(data, tuple): + return tuple(_clone_data(x) for x in data) + elif isinstance(data, list): + return [_clone_data(x) for x in data] + elif isinstance(data, dict): + return {k: _clone_data(v) for k, v in data.items()} + elif isinstance(data, torch.Tensor): + return data.clone() + else: + return data + + +def _run_single_test(test: TestCase): + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + data = generate_input(**test.args) + torch.cuda.synchronize() + submission_output = custom_kernel(_clone_data(data)) + torch.cuda.synchronize() + return check_implementation(data, submission_output) + + +def run_single_test(pool: multiprocessing.Pool, test: TestCase): + """ + Runs a single test in another process. + """ + return pool.apply(_run_single_test, (test,)) + + +def run_testing(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): + """ + Executes the actual test case code and checks for correctness. + + @param logger: A PopcornOutput object used for logging test results. + @param tests: A list of TestCase objects representing the test cases to be executed. + @return: An integer representing the exit status: 0 if all tests pass, otherwise 112. + """ + passed = True + logger.log("test-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"test.{idx}.spec", test.spec) + good, message = run_single_test(pool, test) + if not good: + logger.log(f"test.{idx}.status", "fail") + logger.log(f"test.{idx}.error", message) + passed = False + else: + logger.log(f"test.{idx}.status", "pass") + if message: + logger.log(f"test.{idx}.message", message) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def _run_single_benchmark(test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float) -> Stats | Any: + """ + Runs one benchmark. Do not call directly. + """ + from submission import custom_kernel + + durations = [] + # generate input data once + data = generate_input(**test.args) + check_copy = _clone_data(data) + # first, one obligatory correctness check + output = custom_kernel(data) + good, message = check_implementation(check_copy, output) + if not good: + return message + + # now, do multiple timing runs without further correctness testing + # there is an upper bound of 100 runs, and a lower bound of 3 runs; + # otherwise, we repeat until we either measure at least 10 full seconds, + # or the relative error of the mean is below 1%. + + bm_start_time = time.perf_counter_ns() + for i in range(max_repeats): + if recheck: + # ensure we use a different seed for every benchmark + if "seed" in test.args: + test.args["seed"] += 13 + + data = generate_input(**test.args) + check_copy = _clone_data(data) + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + clear_l2_cache() + + start_event.record() + output = custom_kernel(data) + end_event.record() + torch.cuda.synchronize() + duration = start_event.elapsed_time(end_event) * 1e6 # Convert ms to ns + + if recheck: + good, message = check_implementation(check_copy, output) + if not good: + return message + + del output + durations.append(duration) + + if i > 1: + total_bm_duration = time.perf_counter_ns() - bm_start_time + stats = calculate_stats(durations) + # stop if either + # a) relative error dips below 0.1% + # b) we exceed the total time limit for benchmarking the kernel + # c) we exceed 2 minutes of total wallclock time. + if stats.err / stats.mean < 0.001 or stats.mean * stats.runs > max_time_ns or total_bm_duration > 120e9: + break + + return calculate_stats(durations) + + +def run_single_benchmark(pool: multiprocessing.Pool, test: TestCase, recheck: bool, max_repeats: int, + max_time_ns: float): + """ + For a particular test case, check correctness (if applicable) and grab runtime results. + + @param pool: Process on which the benchmark will be launched. + @param test: TestCase object. + @param recheck: Flag for whether to explicitly check functional correctness. + @param max_repeats: Number of trials to repeat. + @param max_time_ns: Timeout time in nanoseconds. + @return: A Stats object for this particular benchmark case or an error if the test fails. + """ + return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) + + +def run_benchmarking(logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase]): + """ + Executes benchmarking code for a CUDA Kernel and logs runtimes. + + @param logger: A PopcornOutput object used for logging benchmark results. + @param pool: Process on which the benchmarks will be launched. + @param tests: A list of TestCase objects representing the test cases to be benchmarked. + @return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112. + """ + # warm up + run_single_benchmark(pool, tests[0], False, 100, 10e7) + + passed = True + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + result = run_single_benchmark(pool, test, False, 100, 10e9) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) + else: + passed = False + logger.log(f"benchmark.{idx}.status", "fail") + logger.log(f"benchmark.{idx}.error", result) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def run_single_profile(test: TestCase) -> str: + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + from torch.profiler import profile, record_function, ProfilerActivity + data = generate_input(**test.args) + torch.cuda.synchronize() + + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + submission_output = custom_kernel(_clone_data(data)) + torch.cuda.synchronize() + return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) + + +def run_profiling(logger: PopcornOutput, tests: list[TestCase]): + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + report = run_single_profile(test) + logger.log(f"benchmark.{idx}.report", base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8")) + logger.log("check", "pass") + return 0 + + +def main(): + fd = os.getenv("POPCORN_FD") + if not fd: + return 111 + + if len(sys.argv) < 3: + return 2 + + mode = sys.argv[1] + seed = os.getenv("POPCORN_SEED") + os.unsetenv("POPCORN_SEED") + seed = int(seed) if seed else None + set_seed(seed or 42) + tests = get_test_cases(sys.argv[2], seed) + + with PopcornOutput(int(fd)) as logger: + import multiprocessing + mp_context = multiprocessing.get_context('spawn') + with mp_context.Pool(1) as pool: + if mode == "test": + return run_testing(logger, pool, tests) + if mode == "benchmark": + return run_benchmarking(logger, pool, tests) + + if mode == "leaderboard": + # warmup + run_single_benchmark(pool, tests[0], False, 100, 1e7) + logger.log("benchmark-count", len(tests)) + passed = True + for i in range(len(tests)): + result = run_single_benchmark(pool, tests[i], True, 100, 30e9) + logger.log(f"benchmark.{i}.spec", tests[i].spec) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log(f"benchmark.{i}.{field.name}", getattr(result, field.name)) + else: + passed = False + logger.log(f"benchmark.{i}.status", "fail") + logger.log(f"benchmark.{i}.error", str(result)) # TODO: Make sure result implements __str__? + break + + logger.log("check", "pass" if passed else "fail") + elif mode == "profile": + run_profiling(logger, tests) + else: + # TODO: Implement script mode + return 2 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/problems/nvidia/fp16_gemm/eval.py b/problems/nvidia/fp16_gemm/eval.py new file mode 100644 index 0000000..4169b80 --- /dev/null +++ b/problems/nvidia/fp16_gemm/eval.py @@ -0,0 +1,505 @@ +import base64 +import dataclasses +import multiprocessing +import re +import time +import os +import sys +import math +from pathlib import Path +from typing import Any, Optional +import tempfile + +import torch.cuda +from cutlass.cute.nvgpu.common import OpError + +from utils import set_seed, clear_l2_cache + +try: + from task import TestSpec +except ImportError: + TestSpec = dict + +from reference import check_implementation, generate_input + + +class PopcornOutput: + def __init__(self, fd: int): + self.file = os.fdopen(fd, "w") + os.set_inheritable(fd, False) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.file.close() + + def print(self, *args, **kwargs): + print(*args, **kwargs, file=self.file, flush=True) + + def log(self, key, value): + self.print(f"{key}: {value}") + + +@dataclasses.dataclass +class TestCase: + args: dict + spec: str + + +def _combine(a: int, b: int) -> int: + # combine two integers into one: + # we need this to generate a secret seed based on the test-level seed and + # the global secret seed. + # the test-level seeds are public knowledge, and typically relatively small numbers, + # so we need to make sure they don't provide any useful info for the full seed. + # This Cantor construction ensures that if the secret seed is a large number, + # then so is the overall seed. + return int(a + (a + b) * (a + b + 1) // 2) + + +def get_test_cases(file_name: str, seed: Optional[int]) -> list[TestCase]: + try: + content = Path(file_name).read_text() + except Exception as E: + print(f"Could not open test file`{file_name}`: {E}", file=sys.stderr) + exit(113) + + tests = [] + lines = content.splitlines() + match = r"\s*([a-zA-Z]+):\s*([a-zA-Z]+|[+-]?[0-9]+)\s*" + for line in lines: + parts = line.split(";") + case = {} + for part in parts: + matched = re.match(match, part) + if not re.fullmatch(match, part): + print(f"invalid test case: '{line}': '{part}'", file=sys.stderr) + exit(113) + key = matched[1] + val = matched[2] + try: + val = int(val) + except ValueError: + pass + + case[key] = val + tests.append(TestCase(spec=line, args=case)) + + if seed is not None: + for test in tests: + if "seed" in test.args: + test.args["seed"] = _combine(test.args["seed"], seed) + + return tests + + +@dataclasses.dataclass +class Stats: + runs: int + mean: float + std: float + err: float + best: float + worst: float + + +def calculate_stats(durations: list[int]): + """ + Calculate statistical data from a list of durations. + + @param durations: A list of durations in nanoseconds. + @return: A Stats object containing the number of runs, mean, standard deviation, error, best, and worst durations. + """ + runs = len(durations) + total = sum(durations) + best = min(durations) + worst = max(durations) + + avg = total / runs + variance = sum(map(lambda x: (x - avg) ** 2, durations)) + std = math.sqrt(variance / (runs - 1)) + err = std / math.sqrt(runs) + + return Stats( + runs=runs, mean=avg, std=std, err=err, best=float(best), worst=float(worst) + ) + + +def _clone_data(data): + """ + Recursively goes through data and clones all tensors. + """ + if isinstance(data, tuple): + return tuple(_clone_data(x) for x in data) + elif isinstance(data, list): + return [_clone_data(x) for x in data] + elif isinstance(data, dict): + return {k: _clone_data(v) for k, v in data.items()} + elif isinstance(data, torch.Tensor): + return data.clone() + else: + return data + + +def _run_single_test(test: TestCase): + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + + data = generate_input(**test.args) + torch.cuda.synchronize() + try: + submission_output = custom_kernel(_clone_data(data)) + + except OpError as E: + print(f"Encountered {E}", file=sys.stderr) + return False, str(E) + torch.cuda.synchronize() + return check_implementation(data, submission_output) + + +def run_single_test(pool: multiprocessing.Pool, test: TestCase): + """ + Runs a single test in another process. + """ + return pool.apply(_run_single_test, (test,)) + + +def run_testing( + logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] +): + """ + Executes the actual test case code and checks for correctness. + + @param logger: A PopcornOutput object used for logging test results. + @param tests: A list of TestCase objects representing the test cases to be executed. + @return: An integer representing the exit status: 0 if all tests pass, otherwise 112. + """ + # Step 1: Compile kernel once before running tests + logger.log("compile", "start") + compile_success, compile_error = pool.apply(_compile_kernel_once) + if not compile_success: + logger.log("compile", "fail") + logger.log("compile.error", compile_error) + return 112 + logger.log("compile", "pass") + + # Step 2: Run all tests with compiled kernel + passed = True + logger.log("test-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"test.{idx}.spec", test.spec) + good, message = run_single_test(pool, test) + if not good: + logger.log(f"test.{idx}.status", "fail") + logger.log(f"test.{idx}.error", message) + passed = False + else: + logger.log(f"test.{idx}.status", "pass") + if message: + logger.log(f"test.{idx}.message", message) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def _compile_kernel_once(): + """ + Compile the kernel once before any benchmarking. + This ensures compilation time is not included in benchmark results. + """ + from submission import compile_kernel + + # Use a minimal test case for compilation + compile_data = generate_input(m=128, n=128, k=64, seed=42) + a, b, c = compile_data + + try: + # Trigger compilation (will be cached) + compile_kernel(a, b, c) + torch.cuda.synchronize() + return True, None + except OpError as E: + return False, f"Compilation failed: {E}" + except Exception as E: + return False, f"Compilation failed: {E}" + + +def _run_single_benchmark( + test: TestCase, recheck: bool, max_repeats: int, max_time_ns: float +) -> Stats | Any: + """ + Runs one benchmark. Do not call directly. + """ + from submission import custom_kernel, compile_kernel + + durations = [] + # generate input data once + data = generate_input(**test.args) + check_copy = _clone_data(data) + + # Ensure kernel is compiled before any timing (compilation is cached) + try: + a, b, c = data + compile_kernel(a, b, c) + torch.cuda.synchronize() + except OpError as E: + return f"Compilation failed: {E}" + except Exception as E: + return f"Compilation failed: {E}" + + # first, one obligatory correctness check + try: + output = custom_kernel(_clone_data(data)) + except OpError as E: + return f"Encountered {E}" + good, message = check_implementation(check_copy, output) + if not good: + return message + + # now, do multiple timing runs without further correctness testing + # there is an upper bound of 200 runs, and a lower bound of 3 runs; + # otherwise, we repeat until we either measure at least 10 full seconds, + # or the relative error of the mean is below 1%. + + bm_start_time = time.perf_counter_ns() + for i in range(max_repeats): + if recheck: + # ensure we use a different seed for every benchmark + if "seed" in test.args: + test.args["seed"] += 13 + + data = generate_input(**test.args) + check_copy = _clone_data(data) + torch.cuda.synchronize() + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + clear_l2_cache() + + start_event.record() + output = custom_kernel(data) + end_event.record() + torch.cuda.synchronize() + duration = start_event.elapsed_time(end_event) * 1e6 # Convert ms to ns + + if recheck: + good, message = check_implementation(check_copy, output) + if not good: + return message + + del output + durations.append(duration) + + if i > 1: + total_bm_duration = time.perf_counter_ns() - bm_start_time + stats = calculate_stats(durations) + # stop if either + # a) relative error dips below 0.1% + # b) we exceed the total time limit for benchmarking the kernel + # c) we exceed 2 minutes of total wallclock time. + if ( + stats.err / stats.mean < 0.001 + or stats.mean * stats.runs > max_time_ns + or total_bm_duration > 120e9 + ): + break + + return calculate_stats(durations) + + +def run_single_benchmark( + pool: multiprocessing.Pool, + test: TestCase, + recheck: bool, + max_repeats: int, + max_time_ns: float, +): + """ + For a particular test case, check correctness (if applicable) and grab runtime results. + + @param pool: Process on which the benchmark will be launched. + @param test: TestCase object. + @param recheck: Flag for whether to explicitly check functional correctness. + @param max_repeats: Number of trials to repeat. + @param max_time_ns: Timeout time in nanoseconds. + @return: A Stats object for this particular benchmark case or an error if the test fails. + """ + return pool.apply(_run_single_benchmark, (test, recheck, max_repeats, max_time_ns)) + + +def run_benchmarking( + logger: PopcornOutput, pool: multiprocessing.Pool, tests: list[TestCase] +): + """ + Executes benchmarking code for a CUDA Kernel and logs runtimes. + + @param logger: A PopcornOutput object used for logging benchmark results. + @param pool: Process on which the benchmarks will be launched. + @param tests: A list of TestCase objects representing the test cases to be benchmarked. + @return: An integer representing the exit status: 0 if all benchmarks pass, otherwise 112. + """ + # Step 1: Compile kernel once (outside of timing) + logger.log("compile", "start") + compile_success, compile_error = pool.apply(_compile_kernel_once) + if not compile_success: + logger.log("compile", "fail") + logger.log("compile.error", compile_error) + return 112 + logger.log("compile", "pass") + + # Step 2: Warm up with compiled kernel + run_single_benchmark(pool, tests[0], False, 200, 10e7) + + # Step 3: Run benchmarks (compilation time excluded) + passed = True + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + result = run_single_benchmark(pool, test, False, 200, 10e9) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log(f"benchmark.{idx}.{field.name}", getattr(result, field.name)) + else: + passed = False + logger.log(f"benchmark.{idx}.status", "fail") + logger.log(f"benchmark.{idx}.error", result) + + if passed: + logger.log("check", "pass") + return 0 + else: + logger.log("check", "fail") + return 112 + + +def run_single_profile(test: TestCase) -> str: + """ + Runs a single test case. Do not call directly + """ + from submission import custom_kernel + from torch.profiler import profile, record_function, ProfilerActivity + + data = generate_input(**test.args) + torch.cuda.synchronize() + + with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof: + submission_output = custom_kernel(_clone_data(data)) + torch.cuda.synchronize() + return prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=20) + + +def run_profiling(logger: PopcornOutput, tests: list[TestCase]): + logger.log("benchmark-count", len(tests)) + for idx, test in enumerate(tests): + logger.log(f"benchmark.{idx}.spec", test.spec) + report = run_single_profile(test) + logger.log( + f"benchmark.{idx}.report", + base64.b64encode(report.encode("utf-8"), b"+*").decode("utf-8"), + ) + logger.log("check", "pass") + return 0 + + +def main(): + fd = os.getenv("POPCORN_FD") + if not fd: + return 111 + + if len(sys.argv) < 3: + return 2 + + mode = sys.argv[1] + seed = os.getenv("POPCORN_SEED") + os.unsetenv("POPCORN_SEED") + seed = int(seed) if seed else None + set_seed(seed or 42) + + filename = None + + with tempfile.NamedTemporaryFile(delete=False) as tmp: + + def build_test_string(tests: list[dict]): + as_str = "" + for test in tests: + kvs = [] + for k, v in test.items(): + kvs.append(f"{k}: {v}") + as_str += "; ".join(kvs) + "\n" + return as_str + + import yaml + + yaml_content = yaml.safe_load(open(sys.argv[2], "r")) + if mode == "test": + tests_str = build_test_string(yaml_content.get("tests", [])) + elif mode in ("benchmark", "leaderboard", "profile"): + tests_str = build_test_string(yaml_content.get("benchmarks", [])) + + tmp.write(tests_str.encode("utf-8")) + tmp.flush() + filename = tmp.name + + tests = get_test_cases(filename, seed) + + os.unlink(filename) + + with PopcornOutput(int(fd)) as logger: + import multiprocessing + + mp_context = multiprocessing.get_context("spawn") + with mp_context.Pool(1) as pool: + if mode == "test": + return run_testing(logger, pool, tests) + if mode == "benchmark": + return run_benchmarking(logger, pool, tests) + + if mode == "leaderboard": + # Step 1: Compile kernel once (outside of timing) + logger.log("compile", "start") + compile_success, compile_error = pool.apply(_compile_kernel_once) + if not compile_success: + logger.log("compile", "fail") + logger.log("compile.error", compile_error) + return 112 + logger.log("compile", "pass") + + # Step 2: Warmup with compiled kernel + run_single_benchmark(pool, tests[0], False, 200, 1e7) + + # Step 3: Run leaderboard benchmarks (compilation time excluded) + logger.log("benchmark-count", len(tests)) + passed = True + for i in range(len(tests)): + result = run_single_benchmark(pool, tests[i], True, 200, 30e9) + logger.log(f"benchmark.{i}.spec", tests[i].spec) + if isinstance(result, Stats): + for field in dataclasses.fields(Stats): + logger.log( + f"benchmark.{i}.{field.name}", + getattr(result, field.name), + ) + else: + passed = False + logger.log(f"benchmark.{i}.status", "fail") + logger.log( + f"benchmark.{i}.error", str(result) + ) # TODO: Make sure result implements __str__? + break + + logger.log("check", "pass" if passed else "fail") + elif mode == "profile": + run_profiling(logger, tests) + else: + # TODO: Implement script mode + return 2 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/problems/nvidia/fp16_gemm/reference.py b/problems/nvidia/fp16_gemm/reference.py new file mode 100644 index 0000000..8eca7b7 --- /dev/null +++ b/problems/nvidia/fp16_gemm/reference.py @@ -0,0 +1,32 @@ +import torch +from task import input_t, output_t +from utils import make_match_reference + + +def ref_kernel( + data: input_t, +)->output_t: + a, b, c = data + # call torch matmul operation + # b is N x K in column-major order, need to transpose b before matmul + c[...] = a @ b.t() + return c + + +def generate_input( + m: int, + n: int, + k: int, + seed: int, +) -> input_t: + torch.manual_seed(seed) + + # Generate a, b and c tensors + a = torch.empty(m, k, dtype=torch.float16).random_(-2, 2).to(device="cuda") + b = torch.empty(n, k, dtype=torch.float16).random_(-2, 2).to(device="cuda") + c = torch.empty(m, n, dtype=torch.float16).random_(-2, 2).to(device="cuda") + + return (a, b, c) + + +check_implementation = make_match_reference(ref_kernel, rtol=1e-01, atol=1e-05) \ No newline at end of file diff --git a/problems/nvidia/fp16_gemm/submission.py b/problems/nvidia/fp16_gemm/submission.py new file mode 100644 index 0000000..9e000be --- /dev/null +++ b/problems/nvidia/fp16_gemm/submission.py @@ -0,0 +1,399 @@ +import argparse +import torch +from task import input_t, output_t +from typing import Tuple + +import cutlass +import cutlass.cute as cute +import cutlass.utils as utils +import cutlass.torch as cutlass_torch +import cutlass.pipeline as pipeline +from cutlass.cute.nvgpu import cpasync, tcgen05 +import cutlass.utils.blackwell_helpers as sm100_utils +from cutlass.cute.runtime import from_dlpack + +# Kernel configuration +io_dtype = cutlass.Float16 +acc_dtype = cutlass.Float32 +mma_inst_shape_mnk = (128, 128, 16) +mma_tiler_mnk = (128, 128, 64) +threads_per_cta = 128 + +# Pipeline stage configuration +ab_stages = 4 +acc_stage = 1 + + +@cute.struct +class SharedStorage: + ab_mbar_ptr: cute.struct.MemRange[cutlass.Int64, ab_stages * 2] + acc_mbar_ptr: cute.struct.MemRange[cutlass.Int64, acc_stage * 2] + tmem_holding_buf: cutlass.Int32 + + +# The naive CuTe implementation of a fp16 matmul +@cute.kernel +def kernel( + tiled_mma: cute.TiledMma, + tma_atom_a: cute.CopyAtom, + mA_mkl: cute.Tensor, + tma_atom_b: cute.CopyAtom, + mB_nkl: cute.Tensor, + mC_mnl: cute.Tensor, + a_smem_layout: cute.ComposedLayout, + b_smem_layout: cute.ComposedLayout, +): + # Current thread/warp/block coordinates + tidx, _, _ = cute.arch.thread_idx() + warp_idx = cute.arch.warp_idx() + warp_idx = cute.arch.make_warp_uniform(warp_idx) + bidx, bidy, _ = cute.arch.block_idx() + mma_coord_mnk = (bidx, bidy, None) + + # + # 1. Prepare args + # + + # Allocate SMEM + smem = cutlass.utils.SmemAllocator() + storage = smem.allocate(SharedStorage) + sA = smem.allocate_tensor( + element_type=io_dtype, + layout=a_smem_layout.outer, + byte_alignment=128, + swizzle=a_smem_layout.inner, + ) + sB = smem.allocate_tensor( + element_type=io_dtype, + layout=b_smem_layout.outer, + byte_alignment=128, + swizzle=b_smem_layout.inner, + ) + + # Prefetch tma descriptor + if warp_idx == 0: + cpasync.prefetch_descriptor(tma_atom_a) + cpasync.prefetch_descriptor(tma_atom_b) + + # Pipeline configuration + num_tma_copy_bytes = cute.size_in_bytes( + io_dtype, cute.select(a_smem_layout, mode=[0, 1, 2]) + ) + cute.size_in_bytes(io_dtype, cute.select(b_smem_layout, mode=[0, 1, 2])) + ab_producer, ab_consumer = pipeline.PipelineTmaUmma.create( + num_stages=ab_stages, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + tx_count=num_tma_copy_bytes, + barrier_storage=storage.ab_mbar_ptr.data_ptr(), + ).make_participants() + acc_producer, acc_consumer = pipeline.PipelineUmmaAsync.create( + num_stages=acc_stage, + producer_group=pipeline.CooperativeGroup(pipeline.Agent.Thread), + consumer_group=pipeline.CooperativeGroup( + pipeline.Agent.Thread, threads_per_cta, threads_per_cta + ), + barrier_storage=storage.acc_mbar_ptr.data_ptr(), + ).make_participants() + + # Partition tensors for MMA and make fragments + # (bM, bK, RestK) + gA = cute.local_tile(mA_mkl, mma_tiler_mnk, mma_coord_mnk, proj=(1, None, 1)) + # (bN, bK, RestK) + gB = cute.local_tile(mB_nkl, mma_tiler_mnk, mma_coord_mnk, proj=(None, 1, 1)) + # (bM, bN) + gC = cute.local_tile(mC_mnl, mma_tiler_mnk, mma_coord_mnk, proj=(1, 1, None)) + thr_mma = tiled_mma.get_slice(0) + # (MMA, MMA_M, MMA_K) + tCgA = thr_mma.partition_A(gA) + # (MMA, MMA_N, MMA_K) + tCgB = thr_mma.partition_B(gB) + # (MMA, MMA_M, MMA_N) + tCgC = thr_mma.partition_C(gC) + # (MMA, MMA_M, MMA_K) + tCrA = tiled_mma.make_fragment_A(sA) + # (MMA, MMA_N, MMA_K) + tCrB = tiled_mma.make_fragment_B(sB) + # (MMA, MMA_M, MMA_N) + acc_shape = tiled_mma.partition_shape_C(mma_tiler_mnk[:2]) + # (MMA, MMA_M, MMA_N) + tCtAcc = tiled_mma.make_fragment_C(acc_shape) + # Partition tensors for TMA; This requires the tensors partitioned for MMA + tAsA, tAgA = cute.nvgpu.cpasync.tma_partition( + tma_atom_a, + 0, + cute.make_layout(1), + cute.group_modes(sA, 0, 3), + cute.group_modes(tCgA, 0, 3), + ) + tBsB, tBgB = cute.nvgpu.cpasync.tma_partition( + tma_atom_b, + 0, + cute.make_layout(1), + cute.group_modes(sB, 0, 3), + cute.group_modes(tCgB, 0, 3), + ) + + # Allocate all TMEM columns + # CTA-wide sync before retrieving the pointer to the start of the allocated TMEM + # Only warp 0 does the allocation so we need to sync before retrieving the TMEM start address + if warp_idx == 0: + cute.arch.alloc_tmem(512, storage.tmem_holding_buf) + cute.arch.barrier() + tmem_ptr = cute.arch.retrieve_tmem_ptr( + acc_dtype, + alignment=16, + ptr_to_buffer_holding_addr=storage.tmem_holding_buf, + ) + + # Create an accumulator Tensor + tCtAcc = cute.make_tensor(tmem_ptr, tCtAcc.layout) + + subtile_cnt = 4 + # (EpiTile) + epi_tiler = ( + (cute.size(tCtAcc, mode=[0, 0]), cute.size(tCtAcc, mode=[0, 1]) // subtile_cnt), + ) + # (EpiTile, NumTiles) + tCtAcc_epi = cute.zipped_divide(tCtAcc, epi_tiler) + # (EpiTile, NumTiles) + gC_epi = cute.zipped_divide(tCgC, epi_tiler) + + # Every thread loads 32x64 bits + tmem_atom = cute.make_copy_atom( + tcgen05.Ld32x32bOp(tcgen05.Repetition.x32), + cutlass.Float32, + ) + tmem_tiled_copy = tcgen05.make_tmem_copy(tmem_atom, tCtAcc_epi[None, 0]) + tmem_thr_copy = tmem_tiled_copy.get_slice(tidx) + + # (TmemCpy,NumTmemCpy,NumTiles) + tDtC = tmem_thr_copy.partition_S(tCtAcc_epi) + # (TmemCpy,NumTmemCpy,NumTiles) + tDgC = tmem_thr_copy.partition_D(gC_epi) + + # (TmemCpy,NumTmemCpy) + tCrAcc = cute.make_fragment(tDgC[None, None, 0].shape, acc_dtype) + # (TmemCpy,NumTmemCpy) + tCrC = cute.make_fragment(tDgC[None, None, 0].shape, io_dtype) + + # + # 2. Main loop + # + num_k_tiles = cute.size(gA, mode=[2]) + if warp_idx == 0: + # Wait for a empty accumulator buffer + acc_empty = acc_producer.acquire_and_advance() + for k_tile_idx in cutlass.range(num_k_tiles, prefetch_stages=ab_stages - 2): + # Issue TMA loads + ab_empty = ab_producer.acquire_and_advance() + cute.copy( + tma_atom_a, + tAgA[(None, ab_empty.count)], + tAsA[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + ) + cute.copy( + tma_atom_b, + tBgB[(None, ab_empty.count)], + tBsB[(None, ab_empty.index)], + tma_bar_ptr=ab_empty.barrier, + ) + + # Execute one K-block worth of MMA instructions + ab_full = ab_consumer.wait_and_advance() + num_k_blocks = cute.size(tCrA, mode=[2]) + for k_block_idx in cutlass.range_constexpr(num_k_blocks): + k_block_coord = (None, None, k_block_idx, ab_full.index) + cute.gemm( + tiled_mma, + tCtAcc, + tCrA[k_block_coord], + tCrB[k_block_coord], + tCtAcc, + ) + tiled_mma.set(tcgen05.Field.ACCUMULATE, True) + + # Signal that the A/B buffers have been consumed and are ready for the next load + ab_full.release() + + # Signal that the accumulator is fully computed + acc_empty.commit() + + # + # 3. Epilogue + # + + # Release TMEM allocation lock + if warp_idx == 0: + cute.arch.relinquish_tmem_alloc_permit() + + # Wait for the accumulator buffer to be full + acc_full = acc_consumer.wait_and_advance() + + # TMEM -> RMEM -> GEMM + # Sub-tiling for better instruction-level parallelism + for i in cutlass.range(cute.size(tDtC, mode=[2])): + cute.copy(tmem_tiled_copy, tDtC[None, None, i], tCrAcc) + tCrC.store(tCrAcc.load().to(io_dtype)) + cute.autovec_copy(tCrC, tDgC[None, None, i]) + + acc_full.release() + + # Deallocate TMEM + cute.arch.barrier() + if warp_idx == 0: + cute.arch.dealloc_tmem(tmem_ptr, 512) + + +@cute.jit +def my_kernel( + a: cute.Tensor, + b: cute.Tensor, + c: cute.Tensor, +): + # Construct tiled MMA + op = tcgen05.MmaF16BF16Op( + io_dtype, + acc_dtype, + mma_inst_shape_mnk, + tcgen05.CtaGroup.ONE, + tcgen05.OperandSource.SMEM, + tcgen05.OperandMajorMode.K, + tcgen05.OperandMajorMode.K, + ) + tiled_mma = cute.make_tiled_mma(op) + + # Construct SMEM layouts for A and B + a_smem_layout = sm100_utils.make_smem_layout_a( + tiled_mma, + mma_tiler_mnk, + a.element_type, + ab_stages, + ) + b_smem_layout = sm100_utils.make_smem_layout_b( + tiled_mma, + mma_tiler_mnk, + b.element_type, + ab_stages, + ) + a_smem_layout_one_stage = cute.select(a_smem_layout, mode=[0, 1, 2]) + b_smem_layout_one_stage = cute.select(b_smem_layout, mode=[0, 1, 2]) + + # Construct TMA load atoms + op = cute.nvgpu.cpasync.CopyBulkTensorTileG2SOp(tcgen05.CtaGroup.ONE) + # Construct TMA load atoms and Tensors for A and B respectively + a_tma_atom, a_tma_tensor = cute.nvgpu.make_tiled_tma_atom_A( + op, + a, + a_smem_layout_one_stage, + mma_tiler_mnk, + tiled_mma, + (1, 1, 1), + ) + b_tma_atom, b_tma_tensor = cute.nvgpu.make_tiled_tma_atom_B( + op, + b, + b_smem_layout_one_stage, + mma_tiler_mnk, + tiled_mma, + (1, 1, 1), + ) + + # Calculate the grid shape + grid_shape = cute.ceil_div((*c.layout.shape, 1), mma_tiler_mnk[:2]) + + # Launch the kernel + kernel( + tiled_mma, + a_tma_atom, + a_tma_tensor, + b_tma_atom, + b_tma_tensor, + c, + a_smem_layout, + b_smem_layout, + ).launch( + grid=grid_shape, + block=(threads_per_cta, 1, 1), + ) + + +# Global cache for compiled kernel +_compiled_kernel_cache = None + + +def compile_kernel(a, b, c): + """ + Compile the kernel once and cache it. + This should be called before any timing measurements. + + Args: + a, b, c: Sample tensors with the expected shapes and types + + Returns: + The compiled kernel function + """ + global _compiled_kernel_cache + + if _compiled_kernel_cache is not None: + return _compiled_kernel_cache + + # Convert torch tensors to CuTe tensors via dlpack protocol + a_tensor = ( + from_dlpack(a, assumed_align=32) + .mark_layout_dynamic(leading_dim=1) + .mark_compact_shape_dynamic(mode=1, divisibility=32) + ) + b_tensor = ( + from_dlpack(b, assumed_align=32) + .mark_layout_dynamic(leading_dim=1) + .mark_compact_shape_dynamic(mode=1, divisibility=32) + ) + c_tensor = ( + from_dlpack(c, assumed_align=32) + .mark_layout_dynamic(leading_dim=1) + .mark_compact_shape_dynamic(mode=1, divisibility=32) + ) + + # Compile the kernel + _compiled_kernel_cache = cute.compile(my_kernel, a_tensor, b_tensor, c_tensor) + + return _compiled_kernel_cache + + +def custom_kernel(data: input_t) -> output_t: + """ + Execute the kernel. If not already compiled, compile it first. + + Args: + data: Tuple of (a, b, c) tensors + + Returns: + Output tensor c + """ + # Get input tensors + a, b, c = data + + # Ensure kernel is compiled (will use cached version if available) + compiled_func = compile_kernel(a, b, c) + + # Convert torch tensors to CuTe tensors via dlpack protocol + a_tensor = ( + from_dlpack(a, assumed_align=32) + .mark_layout_dynamic(leading_dim=1) + .mark_compact_shape_dynamic(mode=1, divisibility=32) + ) + b_tensor = ( + from_dlpack(b, assumed_align=32) + .mark_layout_dynamic(leading_dim=1) + .mark_compact_shape_dynamic(mode=1, divisibility=32) + ) + c_tensor = ( + from_dlpack(c, assumed_align=32) + .mark_layout_dynamic(leading_dim=1) + .mark_compact_shape_dynamic(mode=1, divisibility=32) + ) + + # Execute the compiled kernel + compiled_func(a_tensor, b_tensor, c_tensor) + return c diff --git a/problems/nvidia/fp16_gemm/task.py b/problems/nvidia/fp16_gemm/task.py new file mode 100644 index 0000000..d35ed2e --- /dev/null +++ b/problems/nvidia/fp16_gemm/task.py @@ -0,0 +1,10 @@ +import torch +from typing import TypedDict, TypeVar + +input_t = TypeVar("input_t", bound=tuple[torch.Tensor, torch.Tensor, torch.Tensor]) +output_t = TypeVar("output_t", bound=torch.Tensor) +class TestSpec(TypedDict): + m: int + n: int + k: int + seed: int diff --git a/problems/nvidia/fp16_gemm/task.yml b/problems/nvidia/fp16_gemm/task.yml new file mode 100644 index 0000000..0d7a4a3 --- /dev/null +++ b/problems/nvidia/fp16_gemm/task.yml @@ -0,0 +1,62 @@ +# name: fp16-gemm + +files: + - {"name": "submission.py", "source": "@SUBMISSION@"} + - {"name": "task.py", "source": "task.py"} + - {"name": "utils.py", "source": "../utils.py"} + - {"name": "reference.py", "source": "reference.py"} + - {"name": "eval.py", "source": "../eval.py"} + +lang: "py" + +description: | + + You will implement a fp16 matrix multiplication kernel optimized for NVIDIA B200. + The shapes of tensors are from DeepSeek-R1. + To be explicit, you will be given a tuple of tensors: + ``` + (a, b, c) + ``` + where `a` and `b` are the input matrices, and `c` is the output matrix: + * `a` is M x K in row-major order in fp16 + * `b` is N x K in column-major order in fp16 + * `c` is M x N in row-major order in fp16 + + Matrix sizes `M` and `N` are divisible by mma_tiler_mnk defined in the kernel, `K` is divisible by 64. + + The ranking criteria is the geometric mean of the benchmark results. + + For the grand price, your kernel will be evaluated against the speed of light analysis + and the solution closest to the speed of light will be awarded the grand price. + ``` + The speed of light analysis is (using 1.5Ghz clock): + M N K time[us] + 7168 128 16384 17.38 + 4096 128 7168 4.34 + 7168 128 2048 2.17 + ``` + +config: + main: "eval.py" + +templates: + Python: "template.py" + +tests: + - {"m": 128, "n": 256, "k": 64, "seed": 1111} + - {"m": 128, "n": 1536, "k": 7168, "seed": 1111} + - {"m": 128, "n": 3072, "k": 1536, "seed": 1111} + - {"m": 256, "n": 7168, "k": 256, "seed": 1111} + - {"m": 256, "n": 7168, "k": 2048, "seed": 1111} + - {"m": 2384, "n": 4608, "k": 7168, "seed": 1111} + - {"m": 384, "n": 7168, "k": 2304, "seed": 1111} + - {"m": 512, "n": 512, "k": 7168, "seed": 1111} + - {"m": 512, "n": 4096, "k": 512, "seed": 1111} + - {"m": 512, "n": 1536, "k": 7168, "seed": 1111} + +benchmarks: + - {"m": 7168, "n": 128, "k": 16384, "seed": 1111} + - {"m": 4096, "n": 128, "k": 7168, "seed": 1111} + - {"m": 7168, "n": 128, "k": 2048, "seed": 1111} + +ranking_by: "geom" diff --git a/problems/nvidia/fp16_gemm/template.py b/problems/nvidia/fp16_gemm/template.py new file mode 100644 index 0000000..e4615a0 --- /dev/null +++ b/problems/nvidia/fp16_gemm/template.py @@ -0,0 +1,20 @@ +from task import input_t, output_t + + +def custom_kernel(data: input_t) -> output_t: + """ + Reference implementation of block-scale fp8 gemm + Args: + data: Tuple that expands to: + a: torch.Tensor[float16] of shape [m, k], + b: torch.Tensor[float16] of shape [n, k], + c: torch.Tensor[float16] of shape [m, n] + Returns: + Tensor containing output in float16 + """ + # c: [m, n] is pre-allocated memory to avoid timing allocation overhead. + a, b, c = data + + # Your implementation here + + return c diff --git a/problems/nvidia/utils.py b/problems/nvidia/utils.py new file mode 100644 index 0000000..e8a9082 --- /dev/null +++ b/problems/nvidia/utils.py @@ -0,0 +1,176 @@ +import os +import random +import numpy as np +import torch + + +def set_seed(seed=42): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def get_device(use_cuda: bool = True) -> torch.device: + """Get the appropriate device (GPU or CPU).""" + if use_cuda: + if torch.cuda.is_available(): + return torch.device("cuda") + elif torch.backends.mps.is_available(): + return torch.device("mps") + else: + print("No compatible GPU found. Falling back to CPU.") + return torch.device("cpu") + + +# Adapted from https://github.com/linkedin/Liger-Kernel/blob/main/test/utils.py +@torch.no_grad() +def verbose_allclose( + received: torch.Tensor, + expected: torch.Tensor, + rtol=1e-05, + atol=1e-08, + max_print=5 +) -> list[str]: + """ + Assert that two tensors are element-wise equal within a tolerance, providing detailed information about mismatches. + + Parameters: + received (torch.Tensor): Tensor we actually got. + expected (torch.Tensor): Tensor we expected to receive. + rtol (float): Relative tolerance; relative to expected + atol (float): Absolute tolerance. + max_print (int): Maximum number of mismatched elements to print. + + Raises: + AssertionError: If the tensors are not all close within the given tolerance. + """ + # Check if the shapes of the tensors match + if received.shape != expected.shape: + return ["SIZE MISMATCH"] + + # Calculate the difference between the tensors + diff = torch.abs(received - expected) + + # Determine the tolerance + tolerance = atol + rtol * torch.abs(expected) + + # Find tolerance mismatched elements + tol_mismatched = diff > tolerance + + # Find nan mismatched elements + nan_mismatched = torch.logical_xor(torch.isnan(received), torch.isnan(expected)) + + # Find +inf mismatched elements + posinf_mismatched = torch.logical_xor(torch.isposinf(received), torch.isposinf(expected)) + # Find -inf mismatched elements + neginf_mismatched = torch.logical_xor(torch.isneginf(received), torch.isneginf(expected)) + + # Find all mismatched elements + mismatched = torch.logical_or( + torch.logical_or(tol_mismatched, nan_mismatched), + torch.logical_or(posinf_mismatched, neginf_mismatched), + ) + + mismatched_indices = torch.nonzero(mismatched) + + # Count the number of mismatched elements + num_mismatched = mismatched.count_nonzero().item() + + # Generate detailed information if there are mismatches + if num_mismatched >= 1: + mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] + + for index in mismatched_indices[:max_print]: + i = tuple(index.tolist()) + mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") + if num_mismatched > max_print: + mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") + return mismatch_details + + return [] + + +@torch.no_grad() +def verbose_allequal(received: torch.Tensor, expected: torch.Tensor, max_print: int=5): + """ + Assert that two tensors are element-wise perfectly equal, providing detailed information about mismatches. + + Parameters: + received (torch.Tensor): Tensor we actually got. + expected (torch.Tensor): Tensor we expected to receive. + max_print (int): Maximum number of mismatched elements to print. + + Returns: + Empty string if tensors are equal, otherwise detailed error information + """ + mismatched = torch.not_equal(received, expected) + mismatched_indices = torch.nonzero(mismatched) + + # Count the number of mismatched elements + num_mismatched = mismatched.count_nonzero().item() + + # Generate detailed information if there are mismatches + if num_mismatched >= 1: + mismatch_details = [f"Number of mismatched elements: {num_mismatched}"] + + for index in mismatched_indices[:max_print]: + i = tuple(index.tolist()) + mismatch_details.append(f"ERROR AT {i}: {received[i]} {expected[i]}") + if num_mismatched > max_print: + mismatch_details.append(f"... and {num_mismatched - max_print} more mismatched elements.") + return mismatch_details + + return [] + + +def match_reference(data, output, reference: callable, rtol=1e-05, atol=1e-08) -> tuple[bool, str]: + """ + Convenient "default" implementation for tasks' `check_implementation` function. + """ + expected = reference(data) + reasons = verbose_allclose(output, expected, rtol=rtol, atol=atol) + + if len(reasons) > 0: + return False, "mismatch found! custom implementation doesn't match reference: " + " ".join(reasons) + + return True, '' + + +def make_match_reference(reference: callable, **kwargs): + def wrapped(data, output): + return match_reference(data, output, reference=reference, **kwargs) + return wrapped + + +class DeterministicContext: + def __init__(self): + self.allow_tf32 = None + self.deterministic = None + self.cublas = None + + def __enter__(self): + self.cublas = os.environ.get('CUBLAS_WORKSPACE_CONFIG', '') + self.allow_tf32 = torch.backends.cudnn.allow_tf32 + self.deterministic = torch.backends.cudnn.deterministic + torch.backends.cudnn.allow_tf32 = False + torch.backends.cudnn.deterministic = True + torch.use_deterministic_algorithms(True) + return self + + def __exit__(self, exc_type, exc_value, traceback): + torch.backends.cudnn.allow_tf32 = self.allow_tf32 + torch.backends.cudnn.deterministic = self.deterministic + torch.use_deterministic_algorithms(False) + os.environ['CUBLAS_WORKSPACE_CONFIG'] = self.cublas + +def clear_l2_cache(): + # import cupy as cp + # cp.cuda.runtime.deviceSetLimit(cp.cuda.runtime.cudaLimitPersistingL2CacheSize, 0) + # create a large dummy tensor + dummy = torch.empty((32, 1024, 1024), dtype=torch.int64, device="cuda") + # write stuff to + dummy.fill_(42) + del dummy \ No newline at end of file