Skip to content

Commit e0907eb

Browse files
authored
Refined semantics for Settings thread-safety (#1947)
* Refined semantics for settings thread-safety: only one (any) thread can call dspy.configure and it affects a global state. Any thread can use dspy.context. It propagates to child threads created with DSPy primitives. * Update settings.py
1 parent 741ac10 commit e0907eb

File tree

3 files changed

+73
-61
lines changed

3 files changed

+73
-61
lines changed

dspy/dsp/utils/settings.py

Lines changed: 40 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -19,36 +19,47 @@
1919
async_max_workers=8,
2020
)
2121

22-
# Global base configuration
22+
# Global base configuration and owner tracking
2323
main_thread_config = copy.deepcopy(DEFAULT_CONFIG)
24+
config_owner_thread_id = None
2425

26+
# Global lock for settings configuration
27+
global_lock = threading.Lock()
2528

2629
class ThreadLocalOverrides(threading.local):
2730
def __init__(self):
28-
self.overrides = dotdict() # Initialize thread-local overrides
31+
self.overrides = dotdict()
2932

30-
31-
# Create the thread-local storage
3233
thread_local_overrides = ThreadLocalOverrides()
3334

3435

3536
class Settings:
3637
"""
3738
A singleton class for DSPy configuration settings.
38-
39-
This is thread-safe. User threads are supported both through ParallelExecutor and native threading.
40-
- If native threading is used, the thread inherits the initial config from the main thread.
41-
- If ParallelExecutor is used, the thread inherits the initial config from its parent thread.
39+
Thread-safe global configuration.
40+
- 'configure' can be called by only one 'owner' thread (the first thread that calls it).
41+
- Other threads see the configured global values from 'main_thread_config'.
42+
- 'context' sets thread-local overrides. These overrides propagate to threads spawned
43+
inside that context block, when (and only when!) using a ParallelExecutor that copies overrides.
44+
45+
1. Only one unique thread (which can be any thread!) can call dspy.configure.
46+
2. It affects a global state, visible to all. As a result, user threads work, but they shouldn't be
47+
mixed with concurrent changes to dspy.configure from the "main" thread.
48+
(TODO: In the future, add warnings: if there are near-in-time user-thread reads followed by .configure calls.)
49+
3. Any thread can use dspy.context. It propagates to child threads created with DSPy primitives: Parallel, asyncify, etc.
4250
"""
4351

4452
_instance = None
4553

4654
def __new__(cls):
4755
if cls._instance is None:
4856
cls._instance = super().__new__(cls)
49-
cls._instance.lock = threading.Lock() # maintained here for DSPy assertions.py
5057
return cls._instance
5158

59+
@property
60+
def lock(self):
61+
return global_lock
62+
5263
def __getattr__(self, name):
5364
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
5465
if name in overrides:
@@ -64,8 +75,6 @@ def __setattr__(self, name, value):
6475
else:
6576
self.configure(**{name: value})
6677

67-
# Dictionary-like access
68-
6978
def __getitem__(self, key):
7079
return self.__getattr__(key)
7180

@@ -88,42 +97,40 @@ def copy(self):
8897

8998
@property
9099
def config(self):
91-
config = self.copy()
92-
if 'lock' in config:
93-
del config['lock']
94-
return config
95-
96-
# Configuration methods
100+
return self.copy()
97101

98102
def configure(self, **kwargs):
99-
global main_thread_config
103+
global main_thread_config, config_owner_thread_id
104+
current_thread_id = threading.get_ident()
100105

101-
# Get or initialize thread-local overrides
102-
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
103-
thread_local_overrides.overrides = dotdict(
104-
{**copy.deepcopy(DEFAULT_CONFIG), **main_thread_config, **overrides, **kwargs}
105-
)
106+
with self.lock:
107+
# First configuration: establish ownership. If ownership established, only that thread can configure.
108+
if config_owner_thread_id in [None, current_thread_id]:
109+
config_owner_thread_id = current_thread_id
110+
else:
111+
raise RuntimeError("dspy.settings can only be changed by the thread that initially configured it.")
106112

107-
# Update main_thread_config, in the main thread only
108-
if threading.current_thread() is threading.main_thread():
109-
main_thread_config = thread_local_overrides.overrides
113+
# Update global config
114+
for k, v in kwargs.items():
115+
main_thread_config[k] = v
110116

111117
@contextmanager
112118
def context(self, **kwargs):
113-
"""Context manager for temporary configuration changes."""
114-
global main_thread_config
119+
"""
120+
Context manager for temporary configuration changes at the thread level.
121+
Does not affect global configuration. Changes only apply to the current thread.
122+
If threads are spawned inside this block using ParallelExecutor, they will inherit these overrides.
123+
"""
124+
115125
original_overrides = getattr(thread_local_overrides, 'overrides', dotdict()).copy()
116-
original_main_thread_config = main_thread_config.copy()
126+
new_overrides = dotdict({**main_thread_config, **original_overrides, **kwargs})
127+
thread_local_overrides.overrides = new_overrides
117128

118-
self.configure(**kwargs)
119129
try:
120130
yield
121131
finally:
122132
thread_local_overrides.overrides = original_overrides
123133

124-
if threading.current_thread() is threading.main_thread():
125-
main_thread_config = original_main_thread_config
126-
127134
def __repr__(self):
128135
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
129136
combined_config = {**main_thread_config, **overrides}

dspy/utils/asyncify.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
def get_async_max_workers():
1212
import dspy
13-
1413
return dspy.settings.async_max_workers
1514

1615

@@ -31,28 +30,31 @@ def asyncify(program: Module) -> Callable[[Any, Any], Awaitable[Any]]:
3130
Wraps a DSPy program so that it can be called asynchronously. This is useful for running a
3231
program in parallel with another task (e.g., another DSPy program).
3332
33+
This implementation propagates the current thread's configuration context to the worker thread.
34+
3435
Args:
3536
program: The DSPy program to be wrapped for asynchronous execution.
3637
3738
Returns:
38-
A function that takes the same arguments as the program, but returns an awaitable that
39-
resolves to the program's output.
40-
41-
Example:
42-
>>> class TestSignature(dspy.Signature):
43-
>>> input_text: str = dspy.InputField()
44-
>>> output_text: str = dspy.OutputField()
45-
>>>
46-
>>> # Create the program and wrap it for asynchronous execution
47-
>>> program = dspy.asyncify(dspy.Predict(TestSignature))
48-
>>>
49-
>>> # Use the program asynchronously
50-
>>> async def get_prediction():
51-
>>> prediction = await program(input_text="Test")
52-
>>> print(prediction) # Handle the result of the asynchronous execution
39+
An async function that, when awaited, runs the program in a worker thread. The current
40+
thread's configuration context is inherited for each call.
5341
"""
54-
import threading
55-
56-
assert threading.current_thread() is threading.main_thread(), "asyncify can only be called from the main thread"
57-
# NOTE: To allow this to be nested, we'd need behavior with contextvars like parallelizer.py
58-
return asyncer.asyncify(program, abandon_on_cancel=True, limiter=get_limiter())
42+
async def async_program(*args, **kwargs) -> Any:
43+
# Capture the current overrides at call-time.
44+
from dspy.dsp.utils.settings import thread_local_overrides
45+
parent_overrides = thread_local_overrides.overrides.copy()
46+
47+
def wrapped_program(*a, **kw):
48+
from dspy.dsp.utils.settings import thread_local_overrides
49+
original_overrides = thread_local_overrides.overrides
50+
thread_local_overrides.overrides = parent_overrides.copy()
51+
try:
52+
return program(*a, **kw)
53+
finally:
54+
thread_local_overrides.overrides = original_overrides
55+
56+
# Create a fresh asyncified callable each time, ensuring the latest context is used.
57+
call_async = asyncer.asyncify(wrapped_program, abandon_on_cancel=True, limiter=get_limiter())
58+
return await call_async(*args, **kwargs)
59+
60+
return async_program

dspy/utils/parallelizer.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
logger = logging.getLogger(__name__)
1212

13+
1314
class ParallelExecutor:
1415
def __init__(
1516
self,
@@ -20,7 +21,6 @@ def __init__(
2021
compare_results=False,
2122
):
2223
"""Offers isolation between the tasks (dspy.settings) irrespective of whether num_threads == 1 or > 1."""
23-
2424
self.num_threads = num_threads
2525
self.disable_progress_bar = disable_progress_bar
2626
self.max_errors = max_errors
@@ -72,15 +72,17 @@ def _execute_isolated_single_thread(self, function, data):
7272
file=sys.stdout
7373
)
7474

75+
from dspy.dsp.utils.settings import thread_local_overrides
76+
original_overrides = thread_local_overrides.overrides
77+
7578
for item in data:
7679
with logging_redirect_tqdm():
7780
if self.cancel_jobs.is_set():
7881
break
7982

80-
# Create an isolated context for each task using thread-local overrides
81-
from dspy.dsp.utils.settings import thread_local_overrides
82-
original_overrides = thread_local_overrides.overrides
83-
thread_local_overrides.overrides = thread_local_overrides.overrides.copy()
83+
# Create an isolated context for each task by copying current overrides
84+
# This way, even if an iteration modifies the overrides, it won't affect subsequent iterations
85+
thread_local_overrides.overrides = original_overrides.copy()
8486

8587
try:
8688
result = function(item)
@@ -122,6 +124,8 @@ def _execute_multi_thread(self, function, data):
122124
@contextlib.contextmanager
123125
def interrupt_handler_manager():
124126
"""Sets the cancel_jobs event when a SIGINT is received, only in the main thread."""
127+
128+
# TODO: Is this check conducive to nested usage of ParallelExecutor?
125129
if threading.current_thread() is threading.main_thread():
126130
default_handler = signal.getsignal(signal.SIGINT)
127131

@@ -145,7 +149,7 @@ def cancellable_function(parent_overrides, index_item):
145149
if self.cancel_jobs.is_set():
146150
return index, job_cancelled
147151

148-
# Create an isolated context for each task using thread-local overrides
152+
# Create an isolated context for each task by copying parent's overrides
149153
from dspy.dsp.utils.settings import thread_local_overrides
150154
original_overrides = thread_local_overrides.overrides
151155
thread_local_overrides.overrides = parent_overrides.copy()
@@ -156,7 +160,6 @@ def cancellable_function(parent_overrides, index_item):
156160
thread_local_overrides.overrides = original_overrides
157161

158162
with ThreadPoolExecutor(max_workers=self.num_threads) as executor, interrupt_handler_manager():
159-
# Capture the parent thread's overrides
160163
from dspy.dsp.utils.settings import thread_local_overrides
161164
parent_overrides = thread_local_overrides.overrides.copy()
162165

0 commit comments

Comments
 (0)