Skip to content

Commit 86d9784

Browse files
authored
Script to generate starter code (#90)
Co-authored-by: cats-marin <[email protected]>
1 parent f0710ec commit 86d9784

File tree

2 files changed

+205
-0
lines changed

2 files changed

+205
-0
lines changed

CONTRIBUTING.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,13 @@ challenges/<difficulty>/<number>_<name>/
2929
└── starter.triton.py # Triton
3030
```
3131

32+
> [!NOTE]
33+
Use generate_starter_code.py in scripts folder to help generate the starter code for medium and hard problems
34+
35+
```bash
36+
python scripts/generate_starter_code.py path/to/challenge_dir # can be either absolute or relative path
37+
```
38+
3239
### Requirements
3340
- Clear problem description with 1 or more examples
3441
- Starter templates should follow the format of existing starter templates

scripts/generate_starter_code.py

Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
import os
2+
import sys
3+
from importlib import util, machinery
4+
import types
5+
import ctypes
6+
import urllib.request
7+
import tempfile
8+
9+
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
10+
if PROJECT_ROOT not in sys.path:
11+
sys.path.insert(0, PROJECT_ROOT)
12+
13+
CONST_HINTS = {
14+
"input"
15+
}
16+
17+
CTYPE_TO_CUDA = {
18+
ctypes.c_int: "int",
19+
ctypes.c_float: "float",
20+
ctypes.c_double: "double",
21+
ctypes.c_uint32: "unsigned int",
22+
ctypes.c_int64: "long long",
23+
ctypes.c_uint16: "__half",
24+
}
25+
26+
CTYPE_TO_MOJO = {
27+
ctypes.c_int: "Int32",
28+
ctypes.c_float: "Float32",
29+
ctypes.c_double: "Float64",
30+
ctypes.c_uint32: "UInt32",
31+
ctypes.c_int64: "Int64",
32+
ctypes.c_uint16: "Float16",
33+
}
34+
35+
CTYPE_TO_TORCH = {
36+
ctypes.c_int: "int",
37+
ctypes.c_float: "torch.float32",
38+
ctypes.c_double: "torch.float64",
39+
ctypes.c_uint32: "int",
40+
ctypes.c_int64: "torch.int64",
41+
ctypes.c_uint16: "torch.float16",
42+
}
43+
44+
def ctype_to_cuda(ctype, name) -> str:
45+
if isinstance(ctype, type) and issubclass(ctype, ctypes._Pointer):
46+
base_type = getattr(ctype, "_type_", None)
47+
if base_type is None or base_type not in CTYPE_TO_CUDA:
48+
raise ValueError(
49+
f"Unsupported pointer base type: {base_type}. "
50+
"Please extend CTYPE_TO_CUDA mapping."
51+
)
52+
return f"{'const ' if name in CONST_HINTS else ''}{CTYPE_TO_CUDA[base_type]}*"
53+
54+
if ctype not in CTYPE_TO_CUDA:
55+
raise ValueError(
56+
f"Unsupported scalar type: {ctype}. "
57+
"Please extend CTYPE_TO_CUDA mapping."
58+
)
59+
return CTYPE_TO_CUDA[ctype]
60+
61+
def ctype_to_mojo(ctype) -> str:
62+
if isinstance(ctype, type) and issubclass(ctype, ctypes._Pointer):
63+
base_type = getattr(ctype, "_type_", None)
64+
if base_type is None or base_type not in CTYPE_TO_MOJO:
65+
raise ValueError(
66+
f"Unsupported pointer base type: {base_type}. "
67+
"Please extend CTYPE_TO_MOJO mapping."
68+
)
69+
return f"UnsafePointer[{CTYPE_TO_MOJO[base_type]}]"
70+
71+
if ctype not in CTYPE_TO_MOJO:
72+
raise ValueError(
73+
f"Unsupported scalar type: {ctype}. "
74+
"Please extend CTYPE_TO_MOJO mapping."
75+
)
76+
return CTYPE_TO_MOJO[ctype]
77+
78+
def ctype_to_torch(ctype, name) -> str:
79+
if isinstance(ctype, type) and issubclass(ctype, ctypes._Pointer):
80+
return f"{name}: torch.Tensor"
81+
82+
if ctype in (ctypes.c_int, ctypes.c_uint32, ctypes.c_int64):
83+
return f"{name}: int"
84+
if ctype in (ctypes.c_float, ctypes.c_double):
85+
return f"{name}: float"
86+
87+
raise ValueError(
88+
f"Unsupported type {ctype} for PyTorch mapping. "
89+
"Please extend CTYPE_TO_TORCH mapping."
90+
)
91+
92+
def load_module(name: str, path: str):
93+
spec = util.spec_from_file_location(name, path)
94+
if spec is None or spec.loader is None:
95+
raise ImportError(f"Could not load {name} from {path}")
96+
97+
module = util.module_from_spec(spec)
98+
spec.loader.exec_module(module)
99+
sys.modules[name] = module
100+
return module
101+
102+
def load_challenge(challenge_dir: str):
103+
base_url = "https://api.leetgpu.com/api/v1/core-files/challenge_base.py"
104+
base_dst = os.path.join(tempfile.gettempdir(), "challenge_base.py")
105+
urllib.request.urlretrieve(base_url, base_dst)
106+
107+
sys.modules.setdefault("core", types.ModuleType("core")).__path__ = []
108+
109+
load_module("core.challenge_base", base_dst)
110+
challenge = load_module("challenge", os.path.join(challenge_dir, "challenge.py"))
111+
112+
return challenge.Challenge()
113+
114+
def generate_starter_cuda(sig, starter_file):
115+
arg_str = ", ".join(ctype_to_cuda(typ, name) + f" {name}" for name, typ in sig.items())
116+
include_half = "#include <cuda_fp16.h>\n" if "__half" in arg_str else ""
117+
code = f"""#include <cuda_runtime.h>
118+
{include_half}
119+
extern "C" void solve({arg_str}) {{
120+
121+
}}"""
122+
with open(starter_file, "w") as f:
123+
f.write(code)
124+
125+
def generate_starter_mojo(sig, starter_file):
126+
arg_str = ", ".join(f"{name}: {ctype_to_mojo(typ)}" for name, typ in sig.items())
127+
code = f"""from gpu.host import DeviceContext
128+
from gpu.id import block_dim, block_idx, thread_idx
129+
from memory import UnsafePointer
130+
from math import ceildiv
131+
132+
@export
133+
def solve({arg_str}):
134+
pass"""
135+
136+
with open(starter_file, "w") as f:
137+
f.write(code)
138+
139+
def generate_starter_pytorch(sig, starter_file):
140+
arg_str = ", ".join(ctype_to_torch(typ, name) for name, typ in sig.items())
141+
code = f"""import torch
142+
143+
def solve({arg_str}):
144+
pass
145+
"""
146+
with open(starter_file, "w") as f:
147+
f.write(code)
148+
149+
def generate_starter_triton(sig, starter_file):
150+
def ctype_to_triton(ctype, name):
151+
if isinstance(ctype, type) and issubclass(ctype, ctypes._Pointer):
152+
return f"{name}: torch.Tensor"
153+
if ctype in (ctypes.c_int, ctypes.c_uint32, ctypes.c_int64):
154+
return f"{name}: int"
155+
if ctype in (ctypes.c_float, ctypes.c_double):
156+
return f"{name}: float"
157+
raise ValueError(f"Unsupported type {ctype} for Triton mapping. Please extend ctype_to_triton mapping.")
158+
159+
arg_str = ", ".join(ctype_to_triton(typ, name) for name, typ in sig.items())
160+
code = f"""import torch
161+
import triton
162+
import triton.language as tl
163+
164+
def solve({arg_str}):
165+
pass
166+
"""
167+
with open(starter_file, "w") as f:
168+
f.write(code)
169+
170+
def main():
171+
if len(sys.argv) != 2:
172+
print("Usage: python scripts/generate_starter_code.py path/to/challenge_dir")
173+
sys.exit(1)
174+
175+
challenge_dir = sys.argv[1]
176+
177+
if "easy" in (part.lower() for part in os.path.normpath(challenge_dir).split(os.sep)):
178+
print("Starter code generation script should not be used for 'easy' challenges.")
179+
sys.exit(1)
180+
181+
starter_dir = os.path.join(challenge_dir, "starter")
182+
183+
try:
184+
os.makedirs(starter_dir, exist_ok=True)
185+
except Exception as e:
186+
print(f"Error creating starter directory: {e}")
187+
sys.exit(1)
188+
189+
challenge = load_challenge(challenge_dir)
190+
sig = challenge.get_solve_signature()
191+
192+
generate_starter_cuda(sig, os.path.join(starter_dir, "starter.cu"))
193+
generate_starter_mojo(sig, os.path.join(starter_dir, "starter.mojo"))
194+
generate_starter_pytorch(sig, os.path.join(starter_dir, "starter.pytorch.py"))
195+
generate_starter_triton(sig, os.path.join(starter_dir, "starter.triton.py"))
196+
197+
if __name__ == "__main__":
198+
main()

0 commit comments

Comments
 (0)