Skip to content

Commit 411077c

Browse files
jpsamarookpamnany
andcommitted
Add threadpool support to runtime
Adds support for Julia to be started with `--threads=auto|N[,M]` where `N` specifies the number of threads in the default threadpool and `M`, if provided, specifies the number of threads in the new interactive threadpool. Adds an optional first parameter to `Threads.@spawn`: `[:default|:interactive]`. If `:interactive` is specified, the task will be run by thread(s) in the interactive threadpool only (if there is one). Co-authored-by: K Pamnany <[email protected]>
1 parent 3cff21e commit 411077c

20 files changed

+425
-96
lines changed

base/options.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
22

3-
# NOTE: This type needs to be kept in sync with jl_options in src/julia.h
3+
# NOTE: This type needs to be kept in sync with jl_options in src/jloptions.h
44
struct JLOptions
55
quiet::Int8
66
banner::Int8
@@ -9,7 +9,9 @@ struct JLOptions
99
commands::Ptr{Ptr{UInt8}} # (e)eval, (E)print, (L)load
1010
image_file::Ptr{UInt8}
1111
cpu_target::Ptr{UInt8}
12-
nthreads::Int32
12+
nthreadpools::Int16
13+
nthreads::Int16
14+
nthreads_per_pool::Ptr{Int16}
1315
nprocs::Int32
1416
machine_file::Ptr{UInt8}
1517
project::Ptr{UInt8}

base/partr.jl

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
module Partr
44

5-
using ..Threads: SpinLock, nthreads
5+
using ..Threads: SpinLock, nthreads, threadid
66

77
# a task minheap
88
mutable struct taskheap
@@ -16,12 +16,13 @@ end
1616

1717
# multiqueue minheap state
1818
const heap_d = UInt32(8)
19-
global heaps::Vector{taskheap} = Vector{taskheap}(undef, 0)
20-
const heaps_lock = SpinLock()
21-
global cong_unbias::UInt32 = typemax(UInt32)
19+
const heaps = [Vector{taskheap}(undef, 0), Vector{taskheap}(undef, 0)]
20+
const heaps_lock = [SpinLock(), SpinLock()]
21+
const cong_unbias = [typemax(UInt32), typemax(UInt32)]
2222

2323

24-
cong(max::UInt32, unbias::UInt32) = ccall(:jl_rand_ptls, UInt32, (UInt32, UInt32), max, unbias) + UInt32(1)
24+
cong(max::UInt32, unbias::UInt32) =
25+
ccall(:jl_rand_ptls, UInt32, (UInt32, UInt32), max, unbias) + UInt32(1)
2526

2627
function unbias_cong(max::UInt32)
2728
return typemax(UInt32) - ((typemax(UInt32) % max) + UInt32(1))
@@ -60,46 +61,52 @@ function multiq_sift_down(heap::taskheap, idx::Int32)
6061
end
6162

6263

63-
function multiq_size()
64+
function multiq_size(tpid::Int8)
65+
nt = UInt32(Threads._nthreads_in_pool(tpid))
66+
tp = tpid + 1
67+
tpheaps = heaps[tp]
6468
heap_c = UInt32(2)
65-
heap_p = UInt32(length(heaps))
66-
nt = UInt32(nthreads())
69+
heap_p = UInt32(length(tpheaps))
6770

6871
if heap_c * nt <= heap_p
6972
return heap_p
7073
end
7174

72-
@lock heaps_lock begin
73-
heap_p = UInt32(length(heaps))
74-
nt = UInt32(nthreads())
75+
@lock heaps_lock[tp] begin
76+
heap_p = UInt32(length(tpheaps))
77+
nt = UInt32(Threads._nthreads_in_pool(tpid))
7578
if heap_c * nt <= heap_p
7679
return heap_p
7780
end
7881

7982
heap_p += heap_c * nt
8083
newheaps = Vector{taskheap}(undef, heap_p)
81-
copyto!(newheaps, heaps)
82-
for i = (1 + length(heaps)):heap_p
84+
copyto!(newheaps, tpheaps)
85+
for i = (1 + length(tpheaps)):heap_p
8386
newheaps[i] = taskheap()
8487
end
85-
global heaps = newheaps
86-
global cong_unbias = unbias_cong(heap_p)
88+
heaps[tp] = newheaps
89+
cong_unbias[tp] = unbias_cong(heap_p)
8790
end
8891

8992
return heap_p
9093
end
9194

9295

9396
function multiq_insert(task::Task, priority::UInt16)
97+
tpid = ccall(:jl_get_task_threadpoolid, Int8, (Any,), task)
98+
heap_p = multiq_size(tpid)
99+
tp = tpid + 1
100+
94101
task.priority = priority
95102

96-
heap_p = multiq_size()
97-
rn = cong(heap_p, cong_unbias)
98-
while !trylock(heaps[rn].lock)
99-
rn = cong(heap_p, cong_unbias)
103+
rn = cong(heap_p, cong_unbias[tp])
104+
tpheaps = heaps[tp]
105+
while !trylock(tpheaps[rn].lock)
106+
rn = cong(heap_p, cong_unbias[tp])
100107
end
101108

102-
heap = heaps[rn]
109+
heap = tpheaps[rn]
103110
if heap.ntasks >= length(heap.tasks)
104111
resize!(heap.tasks, length(heap.tasks) * 2)
105112
end
@@ -122,34 +129,37 @@ function multiq_deletemin()
122129
local rn1, rn2
123130
local prio1, prio2
124131

132+
tid = Threads.threadid()
133+
tp = ccall(:jl_threadpoolid, Int8, (Int16,), tid-1) + 1
134+
tpheaps = heaps[tp]
135+
125136
@label retry
126137
GC.safepoint()
127-
heap_p = UInt32(length(heaps))
138+
heap_p = UInt32(length(tpheaps))
128139
for i = UInt32(0):heap_p
129140
if i == heap_p
130141
return nothing
131142
end
132-
rn1 = cong(heap_p, cong_unbias)
133-
rn2 = cong(heap_p, cong_unbias)
134-
prio1 = heaps[rn1].priority
135-
prio2 = heaps[rn2].priority
143+
rn1 = cong(heap_p, cong_unbias[tp])
144+
rn2 = cong(heap_p, cong_unbias[tp])
145+
prio1 = tpheaps[rn1].priority
146+
prio2 = tpheaps[rn2].priority
136147
if prio1 > prio2
137148
prio1 = prio2
138149
rn1 = rn2
139150
elseif prio1 == prio2 && prio1 == typemax(UInt16)
140151
continue
141152
end
142-
if trylock(heaps[rn1].lock)
143-
if prio1 == heaps[rn1].priority
153+
if trylock(tpheaps[rn1].lock)
154+
if prio1 == tpheaps[rn1].priority
144155
break
145156
end
146-
unlock(heaps[rn1].lock)
157+
unlock(tpheaps[rn1].lock)
147158
end
148159
end
149160

150-
heap = heaps[rn1]
161+
heap = tpheaps[rn1]
151162
task = heap.tasks[1]
152-
tid = Threads.threadid()
153163
if ccall(:jl_set_task_tid, Cint, (Any, Cint), task, tid-1) == 0
154164
unlock(heap.lock)
155165
@goto retry
@@ -171,9 +181,11 @@ end
171181

172182

173183
function multiq_check_empty()
174-
for i = UInt32(1):length(heaps)
175-
if heaps[i].ntasks != 0
176-
return false
184+
for j = UInt32(1):length(heaps)
185+
for i = UInt32(1):length(heaps[j])
186+
if heaps[j][i].ntasks != 0
187+
return false
188+
end
177189
end
178190
end
179191
return true

base/task.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,10 @@ true
251251
istaskfailed(t::Task) = (load_state_acquire(t) === task_state_failed)
252252

253253
Threads.threadid(t::Task) = Int(ccall(:jl_get_task_tid, Int16, (Any,), t)+1)
254+
function Threads.threadpool(t::Task)
255+
tpid = ccall(:jl_get_task_threadpoolid, Int8, (Any,), t)
256+
return tpid == 0 ? :default : :interactive
257+
end
254258

255259
task_result(t::Task) = t.result
256260

base/threadcall.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ const threadcall_restrictor = Semaphore(max_ccall_threads)
99
1010
The `@threadcall` macro is called in the same way as [`ccall`](@ref) but does the work
1111
in a different thread. This is useful when you want to call a blocking C
12-
function without causing the main `julia` thread to become blocked. Concurrency
12+
function without causing the current `julia` thread to become blocked. Concurrency
1313
is limited by size of the libuv thread pool, which defaults to 4 threads but
1414
can be increased by setting the `UV_THREADPOOL_SIZE` environment variable and
1515
restarting the `julia` process.

base/threadingconstructs.jl

Lines changed: 84 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,62 @@
11
# This file is a part of Julia. License is MIT: https://julialang.org/license
22

3-
export threadid, nthreads, @threads, @spawn
3+
export threadid, nthreads, @threads, @spawn,
4+
threadpool, nthreadpools
45

56
"""
6-
Threads.threadid()
7+
Threads.threadid() -> Int
78
8-
Get the ID number of the current thread of execution. The master thread has ID `1`.
9+
Get the ID number of the current thread of execution. The master thread has
10+
ID `1`.
911
"""
1012
threadid() = Int(ccall(:jl_threadid, Int16, ())+1)
1113

12-
# Inclusive upper bound on threadid()
1314
"""
14-
Threads.nthreads()
15+
Threads.nthreads([:default|:interactive]) -> Int
1516
16-
Get the number of threads available to the Julia process. This is the inclusive upper bound
17-
on [`threadid()`](@ref).
17+
Get the number of threads (across all thread pools or within the specified
18+
thread pool) available to Julia. The number of threads across all thread
19+
pools is the inclusive upper bound on [`threadid()`](@ref).
1820
1921
See also: `BLAS.get_num_threads` and `BLAS.set_num_threads` in the
2022
[`LinearAlgebra`](@ref man-linalg) standard library, and `nprocs()` in the
2123
[`Distributed`](@ref man-distributed) standard library.
2224
"""
25+
function nthreads end
26+
2327
nthreads() = Int(unsafe_load(cglobal(:jl_n_threads, Cint)))
28+
function nthreads(pool::Symbol)
29+
if pool == :default
30+
tpid = Int8(0)
31+
elseif pool == :interactive
32+
tpid = Int8(1)
33+
else
34+
error("invalid threadpool specified")
35+
end
36+
return _nthreads_in_pool(tpid)
37+
end
38+
function _nthreads_in_pool(tpid::Int8)
39+
p = unsafe_load(cglobal(:jl_n_threads_per_pool, Ptr{Cint}))
40+
return Int(unsafe_load(p, tpid + 1))
41+
end
42+
43+
"""
44+
Threads.threadpool(tid = threadid()) -> Symbol
45+
46+
Returns the specified thread's threadpool; either `:default` or `:interactive`.
47+
"""
48+
function threadpool(tid = threadid())
49+
tpid = ccall(:jl_threadpoolid, Int8, (Int16,), tid-1)
50+
return tpid == 0 ? :default : :interactive
51+
end
52+
53+
"""
54+
Threads.nthreadpools() -> Int
55+
56+
Returns the number of threadpools currently configured.
57+
"""
58+
nthreadpools() = Int(unsafe_load(cglobal(:jl_n_threadpools, Cint)))
59+
2460

2561
function threading_run(fun, static)
2662
ccall(:jl_enter_threaded_region, Cvoid, ())
@@ -48,7 +84,7 @@ function _threadsfor(iter, lbody, schedule)
4884
quote
4985
local threadsfor_fun
5086
let range = $(esc(range))
51-
function threadsfor_fun(tid=1; onethread=false)
87+
function threadsfor_fun(tid = 1; onethread = false)
5288
r = range # Load into local variable
5389
lenr = length(r)
5490
# divide loop iterations among threads
@@ -232,35 +268,63 @@ macro threads(args...)
232268
end
233269

234270
"""
235-
Threads.@spawn expr
271+
Threads.@spawn [:default|:interactive] expr
236272
237-
Create a [`Task`](@ref) and [`schedule`](@ref) it to run on any available thread.
238-
The task is allocated to a thread after it becomes available. To wait for the task
239-
to finish, call [`wait`](@ref) on the result of this macro, or call [`fetch`](@ref) to
240-
wait and then obtain its return value.
273+
Create a [`Task`](@ref) and [`schedule`](@ref) it to run on any available
274+
thread in the specified threadpool (`:default` if unspecified). The task is
275+
allocated to a thread once one becomes available. To wait for the task to
276+
finish, call [`wait`](@ref) on the result of this macro, or call
277+
[`fetch`](@ref) to wait and then obtain its return value.
241278
242-
Values can be interpolated into `@spawn` via `\$`, which copies the value directly into the
243-
constructed underlying closure. This allows you to insert the _value_ of a variable,
244-
isolating the asynchronous code from changes to the variable's value in the current task.
279+
Values can be interpolated into `@spawn` via `\$`, which copies the value
280+
directly into the constructed underlying closure. This allows you to insert
281+
the _value_ of a variable, isolating the asynchronous code from changes to
282+
the variable's value in the current task.
245283
246284
!!! note
247-
See the manual chapter on threading for important caveats.
285+
See the manual chapter on [multi-threading](@ref man-multithreading)
286+
for important caveats. See also the chapter on [threadpools](@ref man-threadpools).
248287
249288
!!! compat "Julia 1.3"
250289
This macro is available as of Julia 1.3.
251290
252291
!!! compat "Julia 1.4"
253292
Interpolating values via `\$` is available as of Julia 1.4.
293+
294+
!!! compat "Julia 1.9"
295+
A threadpool may be specified as of Julia 1.9.
254296
"""
255-
macro spawn(expr)
256-
letargs = Base._lift_one_interp!(expr)
297+
macro spawn(args...)
298+
tpid = Int8(0)
299+
na = length(args)
300+
if na == 2
301+
ttype, ex = args
302+
if ttype isa QuoteNode
303+
ttype = ttype.value
304+
elseif ttype isa Symbol
305+
# TODO: allow unquoted symbols
306+
ttype = nothing
307+
end
308+
if ttype === :interactive
309+
tpid = Int8(1)
310+
elseif ttype !== :default
311+
throw(ArgumentError("unsupported threadpool in @spawn: $ttype"))
312+
end
313+
elseif na == 1
314+
ex = args[1]
315+
else
316+
throw(ArgumentError("wrong number of arguments in @spawn"))
317+
end
318+
319+
letargs = Base._lift_one_interp!(ex)
257320

258-
thunk = esc(:(()->($expr)))
321+
thunk = esc(:(()->($ex)))
259322
var = esc(Base.sync_varname)
260323
quote
261324
let $(letargs...)
262325
local task = Task($thunk)
263326
task.sticky = false
327+
ccall(:jl_set_task_threadpoolid, Cint, (Any, Int8), task, $tpid)
264328
if $(Expr(:islocal, var))
265329
put!($var, task)
266330
end

doc/src/base/multi-threading.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ Base.Threads.foreach
66
Base.Threads.@spawn
77
Base.Threads.threadid
88
Base.Threads.nthreads
9+
Base.Threads.threadpool
10+
Base.Threads.nthreadpools
911
```
1012

1113
See also [Multi-Threading](@ref man-multithreading).
@@ -49,7 +51,7 @@ Base.Threads.atomic_min!
4951
Base.Threads.atomic_fence
5052
```
5153

52-
## ccall using a threadpool (Experimental)
54+
## ccall using a libuv threadpool (Experimental)
5355

5456
```@docs
5557
Base.@threadcall

0 commit comments

Comments
 (0)