Skip to content

Commit 02bed9c

Browse files
authored
Add Fiber::ExecutionContext::Parallel#resize (#15956)
Allows to dynamically resize a Parallel execution context to a new maximum parallelism. **Parallelism can grow**: merely adds more schedulers, which may eventually start more threads. **Parallelism can shrink**: immediately removes any overflow schedulers and tell them to shut down. The actual shutdown is cooperative, a scheduler won't notice the shutdown until its current fiber is rescheduled.
1 parent 3ec8169 commit 02bed9c

File tree

5 files changed

+207
-23
lines changed

5 files changed

+207
-23
lines changed

spec/std/fiber/execution_context/parallel_spec.cr

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
{% skip_file unless flag?(:execution_context) %}
22
require "spec"
3+
require "wait_group"
34

45
describe Fiber::ExecutionContext::Parallel do
56
it ".new" do
@@ -41,4 +42,43 @@ describe Fiber::ExecutionContext::Parallel do
4142
Fiber::ExecutionContext::Parallel.new("test", size: 5..1)
4243
end
4344
end
45+
46+
it "#resize" do
47+
ctx = Fiber::ExecutionContext::Parallel.new("ctx", 1)
48+
running = Atomic(Bool).new(true)
49+
wg = WaitGroup.new
50+
51+
10.times do
52+
wg.add(1)
53+
54+
ctx.spawn do
55+
while running.get(:relaxed)
56+
sleep(10.microseconds)
57+
end
58+
ensure
59+
wg.done
60+
end
61+
end
62+
63+
# it grows
64+
ctx.resize(4)
65+
ctx.capacity.should eq(4)
66+
67+
# it shrinks
68+
ctx.resize(2)
69+
ctx.capacity.should eq(2)
70+
71+
# it doesn't change
72+
ctx.resize(2)
73+
ctx.capacity.should eq(2)
74+
75+
10.times do
76+
n = rand(1..4)
77+
ctx.resize(n)
78+
ctx.capacity.should eq(n)
79+
end
80+
81+
running.set(false)
82+
wg.wait
83+
end
4484
end

spec/std/fiber/execution_context/runnables_spec.cr

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,32 @@ describe Fiber::ExecutionContext::Runnables do
7171
end
7272
end
7373

74+
describe "#drain" do
75+
it "drains the local queue into the global queue" do
76+
fibers = 6.times.map { |i| new_fake_fiber("f#{i}") }.to_a
77+
78+
# local enqueue + overflow
79+
g = Fiber::ExecutionContext::GlobalQueue.new(Thread::Mutex.new)
80+
r = Fiber::ExecutionContext::Runnables(6).new(g)
81+
82+
# empty
83+
r.drain
84+
g.size.should eq(0)
85+
86+
# full
87+
fibers.each { |f| r.push(f) }
88+
r.drain
89+
r.shift?.should be_nil
90+
g.size.should eq(6)
91+
92+
# refill half (1 pop + 2 grab) and drain again
93+
g.unsafe_grab?(r, divisor: 1)
94+
r.drain
95+
r.shift?.should be_nil
96+
g.size.should eq(5)
97+
end
98+
end
99+
74100
describe "#bulk_push" do
75101
it "fills the local queue" do
76102
l = Fiber::List.new

src/fiber/execution_context/parallel.cr

Lines changed: 75 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,6 @@ module Fiber::ExecutionContext
6969

7070
@parked = Atomic(Int32).new(0)
7171
@spinning = Atomic(Int32).new(0)
72-
@capacity : Int32
7372

7473
# :nodoc:
7574
protected def self.default(maximum : Int32) : self
@@ -102,12 +101,12 @@ module Fiber::ExecutionContext
102101
@condition = Thread::ConditionVariable.new
103102

104103
@global_queue = GlobalQueue.new(@mutex)
105-
@schedulers = Array(Scheduler).new(@capacity)
106-
@threads = Array(Thread).new(@capacity)
104+
@schedulers = Array(Scheduler).new(capacity)
105+
@threads = Array(Thread).new(capacity)
107106

108107
@rng = Random::PCG32.new
109108

110-
start_schedulers
109+
start_schedulers(capacity)
111110
@threads << hijack_current_thread(@schedulers.first) if hijack
112111

113112
ExecutionContext.execution_contexts.push(self)
@@ -120,7 +119,7 @@ module Fiber::ExecutionContext
120119

121120
# The maximum number of threads that can be started.
122121
def capacity : Int32
123-
@capacity
122+
@schedulers.size
124123
end
125124

126125
# :nodoc:
@@ -140,7 +139,7 @@ module Fiber::ExecutionContext
140139
# OPTIMIZE: consider storing schedulers to an array-like object that would
141140
# use an atomic/fence to make sure that @size can only be incremented
142141
# *after* the value has been written to @buffer.
143-
private def start_schedulers
142+
private def start_schedulers(capacity)
144143
capacity.times do |index|
145144
@schedulers << Scheduler.new(self, "#{@name}-#{index}")
146145
end
@@ -176,6 +175,71 @@ module Fiber::ExecutionContext
176175
end
177176
end
178177

178+
# Resizes the context to the new *maximum* parallelism.
179+
#
180+
# The new *maximum* can grow, in which case more schedulers are created to
181+
# eventually increase the parallelism.
182+
#
183+
# The new *maximum* can also shrink, in which case the overflow schedulers
184+
# are removed and told to shutdown immediately. The actual shutdown is
185+
# cooperative, so running schedulers won't stop until their current fiber
186+
# tries to switch to another fiber.
187+
def resize(maximum : Int32) : Nil
188+
raise ArgumentError.new("Parallelism can't be less than one.") if maximum < 1
189+
removed_schedulers = nil
190+
191+
@mutex.synchronize do
192+
# can run in parallel to #steal that dereferences @schedulers (once)
193+
# without locking the mutex, so we dup the schedulers, mutate the copy,
194+
# and eventually assign the copy as @schedulers; this way #steal can
195+
# safely access the array (never mutated).
196+
new_capacity = maximum
197+
old_threads = @threads
198+
old_schedulers = @schedulers
199+
old_capacity = capacity
200+
201+
if new_capacity > old_capacity
202+
@schedulers = Array(Scheduler).new(new_capacity) do |index|
203+
old_schedulers[index]? || Scheduler.new(self, "#{@name}-#{index}")
204+
end
205+
threads = Array(Thread).new(new_capacity)
206+
old_threads.each { |thread| threads << thread }
207+
@threads = threads
208+
elsif new_capacity < old_capacity
209+
# tell the overflow schedulers to shutdown
210+
removed_schedulers = old_schedulers[new_capacity..]
211+
removed_schedulers.each(&.shutdown!)
212+
213+
# resize
214+
@schedulers = old_schedulers[0...new_capacity]
215+
@threads = old_threads[0...new_capacity]
216+
217+
# reset @parked counter (we wake all parked threads) so they can
218+
# shutdown (if told to):
219+
woken_threads = @parked.get(:relaxed)
220+
@parked.set(0, :relaxed)
221+
222+
# update @spinning prior to unpark threads; we use acquire release
223+
# semantics to make sure that all the above stores are visible before
224+
# the following wakeup calls (maybe not needed, but let's err on the
225+
# safe side)
226+
@spinning.add(woken_threads, :acquire_release)
227+
228+
# wake every waiting thread:
229+
@condition.broadcast
230+
@event_loop.interrupt
231+
end
232+
end
233+
234+
return unless removed_schedulers
235+
236+
# drain the local queues of removed schedulers since they're no longer
237+
# available for stealing
238+
removed_schedulers.each do |scheduler|
239+
scheduler.@runnables.drain
240+
end
241+
end
242+
179243
# :nodoc:
180244
def spawn(*, name : String? = nil, same_thread : Bool, &block : ->) : Fiber
181245
raise ArgumentError.new("#{self.class.name}#spawn doesn't support same_thread:true") if same_thread
@@ -200,11 +264,12 @@ module Fiber::ExecutionContext
200264
protected def steal(& : Scheduler ->) : Nil
201265
return if capacity == 1
202266

267+
schedulers = @schedulers
203268
i = @rng.next_int
204-
n = @schedulers.size
269+
n = schedulers.size
205270

206271
n.times do |j|
207-
if scheduler = @schedulers[(i &+ j) % n]?
272+
if scheduler = schedulers[(i &+ j) % n]?
208273
yield scheduler
209274
end
210275
end
@@ -282,11 +347,11 @@ module Fiber::ExecutionContext
282347
# check if we can start another thread; no need for atomics, the values
283348
# shall be rather stable over time and we check them again inside the
284349
# mutex
285-
return if @threads.size == capacity
350+
return if @threads.size >= capacity
286351

287352
@mutex.synchronize do
288353
index = @threads.size
289-
return if index == capacity # check again
354+
return if index >= capacity # check again
290355

291356
@threads << start_thread(@schedulers[index])
292357
end

src/fiber/execution_context/parallel/scheduler.cr

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,18 @@ module Fiber::ExecutionContext
2929
@spinning = false
3030
@waiting = false
3131
@parked = false
32+
@shutdown = false
3233

3334
protected def initialize(@execution_context, @name)
3435
@global_queue = @execution_context.global_queue
3536
@runnables = Runnables(256).new(@global_queue)
3637
@event_loop = @execution_context.event_loop
3738
end
3839

40+
protected def shutdown! : Nil
41+
@shutdown = true
42+
end
43+
3944
# :nodoc:
4045
def spawn(*, name : String? = nil, same_thread : Bool, &block : ->) : Fiber
4146
raise RuntimeError.new("#{self.class.name}#spawn doesn't support same_thread:true") if same_thread
@@ -86,6 +91,8 @@ module Fiber::ExecutionContext
8691
end
8792

8893
private def quick_dequeue? : Fiber?
94+
return if @shutdown
95+
8996
# every once in a while: dequeue from global queue to avoid two fibers
9097
# constantly respawing each other to completely occupy the local queue
9198
if (@tick &+= 1) % 61 == 0
@@ -121,8 +128,21 @@ module Fiber::ExecutionContext
121128
Crystal.trace :sched, "started"
122129

123130
loop do
131+
if @shutdown
132+
spin_stop
133+
@runnables.drain
134+
135+
# we may have been the last running scheduler, waiting on the event
136+
# loop while there are pending events for example; let's resume a
137+
# scheduler to take our place
138+
@execution_context.wake_scheduler
139+
140+
Crystal.trace :sched, "shutdown"
141+
break
142+
end
143+
124144
if fiber = find_next_runnable
125-
spin_stop if @spinning
145+
spin_stop
126146
resume fiber
127147
else
128148
# the event loop enqueued a fiber (or was interrupted) or the
@@ -145,6 +165,8 @@ module Fiber::ExecutionContext
145165

146166
# nothing to do: start spinning
147167
spinning do
168+
return if @shutdown
169+
148170
yield @global_queue.grab?(@runnables, divisor: @execution_context.size)
149171

150172
if @execution_context.lock_evloop? { @event_loop.run(pointerof(list), blocking: false) }
@@ -189,10 +211,12 @@ module Fiber::ExecutionContext
189211
# loop: park the thread until another scheduler or another context
190212
# enqueues a fiber
191213
@execution_context.park_thread do
214+
# don't park the thread when told to shutdown
215+
return if @shutdown
216+
192217
# by the time we acquire the lock, another thread may have enqueued
193218
# fiber(s) and already tried to wakeup a thread (race) so we must
194219
# check again; we don't check the scheduler's local queue (it's empty)
195-
196220
yield @global_queue.unsafe_grab?(@runnables, divisor: @execution_context.size)
197221
yield try_steal?
198222

src/fiber/execution_context/runnables.cr

Lines changed: 40 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -75,25 +75,54 @@ module Fiber::ExecutionContext
7575

7676
# first, try to grab half of the fibers from local queue
7777
batch = uninitialized Fiber[N] # actually N // 2 + 1 but that doesn't compile
78-
n.times do |i|
79-
batch.to_unsafe[i] = @buffer.to_unsafe[(head &+ i) % N]
80-
end
81-
_, success = @head.compare_and_set(head, head &+ n, :acquire_release, :acquire)
78+
_, success = try_grab(batch.to_unsafe, head, n)
8279
return false unless success
8380

84-
# append fiber to the batch
81+
# append fiber to the batch and push to global queue
8582
batch.to_unsafe[n] = fiber
83+
push_to_global_queue(batch.to_unsafe, n &+ 1)
84+
true
85+
end
8686

87-
# link the fibers
87+
# Transfers every fiber in the local runnables queue to the global queue.
88+
# This will grab the global lock.
89+
#
90+
# Can be executed by any scheduler.
91+
def drain : Nil
92+
batch = uninitialized Fiber[N]
93+
n = 0
94+
95+
head = @head.get(:acquire) # sync with other consumers
96+
loop do
97+
tail = @tail.get(:acquire) # sync with the producer
98+
99+
n = (tail &- head)
100+
return if n == 0 # queue is empty
101+
102+
# try to grab everything from local queue
103+
head, success = try_grab(batch.to_unsafe, head, n)
104+
break if success
105+
end
106+
107+
push_to_global_queue(batch.to_unsafe, n)
108+
end
109+
110+
private def try_grab(batch, head, n)
88111
n.times do |i|
89-
batch.to_unsafe[i].list_next = batch.to_unsafe[i &+ 1]
112+
batch[i] = @buffer.to_unsafe[(head &+ i) % N]
90113
end
91-
list = Fiber::List.new(batch.to_unsafe[0], batch.to_unsafe[n], size: (n &+ 1).to_i32)
114+
@head.compare_and_set(head, head &+ n, :acquire_release, :acquire)
115+
end
92116

93-
# now put the batch on global queue (grabs the global lock)
94-
@global_queue.bulk_push(pointerof(list))
117+
private def push_to_global_queue(batch, n)
118+
# link the fibers
119+
(n &- 1).times do |i|
120+
batch[i].list_next = batch[i &+ 1]
121+
end
122+
list = Fiber::List.new(batch[0], batch[n &- 1], size: n.to_i32)
95123

96-
true
124+
# and put the batch on global queue (grabs the global lock)
125+
@global_queue.bulk_push(pointerof(list))
97126
end
98127

99128
# Tries to enqueue all the fibers in *list* into the local queue. If the

0 commit comments

Comments
 (0)