Skip to content

Commit 8b28b4d

Browse files
authored
Merge pull request #43 from GunnarFarneback/do_not_get_extension
Stop using Base.get_extension.
2 parents b1b81be + f1a668d commit 8b28b4d

File tree

4 files changed

+26
-27
lines changed

4 files changed

+26
-27
lines changed

.github/workflows/CI.yml

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,33 +20,24 @@ jobs:
2020
arch:
2121
- x64
2222
steps:
23-
- uses: actions/checkout@v2
24-
- uses: julia-actions/setup-julia@v1
23+
- uses: actions/checkout@v4
24+
- uses: julia-actions/setup-julia@v2
2525
with:
2626
version: ${{ matrix.version }}
2727
arch: ${{ matrix.arch }}
28-
- uses: actions/cache@v1
29-
env:
30-
cache-name: cache-artifacts
31-
with:
32-
path: ~/.julia/artifacts
33-
key: ${{ runner.os }}-test-${{ env.cache-name }}-${{ hashFiles('**/Project.toml') }}
34-
restore-keys: |
35-
${{ runner.os }}-test-${{ env.cache-name }}-
36-
${{ runner.os }}-test-
37-
${{ runner.os }}-
28+
- uses: julia-actions/cache@v2
3829
- uses: julia-actions/julia-buildpkg@v1
3930
- uses: julia-actions/julia-runtest@v1
4031
- uses: julia-actions/julia-processcoverage@v1
41-
- uses: codecov/codecov-action@v1
32+
- uses: codecov/codecov-action@v5
4233
with:
4334
file: lcov.info
4435
docs:
4536
name: Documentation
4637
runs-on: ubuntu-latest
4738
steps:
48-
- uses: actions/checkout@v2
49-
- uses: julia-actions/setup-julia@v1
39+
- uses: actions/checkout@v4
40+
- uses: julia-actions/setup-julia@v2
5041
with:
5142
version: '1'
5243
- run: |

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ONNXRunTime"
22
uuid = "e034b28e-924e-41b2-b98f-d2bbeb830c6a"
33
authors = ["Jan Weidner <[email protected]> and contributors"]
4-
version = "1.3.1"
4+
version = "1.3.2"
55

66
[deps]
77
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"

ext/CUDAExt.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
module CUDAExt
2+
import ONNXRunTime
3+
import CUDA
24

3-
# These functions are only defined for diagnostic purposes. Otherwise
5+
# These calls are only being made for diagnostic purposes. Otherwise
46
# the CUDA extension only relies on the CUDA and cuDNN dependencies to
57
# have loaded the libraries needed by ONNXRunTime's CUDA execution
68
# provider.
7-
import CUDA
8-
cuda_functional() = CUDA.functional()
9-
cuda_runtime_version() = CUDA.runtime_version()
9+
function __init__()
10+
ONNXRunTime.cuda_is_loaded[] = true
11+
ONNXRunTime.cuda_is_functional[] = CUDA.functional()
12+
ONNXRunTime.cuda_runtime_version[] = CUDA.runtime_version()
13+
end
1014

1115
end

src/highlevel.jl

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,16 @@ using .CAPI
1515
using .CAPI: juliatype, EXECUTION_PROVIDERS
1616
export InferenceSession, load_inference, release
1717

18+
# Interaction point with the CUDA extension. These values will be
19+
# updated by the extension when it is loaded.
20+
const cuda_is_loaded = Ref(false)
21+
const cuda_is_functional = Ref(false)
22+
const cuda_runtime_version = Ref(v"0.0.0")
23+
1824
"""
1925
$TYPEDEF
2026
21-
Represents an infernence session. Should only be created by calling [`load_inference`](@ref).
27+
Represents an inference session. Should only be created by calling [`load_inference`](@ref).
2228
"""
2329
struct InferenceSession
2430
api::OrtApi
@@ -99,20 +105,18 @@ function load_inference(path::AbstractString; execution_provider::Symbol=:cpu,
99105
end
100106
session_options = CreateSessionOptions(api)
101107
elseif execution_provider === :cuda
102-
CUDAExt = Base.get_extension(@__MODULE__, :CUDAExt)
103-
if isnothing(CUDAExt)
108+
if !cuda_is_loaded[]
104109
error("""
105110
The $(repr(execution_provider)) execution provider requires the CUDA.jl and cuDNN.jl packages to be available. Try adding `import CUDA, cuDNN` to your code.
106111
""")
107-
elseif !getfield(CUDAExt, :cuda_functional)()
112+
elseif !cuda_is_functional[]
108113
error("""
109114
The $(repr(execution_provider)) execution provider requires CUDA to be functional. See `CUDA.functional`.
110115
""")
111116
else
112-
cuda_runtime_version = getfield(CUDAExt, :cuda_runtime_version)()
113-
if !(cuda_runtime_supported_version <= cuda_runtime_version < cuda_runtime_upper_bound)
117+
if !(cuda_runtime_supported_version <= cuda_runtime_version[] < cuda_runtime_upper_bound)
114118
error("""
115-
Found CUDA runtime version $(cuda_runtime_version). The $(repr(execution_provider)) execution provider requires a CUDA runtime version of at least $(cuda_runtime_supported_version) but less than $(cuda_runtime_upper_bound). See `CUDA.set_runtime_version!` and the package README.
119+
Found CUDA runtime version $(cuda_runtime_version[]). The $(repr(execution_provider)) execution provider requires a CUDA runtime version of at least $(cuda_runtime_supported_version) but less than $(cuda_runtime_upper_bound). See `CUDA.set_runtime_version!` and the package README.
116120
""")
117121
end
118122
end

0 commit comments

Comments
 (0)