Skip to content

Commit 027312b

Browse files
Streaming (#1874)
* dspy.streamify * Update docs * Fix ruff lint error * Bring back send_stream to settings * Improve doc * Bring back request_cache setting * sse => streaming_response * Simplify dsp.utils.settings diff * Add load/dump to LRUCache + drop callable request params * ujson => pickle for dump/load * Stream fix Signed-off-by: dbczumar <[email protected]> * test streaming Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * Streaming works Signed-off-by: dbczumar <[email protected]> * Fix Signed-off-by: dbczumar <[email protected]> * fix Signed-off-by: dbczumar <[email protected]> * no ignore change Signed-off-by: dbczumar <[email protected]> * Simple init * Simple init --------- Signed-off-by: dbczumar <[email protected]> Co-authored-by: dbczumar <[email protected]>
1 parent 7e102fe commit 027312b

File tree

7 files changed

+254
-16
lines changed

7 files changed

+254
-16
lines changed

docs/docs/tutorials/deployment/index.md

Lines changed: 37 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,11 @@ class Question(BaseModel):
4141
# Configure your language model and 'asyncify' your DSPy program.
4242
lm = dspy.LM("openai/gpt-4o-mini")
4343
dspy.settings.configure(lm=lm, async_max_workers=4) # default is 8
44-
dspy_program = dspy.ChainOfThought("question -> answer")
45-
dspy_program = dspy.asyncify(dspy_program)
4644

45+
dspy_program = dspy.asyncify(dspy.ChainOfThought("question -> answer"))
46+
streaming_dspy_program = dspy.streamify(dspy_program)
47+
48+
# Define an endpoint (no streaming)
4749
@app.post("/predict")
4850
async def predict(question: Question):
4951
try:
@@ -54,14 +56,45 @@ async def predict(question: Question):
5456
}
5557
except Exception as e:
5658
raise HTTPException(status_code=500, detail=str(e))
59+
60+
# Define an endpoint (streaming)
61+
from fastapi.responses import StreamingResponse
62+
63+
@app.post("/predict/stream")
64+
async def stream(question: Question):
65+
async def generate():
66+
async for value in streaming_dspy_program(question=question.text):
67+
if isinstance(value, dspy.Prediction):
68+
data = {"prediction": value.labels().toDict()}
69+
elif isinstance(value, litellm.ModelResponse):
70+
data = {"chunk": value.json()}
71+
yield f"data: {ujson.dumps(data)}\n\n"
72+
yield "data: [DONE]\n\n"
73+
74+
return StreamingResponse(generate(), media_type="text/event-stream")
75+
76+
# Since you're often going to want to stream the result of a DSPy program as server-sent events,
77+
# we've included a helper function for that, which is equivalent to the code above.
78+
79+
from dspy.utils.streaming import streaming_response
80+
81+
@app.post("/predict/stream")
82+
async def stream(question: Question):
83+
stream = streaming_dspy_program(question=question.text)
84+
return StreamingResponse(streaming_response(stream), media_type="text/event-stream")
5785
```
5886

5987
In the code above, we call `dspy.asyncify` to convert the dspy program to run in async mode for high-throughput FastAPI
60-
deployments. Currently, this runs the dspy program in a
61-
separate thread and awaits its result. By default, the limit of spawned threads is 8. Think of this like a worker pool.
88+
deployments. Currently, this runs the dspy program in a separate thread and awaits its result.
89+
90+
By default, the limit of spawned threads is 8. Think of this like a worker pool.
6291
If you have 8 in-flight programs and call it once more, the 9th call will wait until one of the 8 returns.
6392
You can configure the async capacity using the new `async_max_workers` setting.
6493

94+
We also use `dspy.streamify` to convert the dspy program to a streaming mode. This is useful when you want to stream
95+
the intermediate outputs (i.e. O1-style reasoning) to the client before the final prediction is ready. This uses
96+
asyncify under the hood and inherits the execution semantics.
97+
6598
Write your code to a file, e.g., `fastapi_dspy.py`. Then you can serve the app with:
6699

67100
```bash

dspy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging
1313
from dspy.utils.asyncify import asyncify
1414
from dspy.utils.saving import load
15+
from dspy.utils.streaming import streamify
1516

1617
from dspy.dsp.utils.settings import settings
1718

dspy/clients/lm.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,17 @@
55
import uuid
66
from datetime import datetime
77
from hashlib import sha256
8-
from typing import Any, Dict, List, Literal, Optional
8+
from typing import Any, Dict, List, Literal, Optional, cast
99

1010
import litellm
1111
import pydantic
1212
import ujson
13+
from anyio.streams.memory import MemoryObjectSendStream
14+
from asyncer import syncify
1315
from cachetools import LRUCache, cached
1416
from litellm import RetryPolicy
1517

18+
import dspy
1619
from dspy.adapters.base import Adapter
1720
from dspy.clients.openai import OpenAIProvider
1821
from dspy.clients.provider import Provider, TrainingJob
@@ -309,16 +312,41 @@ def cached_litellm_completion(request: Dict[str, Any], num_retries: int):
309312

310313

311314
def litellm_completion(request: Dict[str, Any], num_retries: int, cache={"no-cache": True, "no-store": True}):
312-
return litellm.completion(
313-
cache=cache,
315+
retry_kwargs = dict(
314316
retry_policy=_get_litellm_retry_policy(num_retries),
315317
# In LiteLLM version 1.55.3 (the first version that supports retry_policy as an argument
316318
# to completion()), the default value of max_retries is non-zero for certain providers, and
317319
# max_retries is stacked on top of the retry_policy. To avoid this, we set max_retries=0
318320
max_retries=0,
319-
**request,
320321
)
321322

323+
stream = dspy.settings.send_stream
324+
if stream is None:
325+
return litellm.completion(
326+
cache=cache,
327+
**retry_kwargs,
328+
**request,
329+
)
330+
331+
# The stream is already opened, and will be closed by the caller.
332+
stream = cast(MemoryObjectSendStream, stream)
333+
334+
@syncify
335+
async def stream_completion():
336+
response = await litellm.acompletion(
337+
cache=cache,
338+
stream=True,
339+
**retry_kwargs,
340+
**request,
341+
)
342+
chunks = []
343+
async for chunk in response:
344+
chunks.append(chunk)
345+
await stream.send(chunk)
346+
return litellm.stream_chunk_builder(chunks)
347+
348+
return stream_completion()
349+
322350

323351
@request_cache(maxsize=None)
324352
def cached_litellm_text_completion(request: Dict[str, Any], num_retries: int):

dspy/dsp/utils/settings.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import copy
22
import threading
33
from contextlib import contextmanager
4+
45
from dspy.dsp.utils.utils import dotdict
56

67
DEFAULT_CONFIG = dotdict(
@@ -17,6 +18,7 @@
1718
backoff_time=10,
1819
callbacks=[],
1920
async_max_workers=8,
21+
send_stream=None,
2022
)
2123

2224
# Global base configuration and owner tracking
@@ -26,20 +28,22 @@
2628
# Global lock for settings configuration
2729
global_lock = threading.Lock()
2830

31+
2932
class ThreadLocalOverrides(threading.local):
3033
def __init__(self):
3134
self.overrides = dotdict()
3235

36+
3337
thread_local_overrides = ThreadLocalOverrides()
3438

3539

3640
class Settings:
3741
"""
3842
A singleton class for DSPy configuration settings.
39-
Thread-safe global configuration.
43+
Thread-safe global configuration.
4044
- 'configure' can be called by only one 'owner' thread (the first thread that calls it).
4145
- Other threads see the configured global values from 'main_thread_config'.
42-
- 'context' sets thread-local overrides. These overrides propagate to threads spawned
46+
- 'context' sets thread-local overrides. These overrides propagate to threads spawned
4347
inside that context block, when (and only when!) using a ParallelExecutor that copies overrides.
4448
4549
1. Only one unique thread (which can be any thread!) can call dspy.configure.
@@ -61,7 +65,7 @@ def lock(self):
6165
return global_lock
6266

6367
def __getattr__(self, name):
64-
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
68+
overrides = getattr(thread_local_overrides, "overrides", dotdict())
6569
if name in overrides:
6670
return overrides[name]
6771
elif name in main_thread_config:
@@ -70,7 +74,7 @@ def __getattr__(self, name):
7074
raise AttributeError(f"'Settings' object has no attribute '{name}'")
7175

7276
def __setattr__(self, name, value):
73-
if name in ('_instance',):
77+
if name in ("_instance",):
7478
super().__setattr__(name, value)
7579
else:
7680
self.configure(**{name: value})
@@ -82,7 +86,7 @@ def __setitem__(self, key, value):
8286
self.__setattr__(key, value)
8387

8488
def __contains__(self, key):
85-
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
89+
overrides = getattr(thread_local_overrides, "overrides", dotdict())
8690
return key in overrides or key in main_thread_config
8791

8892
def get(self, key, default=None):
@@ -92,7 +96,7 @@ def get(self, key, default=None):
9296
return default
9397

9498
def copy(self):
95-
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
99+
overrides = getattr(thread_local_overrides, "overrides", dotdict())
96100
return dotdict({**main_thread_config, **overrides})
97101

98102
@property
@@ -122,7 +126,7 @@ def context(self, **kwargs):
122126
If threads are spawned inside this block using ParallelExecutor, they will inherit these overrides.
123127
"""
124128

125-
original_overrides = getattr(thread_local_overrides, 'overrides', dotdict()).copy()
129+
original_overrides = getattr(thread_local_overrides, "overrides", dotdict()).copy()
126130
new_overrides = dotdict({**main_thread_config, **original_overrides, **kwargs})
127131
thread_local_overrides.overrides = new_overrides
128132

@@ -132,7 +136,7 @@ def context(self, **kwargs):
132136
thread_local_overrides.overrides = original_overrides
133137

134138
def __repr__(self):
135-
overrides = getattr(thread_local_overrides, 'overrides', dotdict())
139+
overrides = getattr(thread_local_overrides, "overrides", dotdict())
136140
combined_config = {**main_thread_config, **overrides}
137141
return repr(combined_config)
138142

dspy/utils/streaming.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
from asyncio import iscoroutinefunction
2+
from typing import Any, AsyncGenerator, Awaitable, Callable
3+
4+
import litellm
5+
import ujson
6+
from anyio import create_memory_object_stream, create_task_group
7+
from anyio.streams.memory import MemoryObjectSendStream
8+
9+
from dspy.primitives.prediction import Prediction
10+
from dspy.primitives.program import Module
11+
from dspy.utils.asyncify import asyncify
12+
13+
14+
def streamify(program: Module) -> Callable[[Any, Any], Awaitable[Any]]:
15+
"""
16+
Wrap a DSPy program so that it streams its outputs incrementally, rather than returning them
17+
all at once.
18+
19+
Args:
20+
program: The DSPy program to wrap with streaming functionality.
21+
Returns:
22+
A function that takes the same arguments as the original program, but returns an async
23+
generator that yields the program's outputs incrementally.
24+
25+
Example:
26+
>>> class TestSignature(dspy.Signature):
27+
>>> input_text: str = dspy.InputField()
28+
>>> output_text: str = dspy.OutputField()
29+
>>>
30+
>>> # Create the program and wrap it with streaming functionality
31+
>>> program = dspy.streamify(dspy.Predict(TestSignature))
32+
>>>
33+
>>> # Use the program with streaming output
34+
>>> async def use_streaming():
35+
>>> output_stream = program(input_text="Test")
36+
>>> async for value in output_stream:
37+
>>> print(value) # Print each streamed value incrementally
38+
"""
39+
import dspy
40+
41+
if not iscoroutinefunction(program):
42+
program = asyncify(program)
43+
44+
async def generator(args, kwargs, stream: MemoryObjectSendStream):
45+
with dspy.settings.context(send_stream=stream):
46+
prediction = await program(*args, **kwargs)
47+
48+
await stream.send(prediction)
49+
50+
async def streamer(*args, **kwargs):
51+
send_stream, receive_stream = create_memory_object_stream(16)
52+
async with create_task_group() as tg, send_stream, receive_stream:
53+
tg.start_soon(generator, args, kwargs, send_stream)
54+
55+
async for value in receive_stream:
56+
yield value
57+
if isinstance(value, Prediction):
58+
return
59+
60+
return streamer
61+
62+
63+
async def streaming_response(streamer: AsyncGenerator) -> AsyncGenerator:
64+
"""
65+
Convert a DSPy program output stream to an OpenAI-compatible output stream that can be
66+
used by a service as an API response to a streaming request.
67+
68+
Args:
69+
streamer: An async generator that yields values from a DSPy program output stream.
70+
Returns:
71+
An async generator that yields OpenAI-compatible streaming response chunks.
72+
"""
73+
async for value in streamer:
74+
if isinstance(value, Prediction):
75+
data = {"prediction": {k: v for k, v in value.items(include_dspy=False)}}
76+
yield f"data: {ujson.dumps(data)}\n\n"
77+
elif isinstance(value, litellm.ModelResponse):
78+
data = {"chunk": value.json()}
79+
yield f"data: {ujson.dumps(data)}\n\n"
80+
elif isinstance(value, str) and value.startswith("data:"):
81+
# The chunk value is an OpenAI-compatible streaming chunk value,
82+
# e.g. "data: {"finish_reason": "stop", "index": 0, "is_finished": True, ...}",
83+
# so yield it directly
84+
yield value
85+
else:
86+
raise ValueError(f"Unknown chunk value type: {value}")
87+
yield "data: [DONE]\n\n"

tests/test_utils/server/litellm_server.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
import json
22
import os
3+
import time
4+
from typing import AsyncIterator, Iterator
35

46
import litellm
57
from litellm import CustomLLM
8+
from litellm.types.utils import GenericStreamingChunk
69

710
LITELLM_TEST_SERVER_LOG_FILE_PATH_ENV_VAR = "LITELLM_TEST_SERVER_LOG_FILE_PATH"
811

@@ -16,6 +19,28 @@ async def acompletion(self, *args, **kwargs) -> litellm.ModelResponse:
1619
_append_request_to_log_file(kwargs)
1720
return _get_mock_llm_response(kwargs)
1821

22+
def streaming(self, *args, **kwargs) -> Iterator[GenericStreamingChunk]:
23+
generic_streaming_chunk: GenericStreamingChunk = {
24+
"finish_reason": "stop",
25+
"index": 0,
26+
"is_finished": True,
27+
"text": '{"output_text": "Hello!"}',
28+
"tool_use": None,
29+
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
30+
}
31+
return generic_streaming_chunk # type: ignore
32+
33+
async def astreaming(self, *args, **kwargs) -> AsyncIterator[GenericStreamingChunk]:
34+
generic_streaming_chunk: GenericStreamingChunk = {
35+
"finish_reason": "stop",
36+
"index": 0,
37+
"is_finished": True,
38+
"text": '{"output_text": "Hello!"}',
39+
"tool_use": None,
40+
"usage": {"completion_tokens": 0, "prompt_tokens": 0, "total_tokens": 0},
41+
}
42+
yield generic_streaming_chunk
43+
1944

2045
def _get_mock_llm_response(request_kwargs):
2146
_throw_exception_based_on_content_if_applicable(request_kwargs)

0 commit comments

Comments
 (0)