Skip to content

Commit c7cbf86

Browse files
committed
Add a new stage to generate zebin when TRITON_XPU_GEN_NATIVE_CODE=1.
Signed-off-by: Lu,Chengjun <[email protected]>
1 parent 3de5f93 commit c7cbf86

File tree

2 files changed

+56
-47
lines changed

2 files changed

+56
-47
lines changed

python/triton/compiler/compiler.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def parse(full_name, ext, context):
138138
return module
139139
if ext == "llir" or ext == "ptx" or ext == "amdgcn":
140140
return Path(full_name).read_text()
141-
if ext == "cubin" or ext == "hsaco":
141+
if ext == "cubin" or ext == "hsaco" or ext == "zebin":
142142
return Path(full_name).read_bytes()
143143
if ext == "spv":
144144
return Path(full_name).read_bytes()
@@ -332,7 +332,7 @@ def compile(src, target=None, options=None, _env_vars=None):
332332
print(f"\nOverriding kernel with file {full_name}")
333333
next_module = parse(full_name, ext, context)
334334
# If TRITON_STORE_BINARY_ONLY is 1, only store cubin/hsaco/json
335-
if (not store_only_binary) or (ext in ("cubin", "hsaco", "json", "spv")):
335+
if (not store_only_binary) or (ext in ("cubin", "hsaco", "zebin", "json", "spv")):
336336
metadata_group[ir_filename] = fn_cache_manager.put(next_module, ir_filename)
337337
if fn_dump_manager is not None:
338338
fn_dump_manager.put(next_module, ir_filename)
@@ -422,11 +422,15 @@ def __init__(self, src, metadata_group, hash):
422422
self.name = self.metadata.name
423423
# stores the text of each level of IR that was generated during compilation
424424
asm_files = [Path(p) for c, p in metadata_group.items() if not c.endswith(".json")]
425+
426+
def read_file(path):
427+
try:
428+
return path.read_text()
429+
except UnicodeDecodeError:
430+
return path.read_bytes()
431+
432+
self.asm = AsmDict({file.suffix[1:]: read_file(file) for file in asm_files})
425433
binary_ext = backend.binary_ext
426-
self.asm = AsmDict({
427-
file.suffix[1:]: file.read_bytes() if file.suffix[1:] == binary_ext else file.read_text()
428-
for file in asm_files
429-
})
430434
self.metadata_group = metadata_group
431435
self.kernel = self.asm[binary_ext]
432436
# binaries are lazily initialized

third_party/intel/backend/compiler.py

Lines changed: 46 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def __init__(self, target: tuple) -> None:
9696
mod = compile_module_from_src(src=Path(os.path.join(dirname, "arch_parser.c")).read_text(), name="arch_utils")
9797
self.device_arch = knobs.intel.device_arch or mod.parse_device_arch(target.arch.get('architecture', 0))
9898
self.properties = self.parse_target(target.arch)
99-
self.binary_ext = "spv"
99+
self.binary_ext = "zebin" if knobs.intel.gen_native_code else "spv"
100100

101101
def get_target_name(self, options) -> str:
102102
return f"xpu:{self.device_arch}"
@@ -363,6 +363,10 @@ def make_llir(src, metadata, options):
363363
def make_spv(src, metadata, options, device_arch):
364364
spirv, name = intel.translate_to_spirv(src)
365365
metadata["name"] = name
366+
return spirv
367+
368+
@staticmethod
369+
def make_zebin(src, metadata, options, device_arch):
366370
if options.grf_mode == 'small':
367371
metadata["build_flags"] = "-cl-intel-128-GRF-per-thread"
368372
elif options.grf_mode == 'large':
@@ -381,50 +385,48 @@ def make_spv(src, metadata, options, device_arch):
381385
if knobs.intel.dump_shader_info:
382386
# The IGC (Intel Graphic Compiler) only parses the options at first time in JIT-ing the binary per process.
383387
# Have to use the `ocloc` to generate the binary in sub-process to work around the limitation.
384-
assert options.generate_native_code, "Only support native code generation with shader dump"
385388
shader_dump_opt = f" -igc_opts ',DumpToCustomDir={metadata['cache_dir']},ShaderDumpEnable=1'"
386389

387390
metadata["generate_native_code"] = options.generate_native_code
388391

389-
if options.generate_native_code:
390-
with track("generate_native_code"), tempfile.TemporaryDirectory() as temp_dir:
391-
with tempfile.NamedTemporaryFile(mode='wb', suffix='.spv', dir=temp_dir, delete=False) as fsrc:
392-
fsrc.write(spirv)
393-
fbin = fsrc.name + '.o'
394-
395-
ocloc_cmd = [
396-
'ocloc', 'compile', '-file', fsrc.name, '-o', fbin, '-spirv_input', '-device', device_arch,
397-
'-options', metadata["build_flags"] + shader_dump_opt
398-
]
399-
400-
try:
401-
output = subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True)
402-
if 'spilled' in output and metadata["build_flags"].find("-cl-intel-256-GRF-per-thread") == -1:
403-
"""
404-
The exact message is something like:
405-
warning: kernel matmul_kernel compiled SIMD16 allocated 128 regs and spilled around 217
406-
is "spilled" enough for now?
407-
"""
408-
metadata["build_flags"] += " -cl-intel-256-GRF-per-thread"
409-
# re-run with new build flags
410-
ocloc_cmd[-1] = metadata["build_flags"] + shader_dump_opt
411-
subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True)
412-
except subprocess.CalledProcessError as e:
413-
if e.returncode == 255:
414-
error = 'Internal Triton ZEBIN codegen error'
415-
elif e.returncode == 128 + signal.SIGSEGV:
416-
error = '`ocloc` raised SIGSEGV'
417-
else:
418-
error = f'`ocloc` failed with error code {e.returncode}'
419-
420-
raise RuntimeError(f'{error}\n'
421-
f'`ocloc` stderr:\n{e.output}\n'
422-
f'Repro command: {ocloc_cmd}\n') from e
423-
424-
with open(fbin, 'rb') as f:
425-
zebin = f.read()
426-
return zebin
427-
return spirv
392+
with tempfile.TemporaryDirectory() as temp_dir:
393+
with track("generate_native_code"), tempfile.NamedTemporaryFile(mode='wb', suffix='.spv', dir=temp_dir,
394+
delete=False) as fsrc:
395+
fsrc.write(src)
396+
fbin = fsrc.name + '.o'
397+
398+
ocloc_cmd = [
399+
'ocloc', 'compile', '-file', fsrc.name, '-o', fbin, '-spirv_input', '-device', device_arch, '-options',
400+
metadata["build_flags"] + shader_dump_opt
401+
]
402+
403+
try:
404+
output = subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True)
405+
if 'spilled' in output and metadata["build_flags"].find("-cl-intel-256-GRF-per-thread") == -1:
406+
"""
407+
The exact message is something like:
408+
warning: kernel matmul_kernel compiled SIMD16 allocated 128 regs and spilled around 217
409+
is "spilled" enough for now?
410+
"""
411+
metadata["build_flags"] += " -cl-intel-256-GRF-per-thread"
412+
# re-run with new build flags
413+
ocloc_cmd[-1] = metadata["build_flags"] + shader_dump_opt
414+
subprocess.check_output(ocloc_cmd, stderr=subprocess.STDOUT, text=True)
415+
except subprocess.CalledProcessError as e:
416+
if e.returncode == 255:
417+
error = 'Internal Triton ZEBIN codegen error'
418+
elif e.returncode == 128 + signal.SIGSEGV:
419+
error = '`ocloc` raised SIGSEGV'
420+
else:
421+
error = f'`ocloc` failed with error code {e.returncode}'
422+
423+
raise RuntimeError(f'{error}\n'
424+
f'`ocloc` stderr:\n{e.output}\n'
425+
f'Repro command: {ocloc_cmd}\n') from e
426+
427+
with open(fbin, 'rb') as f:
428+
zebin = f.read()
429+
return zebin
428430

429431
def add_stages(self, stages, options, language):
430432
if language == Language.TRITON:
@@ -434,9 +436,12 @@ def add_stages(self, stages, options, language):
434436
stages["ttgir"] = lambda src, metadata: self.gluon_to_ttgir(src, metadata, options)
435437
stages["llir"] = lambda src, metadata: self.make_llir(src, metadata, options)
436438
stages["spv"] = lambda src, metadata: self.make_spv(src, metadata, options, self.device_arch)
439+
if options.generate_native_code:
440+
stages["zebin"] = lambda src, metadata: self.make_zebin(src, metadata, options, self.device_arch)
437441
if knobs.runtime.add_stages_inspection_hook is not None:
438442
knobs.runtime.add_stages_inspection_hook(self, stages, options, language, None)
439443

444+
440445
@functools.lru_cache()
441446
def hash(self):
442447
return f'SPIR-V 1.5-{self.properties}'

0 commit comments

Comments
 (0)