Skip to content

Commit f864bc0

Browse files
committed
Use stacked method tables
1 parent 5519c83 commit f864bc0

File tree

4 files changed

+35
-1
lines changed

4 files changed

+35
-1
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,3 +34,6 @@ SPIRV_LLVM_Backend_jll = "20"
3434
SPIRV_Tools_jll = "2025.1"
3535
StaticArrays = "1"
3636
julia = "1.10"
37+
38+
[sources]
39+
GPUCompiler = {url="https://github.com/JuliaGPU/GPUCompiler.jl", rev="vc/mtv"}

lib/intrinsics/src/SPIRVIntrinsics.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ macro import_all()
3232
# bring all the names of this module in scope
3333
name in (:SPIRVIntrinsics, :eval, :include) && continue
3434
startswith(string(name), "#") && continue
35+
name in (:method_table, :@device_function, :@device_override) && continue
3536
push!(code.args, :(using .SPIRVIntrinsics: $name))
3637
end
3738

src/OpenCL.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,40 @@ include("../lib/cl/CL.jl")
1515
@reexport using .cl
1616
export cl
1717

18+
## device overrides
19+
20+
# local method table for device functions
21+
Base.Experimental.@MethodTable(method_table)
22+
23+
macro device_override(ex)
24+
esc(quote
25+
Base.Experimental.@overlay($method_table, $ex)
26+
end)
27+
end
28+
29+
macro device_function(ex)
30+
ex = macroexpand(__module__, ex)
31+
def = ExprTools.splitdef(ex)
32+
33+
# generate a function that errors
34+
def[:body] = quote
35+
error("This function is not intended for use on the CPU")
36+
end
37+
38+
esc(quote
39+
$(ExprTools.combinedef(def))
40+
@device_override $ex
41+
end)
42+
end
43+
44+
1845
# device functionality
1946
import SPIRVIntrinsics
2047
SPIRVIntrinsics.@import_all
2148
SPIRVIntrinsics.@reexport_public
49+
50+
const spirv_method_table = SPIRVIntrinsics.method_table
51+
2252
include("device/runtime.jl")
2353
include("device/array.jl")
2454
include("device/quirks.jl")

src/compiler/compilation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ const OpenCLCompilerJob = CompilerJob{SPIRVCompilerTarget,OpenCLCompilerParams}
66

77
GPUCompiler.runtime_module(::CompilerJob{<:Any,OpenCLCompilerParams}) = OpenCL
88

9-
GPUCompiler.method_table(::OpenCLCompilerJob) = method_table
9+
GPUCompiler.method_table_view(job::OpenCLCompilerJob) = GPUCompiler.StackedMethodTable(job.world, method_table, spirv_method_table)
1010

1111
# filter out OpenCL built-ins
1212
# TODO: eagerly lower these using the translator API

0 commit comments

Comments
 (0)