Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/test/Examples/NVGPU/Ch5.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def producer_loop(
):
phase = const(True, ty=T.bool())

for iv, phase in scf.for_(0, (K // TILE_K), 1, [phase]):
for iv, phase, _ in scf.for_(0, (K // TILE_K), 1, [phase]):
stage = iv % num_stages
# Wait MMA to be done
mbar_mma[stage].try_wait(phase)
Expand Down
7 changes: 3 additions & 4 deletions mlir/test/Examples/NVGPU/tools/nvdsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ def arrive(self, txcount: int = 0, predicate=None):
self.mbar_group_op, txcount_op, self.id_op, predicate=predicate
)
else:
nvgpu.mbarrier_arrive(
ir.Type.parse("!nvgpu.mbarrier.token"), self.mbar_group_op, self.id_op
nvgpu.mbarrier_arrive(self.mbar_group_op, self.id_op
)

def try_wait(self, phase: bool = False, ticks: int = 10000000):
Expand Down Expand Up @@ -144,7 +143,7 @@ def create_descriptor(self, device_ptr):
device_ptr,
)
self.tma_descriptor = nvgpu.TmaCreateDescriptorOp(
tma_descriptor_ty, device_unranked_memref, map(const, self.tma_box_shape)
tma_descriptor_ty, device_unranked_memref, list(map(const, self.tma_box_shape))
)
return self.tma_descriptor.result

Expand All @@ -156,7 +155,7 @@ def load(self, dest, mbarrier: Mbarriers, coords=[0], predicate=None):
dest,
mbarrier.mbar_group_op,
self.tma_descriptor,
coordinates=map(const, coords),
coordinates=list(map(const, coords)),
mbarId=mbarrier.id_op,
predicate=predicate,
)
Expand Down
4 changes: 3 additions & 1 deletion mlir/test/Examples/NVGPU/tools/nvgpucompiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ def compile(self, module: ir.Module):

def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
"""Wraps the module in a JIT execution engine."""
return execution_engine.ExecutionEngine(
ee = execution_engine.ExecutionEngine(
module, opt_level=self.opt_level, shared_libs=self.shared_libs
)
ee.initialize()
return ee

def compile_and_jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
"""Compiles and jits the module."""
Expand Down