diff --git a/guidance/_ast.py b/guidance/_ast.py index 81d5574f3..461769a13 100644 --- a/guidance/_ast.py +++ b/guidance/_ast.py @@ -7,6 +7,7 @@ TYPE_CHECKING, Any, Callable, + AsyncIterable, Iterator, Optional, Sequence, @@ -17,7 +18,9 @@ from typing_extensions import assert_never from ._parser import ByteParser, ByteParserException -from .trace import OutputAttr +from .trace import InputAttr, OutputAttr, RoleOpenerInput, RoleCloserInput + +NodeAttr = Union[InputAttr, OutputAttr] if TYPE_CHECKING: from .models._base import Interpreter, State @@ -116,13 +119,13 @@ def __call__(self, model): return model def __add__(self, other): - if not isinstance(other, (str, GrammarNode, Function)): + if not isinstance(other, (str, ASTNode, Function)): return NotImplemented if isinstance(other, str): other = _parse_tags(other) - if isinstance(other, GrammarNode) and other.is_null: + if isinstance(other, ASTNode) and other.is_null: return self def __add__(model): @@ -131,13 +134,13 @@ def __add__(model): return Function(__add__, [], {}) def __radd__(self, other): - if not isinstance(other, (str, GrammarNode, Function)): + if not isinstance(other, (str, ASTNode, Function)): return NotImplemented if isinstance(other, str): other = _parse_tags(other) - if isinstance(other, GrammarNode) and other.is_null: + if isinstance(other, ASTNode) and other.is_null: return self def __radd__(model): @@ -145,40 +148,161 @@ def __radd__(model): return Function(__radd__, [], {}) +@dataclass +class AsyncFunction(Tagged): + name: str = field(init=False) + f: Callable + args: tuple[Any, ...] + kwargs: dict[str, Any] -S = TypeVar("S", bound="State") + def __post_init__(self): + self.name = self.f.__name__ + + async def __call__(self, model): + model = await self.f(model, *self.args, **self.kwargs) + if model is None: + raise Exception( + f"The guidance function `{self.f.__name__}` did not return a model object! You need to return an updated model object at the end of your guidance function." + ) + return model + + def __add__(self, other): + if not isinstance(other, (str, ASTNode, Function, AsyncFunction)): + return NotImplemented + + if isinstance(other, str): + other = _parse_tags(other) + + if isinstance(other, ASTNode) and other.is_null: + return self + + async def __add__(model): + return (await self(model)) + other + + return AsyncFunction(__add__, [], {}) + + def __radd__(self, other): + if not isinstance(other, (str, ASTNode, Function, AsyncFunction)): + return NotImplemented + + if isinstance(other, str): + other = _parse_tags(other) + + if isinstance(other, ASTNode) and other.is_null: + return self + async def __radd__(model): + return await self(model + other) + + return AsyncFunction(__radd__, [], {}) + + +S = TypeVar("S", bound="State") class ASTNode(ABC): @abstractmethod - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[NodeAttr]: pass def simplify(self) -> "ASTNode": return self + @property + def is_null(self) -> bool: + return False + + def __add__(self, other): + if isinstance(other, str): + other = _parse_tags(other) + + if isinstance(other, ASTNode): + return Concatenate((self, other)) + + return NotImplemented + + def __radd__(self, other): + if isinstance(other, str): + other = _parse_tags(other) + + if isinstance(other, ASTNode): + return Concatenate((other, self)) + + return NotImplemented + + @classmethod + def null(cls) -> "ASTNode": + return Concatenate(()) + +@dataclass(frozen=True) +class Concatenate(ASTNode): + nodes: tuple[ASTNode, ...] + + async def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[NodeAttr]: + buffer: Optional[GrammarNode] = None + for child in self: + assert not isinstance(child, Concatenate) # iter should be flat + if isinstance(child, GrammarNode): + if buffer is None: + buffer = child + else: + buffer = buffer + child + else: + if buffer is not None: + async for attr in interpreter.run(buffer, **kwargs): + yield attr + buffer = None + async for attr in interpreter.run(child, **kwargs): + yield attr + if buffer is not None: + async for attr in interpreter.run(buffer, **kwargs): + yield attr + + def __iter__(self) -> Iterator[ASTNode]: + for node in self.nodes: + if isinstance(node, Concatenate): + yield from node + else: + yield node @dataclass class RoleStart(ASTNode): role: str - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: - return interpreter._role_start(self, **kwargs) + async def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[NodeAttr]: + yield RoleOpenerInput(name=self.role) + async for output_attr in interpreter._role_start(self, **kwargs): + yield output_attr @dataclass class RoleEnd(ASTNode): role: str - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: - return interpreter._role_end(self, **kwargs) + async def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[NodeAttr]: + yield RoleCloserInput(name=self.role) + async for output_attr in interpreter._role_end(self, **kwargs): + yield output_attr + + +@dataclass +class CaptureStart(ASTNode): + name: str + + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: + return interpreter.capture_start(self, **kwargs) + +@dataclass +class CaptureEnd(ASTNode): + name: str + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: + return interpreter.capture_end(self, **kwargs) @dataclass class ImageBlob(ASTNode): data: str - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.image_blob(self, **kwargs) @@ -186,7 +310,7 @@ def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: class ImageUrl(ASTNode): url: str - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.image_url(self, **kwargs) @@ -194,7 +318,7 @@ def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: class AudioBlob(ASTNode): data: str - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.audio_blob(self, **kwargs) @@ -202,7 +326,7 @@ class GenAudio(ASTNode): def __init__(self, kwargs: dict[str, Any]): self.kwargs = kwargs - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.gen_audio(self, **kwargs) @@ -317,7 +441,7 @@ class LiteralNode(GrammarNode): def is_null(self) -> bool: return self.value == "" - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.text(self, **kwargs) @@ -325,7 +449,7 @@ def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: class RegexNode(GrammarNode): regex: Optional[str] - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.regex(self, **kwargs) @@ -353,7 +477,7 @@ def simplify(self) -> "GrammarNode": def children(self) -> Sequence["GrammarNode"]: return self.alternatives - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.select(self, **kwargs) @@ -376,7 +500,7 @@ def simplify(self) -> "GrammarNode": def children(self) -> Sequence["GrammarNode"]: return self.nodes - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.join(self, **kwargs) @@ -402,7 +526,7 @@ def children(self) -> Sequence["GrammarNode"]: def simplify(self) -> GrammarNode: return RepeatNode(self.node.simplify(), self.min, self.max) - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.repeat(self, **kwargs) @@ -415,7 +539,7 @@ def is_terminal(self) -> bool: # this can be used as part of bigger regexes return True - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.substring(self, **kwargs) @@ -466,7 +590,7 @@ def is_terminal(self) -> bool: def children(self) -> Sequence["GrammarNode"]: return (self.value,) - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.rule(self, **kwargs) @dataclass(frozen=True, eq=False) @@ -485,7 +609,7 @@ def is_terminal(self) -> bool: # so it should never be terminal. return False - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: if self.target is None: raise ValueError("RuleRefNode target not set") return interpreter.rule(self.target) @@ -501,7 +625,7 @@ class SubgrammarNode(BaseSubgrammarNode): body: GrammarNode skip_regex: Optional[str] = None - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.subgrammar(self, **kwargs) @@ -509,7 +633,7 @@ def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: class JsonNode(BaseSubgrammarNode): schema: dict[str, Any] - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.json(self, **kwargs) @@ -517,7 +641,7 @@ def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: class LarkNode(BaseSubgrammarNode): lark_grammar: str - def _run(self, interpreter: "Interpreter[S]", **kwargs) -> Iterator[OutputAttr]: + def _run(self, interpreter: "Interpreter[S]", **kwargs) -> AsyncIterable[OutputAttr]: return interpreter.lark(self, **kwargs) class LarkSerializer: diff --git a/guidance/_bg/__init__.py b/guidance/_bg/__init__.py index c6f150624..0490e7d4e 100644 --- a/guidance/_bg/__init__.py +++ b/guidance/_bg/__init__.py @@ -6,15 +6,16 @@ import asyncio import threading from asyncio import AbstractEventLoop, Future, Task -from typing import Tuple, Coroutine +from typing import Coroutine, Any, TypeVar +T = TypeVar("T") def _start_asyncio_loop(loop: AbstractEventLoop): asyncio.set_event_loop(loop) loop.run_forever() -def _asyncio_background_thread() -> Tuple[threading.Thread, AbstractEventLoop]: +def _asyncio_background_thread() -> tuple[threading.Thread, AbstractEventLoop]: loop = asyncio.new_event_loop() thread = threading.Thread(target=_start_asyncio_loop, args=(loop,)) thread.daemon = True @@ -29,7 +30,7 @@ def __init__(self): self._loop = None self._thread = None - def _thread_and_loop(self) -> Tuple[threading.Thread, AbstractEventLoop]: + def _thread_and_loop(self) -> tuple[threading.Thread, AbstractEventLoop]: if self._loop is None: self._thread, self._loop = _asyncio_background_thread() self._thread.start() @@ -41,7 +42,7 @@ def call_soon_threadsafe(self, cb, *args, context = None): _, loop = self._thread_and_loop() return loop.call_soon_threadsafe(cb, *args, context=context) - def run_async_coroutine(self, coroutine: Coroutine) -> Future: + def run_async_coroutine(self, coroutine: Coroutine[Any, Any, T]) -> Future[T]: """ Runs an asynchronous coroutine in the visual thread. Args: @@ -55,7 +56,7 @@ def run_async_coroutine(self, coroutine: Coroutine) -> Future: return future @staticmethod - async def async_task(coroutine: Coroutine) -> Task: + async def async_task(coroutine: Coroutine[Any, Any, T]) -> Task[T]: """ Creates an asyncio task from coroutine. Args: diff --git a/guidance/_guidance.py b/guidance/_guidance.py index e58d1699b..6155e88b2 100644 --- a/guidance/_guidance.py +++ b/guidance/_guidance.py @@ -8,7 +8,7 @@ from ._grammar import string -from ._ast import Function, RuleRefNode, RuleNode +from ._ast import AsyncFunction, Function, RuleRefNode, RuleNode from ._utils import strip_multiline_string_indents, make_weak_bound_method, signature_pop from .models import Model @@ -127,6 +127,11 @@ def _decorator(f, *, stateless, cache, model): @functools.wraps(f) def wrapped(*args, **kwargs): + from inspect import iscoroutinefunction + if iscoroutinefunction(f): + if stateless is True: + raise ValueError("Stateless functions cannot be async") + return AsyncFunction(f, args, kwargs) # make a stateless grammar if we can if stateless is True or ( diff --git a/guidance/_reentrant_async.py b/guidance/_reentrant_async.py new file mode 100644 index 000000000..48abbd2a4 --- /dev/null +++ b/guidance/_reentrant_async.py @@ -0,0 +1,83 @@ +""" +Adapted from greenletio (https://github.com/miguelgrinberg/greenletio), used under the MIT License. +See LICENSE file or https://github.com/miguelgrinberg/greenletio/blob/main/LICENSE for details. +""" + +import contextvars +import sys +import threading +from functools import wraps +from typing import Any, Callable, Coroutine, TypeVar, cast + +from greenlet import getcurrent, greenlet # type: ignore[import-untyped] +from typing_extensions import ParamSpec + + +class ReentrantAsyncException(RuntimeError): + """Exception raised when a coroutine is awaited in a non-greenlet context.""" + + pass + + +P = ParamSpec("P") +T = TypeVar("T") + + +def reentrant_await(coro: Coroutine[Any, Any, T]) -> T: + """ + Sends a coroutine to the parent greenlet, which is expected to await the coroutine + for us and send the result back. + + When there is no parent greenlet, we raise a ReentrantAsyncException. + """ + + parent_gl = getcurrent().parent + if parent_gl is None: + coro.close() + raise ReentrantAsyncException("Attempted to use synchronous entry-point in async context") + return cast(T, parent_gl.switch(coro)) + + +def sync_to_reentrant_async(fn: Callable[P, T]) -> Callable[P, Coroutine[Any, Any, T]]: + """ + Decorator to convert a synchronous function into a re-entrant asynchronous one. + + Calls to `reentrant_await` down the stack will bounce back here, and we we'll await + the coroutine for them + """ + + @wraps(fn) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + gl = greenlet(fn) + gl.gr_context = contextvars.copy_context() + coro = gl.switch(*args, **kwargs) + while gl: + coro = coro + try: + result = await coro + except: # noqa: E722 + # this catches exceptions from async functions awaited in + # sync code, and re-raises them in the greenlet + coro = gl.throw(*sys.exc_info()) + else: + coro = gl.switch(result) + return coro + + return wrapper + + +def run_async_coroutine_in_bg_async(coro: Coroutine[Any, Any, T]) -> T: + """ + Run a coroutine in the background thread and wait for it to finish. + (This is a blocking call.) + """ + + from .registry import get_bg_async + + bg_async = get_bg_async() + thread, _ = bg_async._thread_and_loop() + if thread is threading.current_thread(): + coro.close() + raise RuntimeError("Cannot nest async call -- already in background thread.") + fut = bg_async.run_async_coroutine(coro) + return fut.result() diff --git a/guidance/library/_gen.py b/guidance/library/_gen.py index 88342cc28..6b7cd3488 100644 --- a/guidance/library/_gen.py +++ b/guidance/library/_gen.py @@ -133,10 +133,10 @@ def tool_gen(lm): ) grm = with_temperature(select(options), temperature) - initial_token_count = lm.token_count + initial_token_count = lm.get_token_count() tagged_name = "__LIST_APPEND:" + name if list_append and name is not None else name with block(tagged_name): - while lm.token_count <= max_tokens + initial_token_count: + while lm.get_token_count() <= max_tokens + initial_token_count: lm += grm tool_called = False for i in range(len(tools)): diff --git a/guidance/models/_azureai.py b/guidance/models/_azureai.py index fec29c49e..e6b58eb91 100644 --- a/guidance/models/_azureai.py +++ b/guidance/models/_azureai.py @@ -1,6 +1,6 @@ import logging -from typing import Callable, Iterator, Optional, Union, TYPE_CHECKING +from typing import Callable, AsyncIterable, Optional, Union, TYPE_CHECKING from pydantic import TypeAdapter @@ -44,7 +44,7 @@ def __init__( raise Exception( "Please install the openai package version >= 1 using `pip install openai -U` in order to use guidance.models.OpenAI!" ) - client = openai.AzureOpenAI( + client = openai.AsyncAzureOpenAI( azure_endpoint=azure_endpoint, azure_deployment=azure_deployment, api_version=api_version, @@ -153,18 +153,18 @@ def __init__( model_name: str, ): try: - import azure.ai.inference + import azure.ai.inference.aio except ImportError: raise Exception( "Please install the azure-ai-inference package using `pip install azure-ai-inference` in order to use guidance.models.AzureInference!" ) - client = azure.ai.inference.ChatCompletionsClient( + client = azure.ai.inference.aio.ChatCompletionsClient( endpoint=endpoint, credential=credential, ) super().__init__(client=client, model=model_name) - def _run(self, **kwargs) -> Iterator[OutputAttr]: + async def _run(self, **kwargs) -> AsyncIterable[OutputAttr]: if self.state.active_role is None: # Should never happen? raise ValueError( @@ -179,7 +179,7 @@ def _run(self, **kwargs) -> Iterator[OutputAttr]: f"OpenAI models do not support pre-filled assistant messages: got data {self.state.content}." ) - with self.client.complete( + chunks = await self.client.complete( body={ "messages": TypeAdapter(list[Message]).dump_python(self.state.messages), "log_probs": self.log_probs, @@ -190,13 +190,14 @@ def _run(self, **kwargs) -> Iterator[OutputAttr]: headers={ "extra-parameters": "pass-through", }, - ) as chunks: - yield from self._handle_stream(chunks) + ) + for attr in self._handle_stream(chunks): + yield attr # def rule(self, node: RuleNode, **kwargs) -> Iterator[OutputAttr]: # raise ValueError("Rule nodes are not supported for Azure Inference") - def json(self, node: JsonNode, **kwargs) -> Iterator[OutputAttr]: + def json(self, node: JsonNode, **kwargs) -> AsyncIterable[OutputAttr]: return self._run( json_schema={ "name": "json_schema", # TODO? diff --git a/guidance/models/_base/__init__.py b/guidance/models/_base/__init__.py index 0eb8ddfaf..a8ce76d85 100644 --- a/guidance/models/_base/__init__.py +++ b/guidance/models/_base/__init__.py @@ -11,4 +11,5 @@ "ASTNode", "ContentChunk", "MessageChunk", + "Interpreter", ] diff --git a/guidance/models/_base/_interpreter.py b/guidance/models/_base/_interpreter.py index e64aeb0a0..8c1c0612c 100644 --- a/guidance/models/_base/_interpreter.py +++ b/guidance/models/_base/_interpreter.py @@ -1,9 +1,11 @@ import base64 -from typing import Generic, Iterator, TypeVar +from abc import ABC +from typing import Generic, Iterable, AsyncIterable, TypeVar, Union, Optional from ..._ast import ( ASTNode, AudioBlob, + Concatenate, GenAudio, GrammarNode, ImageBlob, @@ -20,89 +22,105 @@ SelectNode, SubgrammarNode, SubstringNode, + CaptureStart, + CaptureEnd, ) from ..._utils import bytes_from -from ...trace import OutputAttr +from ...trace import InputAttr, OutputAttr, TextOutput from ._state import State -S = TypeVar("S", bound=State) +NodeAttr = Union[InputAttr, OutputAttr] +S = TypeVar("S", bound=State) -class Interpreter(Generic[S]): +class Interpreter(Generic[S], ABC): def __init__(self, state: S): self.state = state - def run(self, node: ASTNode, **kwargs) -> Iterator[OutputAttr]: - yield from node.simplify()._run(self, **kwargs) + async def run(self, node: ASTNode, **kwargs) -> AsyncIterable[NodeAttr]: + async for attr in node.simplify()._run(self, **kwargs): + if isinstance(attr, TextOutput) and attr.is_generated: + # TODO: this should probably be a lower-level responsibility? Not sure. + self.state.token_count += attr.token_count + yield attr + + async def capture_start(self, node: CaptureStart, **kwargs) -> AsyncIterable[OutputAttr]: + self.state.open_capture(node.name) + if False: + # Yes, this is intentional. + yield - def _role_start(self, node: RoleStart, **kwargs) -> Iterator[OutputAttr]: + async def capture_end(self, node: CaptureEnd, **kwargs) -> AsyncIterable[OutputAttr]: + yield self.state.close_capture(node.name) + + def _role_start(self, node: RoleStart, **kwargs) -> AsyncIterable[OutputAttr]: if self.state.active_role is not None: raise ValueError( f"Cannot open role {node.role!r}: {self.state.active_role!r} is already open." ) return self.role_start(node, **kwargs) - def role_start(self, node: RoleStart, **kwargs) -> Iterator[OutputAttr]: + def role_start(self, node: RoleStart, **kwargs) -> AsyncIterable[OutputAttr]: raise UnsupportedNodeError(interpreter=self, node=node) - def _role_end(self, node: RoleEnd, **kwargs) -> Iterator[OutputAttr]: + def _role_end(self, node: RoleEnd, **kwargs) -> AsyncIterable[OutputAttr]: if self.state.active_role is None: raise ValueError("Cannot close role without active role") if self.state.active_role != node.role: raise ValueError(f"Cannot close role {node.role!r}: {self.state.active_role!r} is open.") return self.role_end(node, **kwargs) - def role_end(self, node: RoleEnd, **kwargs) -> Iterator[OutputAttr]: + def role_end(self, node: RoleEnd, **kwargs) -> AsyncIterable[OutputAttr]: raise UnsupportedNodeError(interpreter=self, node=node) - def text(self, node: LiteralNode, **kwargs) -> Iterator[OutputAttr]: + def text(self, node: LiteralNode, **kwargs) -> AsyncIterable[OutputAttr]: raise UnsupportedNodeError(interpreter=self, node=node) - def image_blob(self, node: ImageBlob, **kwargs) -> Iterator[OutputAttr]: + def image_blob(self, node: ImageBlob, **kwargs) -> AsyncIterable[OutputAttr]: raise UnsupportedNodeError(interpreter=self, node=node) - def image_url(self, node: ImageUrl, **kwargs) -> Iterator[OutputAttr]: + def image_url(self, node: ImageUrl, **kwargs) -> AsyncIterable[OutputAttr]: + # TODO: we should be using something like httpx to fetch the image image_bytes = bytes_from(node.url, allow_local=False) base64_string = base64.b64encode(image_bytes).decode("utf-8") return self.image_blob(ImageBlob(data=base64_string), **kwargs) - def grammar(self, node: GrammarNode, **kwargs) -> Iterator[OutputAttr]: + def grammar(self, node: GrammarNode, **kwargs) -> AsyncIterable[OutputAttr]: raise UnsupportedNodeError(interpreter=self, node=node) - def regex(self, node: RegexNode, **kwargs) -> Iterator[OutputAttr]: + def regex(self, node: RegexNode, **kwargs) -> AsyncIterable[OutputAttr]: return self.grammar(node, **kwargs) - def select(self, node: SelectNode, **kwargs) -> Iterator[OutputAttr]: + def select(self, node: SelectNode, **kwargs) -> AsyncIterable[OutputAttr]: return self.grammar(node, **kwargs) - def join(self, node: JoinNode, **kwargs) -> Iterator[OutputAttr]: + def join(self, node: JoinNode, **kwargs) -> AsyncIterable[OutputAttr]: return self.grammar(node, **kwargs) - def repeat(self, node: RepeatNode, **kwargs) -> Iterator[OutputAttr]: + def repeat(self, node: RepeatNode, **kwargs) -> AsyncIterable[OutputAttr]: return self.grammar(node, **kwargs) - def substring(self, node: SubstringNode, **kwargs) -> Iterator[OutputAttr]: + def substring(self, node: SubstringNode, **kwargs) -> AsyncIterable[OutputAttr]: return self.grammar(node, **kwargs) - def rule(self, node: RuleNode, **kwargs) -> Iterator[OutputAttr]: + def rule(self, node: RuleNode, **kwargs) -> AsyncIterable[OutputAttr]: return self.grammar(node, **kwargs) - def subgrammar(self, node: SubgrammarNode, **kwargs) -> Iterator[OutputAttr]: + def subgrammar(self, node: SubgrammarNode, **kwargs) -> AsyncIterable[OutputAttr]: return self.grammar(node, **kwargs) - def json(self, node: JsonNode, **kwargs) -> Iterator[OutputAttr]: + def json(self, node: JsonNode, **kwargs) -> AsyncIterable[OutputAttr]: return self.grammar(node, **kwargs) - def lark(self, node: LarkNode, **kwargs) -> Iterator[OutputAttr]: + def lark(self, node: LarkNode, **kwargs) -> AsyncIterable[OutputAttr]: return self.grammar(node, **kwargs) - def audio_blob(self, node: AudioBlob, **kwargs) -> Iterator[OutputAttr]: + def audio_blob(self, node: AudioBlob, **kwargs) -> AsyncIterable[OutputAttr]: raise UnsupportedNodeError(interpreter=self, node=node) - def gen_audio(self, node: GenAudio, **kwargs) -> Iterator[OutputAttr]: + def gen_audio(self, node: GenAudio, **kwargs) -> AsyncIterable[OutputAttr]: raise UnsupportedNodeError(interpreter=self, node=node) - class UnsupportedNodeError(ValueError): def __init__(self, interpreter: Interpreter, node: ASTNode): super().__init__(f"{interpreter} does not support {node!r} of type {type(node)}") diff --git a/guidance/models/_base/_model.py b/guidance/models/_base/_model.py index 95c347420..2e97457af 100644 --- a/guidance/models/_base/_model.py +++ b/guidance/models/_base/_model.py @@ -2,42 +2,34 @@ import queue import threading +import asyncio from contextvars import ContextVar, copy_context from copy import deepcopy -from typing import TYPE_CHECKING, Any, Iterator, Optional, TypeVar, Union - +from typing import TYPE_CHECKING, Any, Iterator, Optional, TypeVar, Union, Sequence from typing_extensions import Self -from ..._ast import ASTNode, Function, _parse_tags from ..._ast import ( ASTNode, Function, - GenAudio, - ImageBlob, - ImageUrl, - LiteralNode, - RoleEnd, - RoleStart, + AsyncFunction, + CaptureStart, + CaptureEnd, _parse_tags, ) from ...trace import ( - ImageInput, - LiteralInput, NodeAttr, - RoleCloserInput, - RoleOpenerInput, - StatelessGuidanceInput, - TextOutput, TraceNode, ) -from ...trace._trace import AudioInput from ...visual import TraceMessage +from ..._reentrant_async import sync_to_reentrant_async, reentrant_await, run_async_coroutine_in_bg_async + from ._interpreter import Interpreter from ._state import State if TYPE_CHECKING: from ...library._block import Block +_below_entry_point: ContextVar[bool] = ContextVar("below_entry_point", default=False) _active_blocks: ContextVar[tuple["Block", ...]] = ContextVar("active_blocks", default=()) _event_queues: ContextVar[tuple[queue.Queue["Model"], ...]] = ContextVar( "event_queues", default=() @@ -57,23 +49,45 @@ def _gen_id(): S = TypeVar("S", bound=State) D = TypeVar("D", bound=Any) +from dataclasses import dataclass, field, InitVar +@dataclass class Model: - def __init__( - self, - interpreter: Interpreter[S], - echo: bool = True, - ) -> None: - self.echo = echo + interpreter: InitVar[Interpreter[S]] + echo: bool = True + + # Private init attributes + _interpreter: Interpreter = field(init=False) + _parent: Optional["Model"] = None + _pending: Union[None, ASTNode, Function] = None + _active_blocks: tuple["Block", ...] = () + + # Private non-init attributes + _parent_id: Optional[int] = field(init=False, default=None) + _id: int = field(init=False, default_factory=_gen_id) + _trace_nodes: set[TraceNode] = field(init=False, default_factory=set) + + def __post_init__(self, interpreter: Interpreter) -> None: self._interpreter = interpreter - self._active_blocks: dict[Block, int] = {} - self.token_count: int = 0 + # Set the parent ID if we have a parent + if self._parent is not None: + self._parent_id = self._parent._id - self._parent: Optional["Model"] = None - self._parent_id: Optional[int] = None - self._id: int = _gen_id() - self._trace_nodes: set[TraceNode] = set() - self._update_trace_node(self._id, self._parent_id, None) + def copy(self) -> Self: + obj = object.__new__(self.__class__) + obj.__dict__.update(self.__dict__) + # Use the base-class's __init__ to set up the new object + # TODO: if we can move to having just the one Model class, + # we can replace this all with a simple `dataclasses.replace(self, ...)` + Model.__init__( + obj, + interpreter=deepcopy(self._interpreter), + # TODO: should this be our parent? Or is the copy really our child? + _parent=self, + _pending=self._pending, + _active_blocks=self._active_blocks, + ) + return obj def _update_trace_node( self, identifier: int, parent_id: Optional[int], node_attr: Optional[NodeAttr] = None @@ -91,60 +105,32 @@ def _update_trace_node( node_attr=node_attr, ), ) + pass + + def _increment_trace_id(self) -> None: + # This is a bit of a hack to get the trace ids working (only one output attr is allowed per id, so we need to increment.) + # Parent will be the real parent, so this is all a bit of a mess. TODO: allow multiple output attrs per id + self._parent_id = self._id + self._id = _gen_id() - def __add__(self, other: Union[str, Function, ASTNode]) -> Self: - self = self._apply_blocks() + def _add_to_pending(self, item: Union[ASTNode, Function]) -> None: + if self._pending is None: + self._pending = item + else: + self._pending += item + + def __add__(self, other: Union[str, Function, AsyncFunction, ASTNode]) -> Self: + self = self.copy() + self._apply_blocks() if isinstance(other, str): if other == "": return self other = _parse_tags(other) - if isinstance(other, Function): - return other(self) - if isinstance(other, ASTNode): - self = self._apply_node(other) - self = self._update_open_block_captures() + if isinstance(other, (ASTNode, Function, AsyncFunction)): + self._add_to_pending(other) return self return NotImplemented - def _apply_node(self, node: ASTNode) -> Self: - self = self.copy() - - # Input side of trace handler. - # TODO: StatefulGuidanceInput up in __add__? - if isinstance(node, RoleStart): - self._update_trace_node(self._id, self._parent_id, RoleOpenerInput(name=node.role)) - elif isinstance(node, RoleEnd): - self._update_trace_node(self._id, self._parent_id, RoleCloserInput(name=node.role)) - elif isinstance(node, LiteralNode): - self._update_trace_node(self._id, self._parent_id, LiteralInput(value=node.value)) - elif isinstance(node, ImageBlob): - self._update_trace_node(self._id, self._parent_id, ImageInput(value=node.data)) - elif isinstance(node, ImageUrl): - # TODO -- let's avoid downloading it here - pass - elif isinstance(node, GenAudio): - self._update_trace_node( - self._id, self._parent_id, AudioInput(value="") - ) # TODO -- what goes here? - else: - self._update_trace_node(self._id, self._parent_id, StatelessGuidanceInput(value=node)) - - for i, output_attr in enumerate(self._interpreter.run(node)): - if isinstance(output_attr, TextOutput): - # TODO: put this elsewhere (inside state?) - self.token_count += output_attr.token_count - if i != 0: - # On the first iteration, we already have a fresh trace node - # TODO: should be allowed to associate multiple output_attrs with a single input node? - # TODO: put this responsibility on the client in the case that it breaks a single input - # node into multiple input nodes to be handled sequentially? - self._parent_id = self._id - self._id = _gen_id() - self._update_trace_node(self._id, self._parent_id, output_attr) - # Stream current model state - self._send_to_event_queue() - return self - def _send_to_event_queue(self) -> None: """For streaming""" for event_queue in _event_queues.get(): @@ -154,63 +140,38 @@ def stream(self) -> "ModelStream": """Return a new model stream object that delays execution until it is iterated over.""" return ModelStream(self) - def _apply_blocks(self) -> Self: - self = self.copy() + def _apply_blocks(self) -> None: global_active_blocks = _active_blocks.get() - for block, start_index in list(reversed(self._active_blocks.items())): + new_active_blocks = [] + for block in reversed(self._active_blocks): # Close blocks that are not globally active anymore if block not in global_active_blocks: - self._active_blocks.pop(block) if block.closer is not None: closer = block.closer if isinstance(closer, str): closer = _parse_tags(closer) - if isinstance(closer, Function): - raise NotImplementedError( - "Stateful block opener/closer functions are not yet supported" - ) - self = self._apply_node(closer) - # Update capture regardless of whether or not it's been closed - if block.name is not None: - self = self.set(block.name, str(self)[start_index:]) + self._add_to_pending(closer) + if block.name is not None: + self._add_to_pending(CaptureEnd(name=block.name)) + else: + # Not closed, so keep it + new_active_blocks.append(block) + new_active_blocks = list(reversed(new_active_blocks)) for block in global_active_blocks: # Open blocks that are not yet locally active if block not in self._active_blocks: - # Set start_index to the current length - self._active_blocks[block] = len(self) + new_active_blocks.append(block) + if block.name is not None: + self._add_to_pending(CaptureStart(name=block.name)) if block.opener is not None: opener = block.opener if isinstance(opener, str): opener = _parse_tags(opener) - if isinstance(opener, Function): - raise NotImplementedError( - "Stateful block opener/closer functions are not yet supported" - ) - self = self._apply_node(opener) - return self - - def _update_open_block_captures(self) -> Self: - self = self.copy() - for block, start_index in self._active_blocks.items(): - if block.name is not None: - self = self.set(block.name, str(self)[start_index:]) - return self - - def copy(self) -> Self: - obj = object.__new__(self.__class__) - obj.__dict__.update(self.__dict__) - - obj._interpreter = deepcopy(self._interpreter) - obj._active_blocks = {**self._active_blocks} - obj._id = _gen_id() - obj._parent_id = self._id - obj._trace_nodes = set() - obj._parent = self - obj._update_trace_node(obj._id, obj._parent_id, None) - return obj + self._add_to_pending(opener) + self._active_blocks = tuple(new_active_blocks) def __str__(self) -> str: - return str(self._interpreter.state) + return str(self._get_state()) def __len__(self): return len(str(self)) @@ -222,7 +183,7 @@ def __setitem__(self, key, value): def __getitem__(self, key: str) -> Any: try: - captures = self._interpreter.state.captures[key] + captures = self._get_state().captures[key] except KeyError: raise KeyError(f"Model does not contain the variable '{key}'") if isinstance(captures, list): @@ -231,7 +192,7 @@ def __getitem__(self, key: str) -> Any: return captures["value"] def __contains__(self, key: str) -> bool: - return key in self._interpreter.state.captures + return key in self._get_state().captures def get(self, key: str, default: Optional[D] = None) -> Union[str, list[str], None, D]: """Return the value of a variable, or a default value if the variable is not present. @@ -260,9 +221,9 @@ def set(self, key: str, value: Union[str, list[str]]) -> Self: """ self = self.copy() if isinstance(value, list): - self._interpreter.state.captures[key] = [{"value": v, "log_prob": None} for v in value] + self._get_state().captures[key] = [{"value": v, "log_prob": None} for v in value] else: - self._interpreter.state.captures[key] = {"value": value, "log_prob": None} + self._get_state().captures[key] = {"value": value, "log_prob": None} return self def remove(self, key: str) -> Self: @@ -274,7 +235,7 @@ def remove(self, key: str) -> Self: The variable name to remove. """ self = self.copy() - self._interpreter.state.captures.pop(key) + self._get_state().captures.pop(key) return self def log_prob( @@ -290,7 +251,7 @@ def log_prob( The value to return if the variable is not current set. """ try: - captures = self._interpreter.state.captures[key] + captures = self._get_state().captures[key] except KeyError: return default if isinstance(captures, list): @@ -304,37 +265,118 @@ def __getattribute__(self, name): return getattr(self._interpreter, "engine") return super().__getattribute__(name) + async def run_batched_async(self, items: Sequence[Union[str, Function, AsyncFunction, ASTNode]]) -> Self: + lms = [self + item for item in items] + coros = [lm._run() for lm in lms] + await asyncio.gather(*coros) + return lms + + def run_batched(self, items: Sequence[Union[str, Function, AsyncFunction, ASTNode]]) -> Self: + if not _below_entry_point.get(): + return run_async_coroutine_in_bg_async(self.run_batched_async(items)) + return reentrant_await(self.run_batched_async(items)) + + async def _run(self) -> None: + # TODO: trace `InputAttr`s + async def inner(): + new_self = self.copy() + # may be some pending blocks + new_self._apply_blocks() + while isinstance(new_self._pending, (Function, AsyncFunction)): + func = new_self._pending + new_self._pending = None + new_self._active_blocks = () + if isinstance(func, AsyncFunction): + new_self = await func(new_self) + else: + # If someone awaits us directly (i.e. we're not below an `await_`), + # we need to wrap the sync part in `async_` to avoid blocking our caller's + # event loop. + # Otherwise, this is effectively equivalent to func(new_self) + new_self = await sync_to_reentrant_async(func)(new_self) + # may be some pending blocks + new_self._apply_blocks() + self.__dict__ = new_self.__dict__ # I guess + if self._pending is None: + return + + assert isinstance(self._pending, ASTNode) + node = self._pending + self._pending = None + await self._run_node(node) + + # Mark that we are below the entry point so that + # `_run_sync` knows to use `await_` instead of + # running in the background thread. + token = _below_entry_point.set(True) + try: + return await inner() + finally: + _below_entry_point.reset(token) + + def _run_sync(self) -> None: + if not _below_entry_point.get(): + return run_async_coroutine_in_bg_async(self._run()) + return reentrant_await(self._run()) + + async def _run_node(self, node: ASTNode) -> None: + async for node_attr in self._interpreter.run(node): + self._increment_trace_id() + self._update_trace_node(self._id, self._parent_id, node_attr) + # Stream current model state + self._send_to_event_queue() + + async def _get_state_async(self) -> State: + """Get the state of the model.""" + await self._run() + return self._interpreter.state + + def _get_state(self) -> State: + """Get the state of the model.""" + self._run_sync() + return self._interpreter.state + + async def get_async(self, key: str) -> Any: + try: + captures = (await self._get_state_async()).captures[key] + except KeyError: + raise KeyError(f"Model does not contain the variable '{key}'") + if isinstance(captures, list): + return [c["value"] for c in captures] + else: + return captures["value"] + + async def to_string_async(self) -> str: + """Get the string representation of the model.""" + return str(await self._get_state_async()) + + async def length_async(self) -> int: + """Get the length of the model.""" + return len(await self.to_string_async()) + + async def get_token_count_async(self) -> int: + """Get the token count of the model.""" + return (await self._get_state_async()).token_count + + def get_token_count(self) -> int: + """Get the token count of the model.""" + return self._get_state().token_count class ModelStream: def __init__( self, model: Model, - grammar: Union["ModelStream", str, ASTNode, Function, None] = None, - timeout=5, + timeout: float = 5.0, ) -> None: """Create a model stream object that delays execution until it is iterated over.""" if model.echo: model = model.copy() model.echo = False # turn off display echoing self.model = model - self.grammar = grammar self.timeout = timeout - def __add__(self, grammar: Union[str, ASTNode]) -> Self: - """Extend this delayed chain of execution with another grammar append.""" - if self.grammar is None: - return ModelStream(self.model, grammar) - else: - return ModelStream(self.model, self.grammar + grammar) - - def _inner_run(self, model): - """This runs the model stream without iterating, and is only using internally by __iter__.""" - if isinstance(self.grammar, ModelStream): - model = self.grammar._inner_run(model) - elif self.grammar is None: - model = self.model + "" - else: - model = self.model + self.grammar + def __add__(self, other: Any) -> Self: + return ModelStream(self.model + other) def __iter__(self) -> Iterator[Model]: """Starts a thread to execute the model and grammar, yielding events as they occur.""" @@ -347,7 +389,7 @@ def __iter__(self) -> Iterator[Model]: def target(ctx): _event_queues.set(ctx[_event_queues]) try: - self._inner_run(self.model) + self.model._run_sync() events.put(None) # mark that we are done except BaseException as ex: events.put(ex) diff --git a/guidance/models/_base/_state.py b/guidance/models/_base/_state.py index 38dd79de5..d27fdf3b9 100644 --- a/guidance/models/_base/_state.py +++ b/guidance/models/_base/_state.py @@ -13,6 +13,21 @@ class State(ABC): def __init__(self) -> None: self.captures: dict[str, Union[CaptureVar, list[CaptureVar]]] = {} self.active_role: Optional[str] = None + self.open_capture_blocks: dict[str, int] = {} + self.token_count: int = 0 + + def open_capture(self, name: str) -> None: + self.open_capture_blocks[name] = len(str(self)) + + def close_capture(self, name: str) -> CaptureOutput: + start_index = self.open_capture_blocks.pop(name) + value = str(self)[start_index:] + return self.apply_capture( + name=name, + value=value, + log_prob=None, + is_append=False, + ) @abstractmethod def __str__(self) -> str: diff --git a/guidance/models/_engine/_interpreter.py b/guidance/models/_engine/_interpreter.py index 8fdf599dd..72183c757 100644 --- a/guidance/models/_engine/_interpreter.py +++ b/guidance/models/_engine/_interpreter.py @@ -1,6 +1,6 @@ from base64 import b64decode from io import BytesIO -from typing import Iterator +from typing import AsyncIterable from copy import deepcopy from ..._ast import GrammarNode, ImageBlob, LiteralNode, RoleEnd, RoleStart @@ -39,21 +39,21 @@ def get_role_end(self, role: str) -> str: raise ValueError("Cannot use roles without a chat template") return self.chat_template.get_role_end(role) - def role_start(self, node: RoleStart, **kwargs) -> Iterator[OutputAttr]: + def role_start(self, node: RoleStart, **kwargs) -> AsyncIterable[OutputAttr]: self.state.active_role = node.role # TODO: mark these as special tokens..? - yield from self.run(LiteralNode(value=self.get_role_start(node.role)), **kwargs) + return self.run(LiteralNode(value=self.get_role_start(node.role)), **kwargs) - def role_end(self, node: RoleEnd, **kwargs) -> Iterator[OutputAttr]: + def role_end(self, node: RoleEnd, **kwargs) -> AsyncIterable[OutputAttr]: self.state.active_role = None # TODO: mark these as special tokens..? - yield from self.run(LiteralNode(value=self.get_role_end(node.role)), **kwargs) + return self.run(LiteralNode(value=self.get_role_end(node.role)), **kwargs) - def text(self, node: LiteralNode, **kwargs) -> Iterator[OutputAttr]: + async def text(self, node: LiteralNode, **kwargs) -> AsyncIterable[OutputAttr]: self.state.prompt += node.value yield TextOutput(value=node.value, is_input=True) - def grammar(self, node: GrammarNode, **kwargs) -> Iterator[OutputAttr]: + async def grammar(self, node: GrammarNode, **kwargs) -> AsyncIterable[OutputAttr]: engine_gen = self.engine( state=self.state, grammar=node.ll_grammar(), @@ -62,13 +62,14 @@ def grammar(self, node: GrammarNode, **kwargs) -> Iterator[OutputAttr]: ) delayed_bytes = b"" + # TODO: this should be async some day for chunk in engine_gen: new_bytes = chunk.new_bytes new_text, delayed_bytes = partial_decode(new_bytes) # Update the state self.state.prompt += new_text - yield TextOutput(value=new_text, token_count=chunk.new_token_count, is_generated=True) + yield TextOutput(value=new_text, token_count=chunk.new_token_count, is_generated=chunk.is_generated) # TODO -- rewrite engine internals to make sure chunk.{generated,fast_forwarded}_tokens aren't empty... # # TODO: GenTokenExtra @@ -112,7 +113,7 @@ def grammar(self, node: GrammarNode, **kwargs) -> Iterator[OutputAttr]: class Llama3VisionInterpreter(EngineInterpreter): - def image_blob(self, node: ImageBlob, **kwargs) -> Iterator[OutputAttr]: + async def image_blob(self, node: ImageBlob, **kwargs) -> AsyncIterable[OutputAttr]: try: import PIL.Image except ImportError: @@ -129,7 +130,7 @@ def image_blob(self, node: ImageBlob, **kwargs) -> Iterator[OutputAttr]: class Phi3VisionInterpreter(EngineInterpreter): - def image_blob(self, node: ImageBlob, **kwargs) -> Iterator[OutputAttr]: + async def image_blob(self, node: ImageBlob, **kwargs) -> AsyncIterable[OutputAttr]: try: import PIL.Image except ImportError: diff --git a/guidance/models/_openai.py b/guidance/models/_openai.py index e7a2489de..11f959aab 100644 --- a/guidance/models/_openai.py +++ b/guidance/models/_openai.py @@ -1,20 +1,8 @@ -from typing import TYPE_CHECKING, Callable, Iterator, Optional, Union +from typing import Optional -from pydantic import TypeAdapter - - -from .._ast import ( - JsonNode, - RuleNode, -) -from ..trace import OutputAttr from ._base import Model - from ._openai_base import ( BaseOpenAIInterpreter, - AudioContent, - OpenAIState, - Message, OpenAIImageMixin, OpenAIAudioMixin, ) @@ -33,7 +21,7 @@ def __init__( raise Exception( "Please install the openai package version >= 1 using `pip install openai -U` in order to use guidance.models.OpenAI!" ) - client = openai.OpenAI(api_key=api_key, **kwargs) + client = openai.AsyncOpenAI(api_key=api_key, **kwargs) super().__init__(model=model, client=client) @@ -64,13 +52,13 @@ def __init__( if "audio-preview" in model: interpreter_cls = type( - "OpenAIAudioInterpreter", (OpenAIInterpreter, OpenAIAudioMixin), {} + "OpenAIAudioInterpreter", (OpenAIAudioMixin, OpenAIInterpreter), {} ) elif model.startswith("gpt-4o") or model.startswith("o1"): interpreter_cls = type( - "OpenAIImageInterpreter", (OpenAIInterpreter, OpenAIImageMixin), {} + "OpenAIImageInterpreter", (OpenAIImageMixin, OpenAIInterpreter), {} ) else: - interpreter_cls = OpenAIInterpreter + interpreter_cls = BaseOpenAIInterpreter super().__init__(interpreter=interpreter_cls(model, api_key=api_key, **kwargs), echo=echo) diff --git a/guidance/models/_openai_base.py b/guidance/models/_openai_base.py index cf822f292..27cd14126 100644 --- a/guidance/models/_openai_base.py +++ b/guidance/models/_openai_base.py @@ -2,7 +2,7 @@ import wave from copy import deepcopy from io import BytesIO -from typing import TYPE_CHECKING, Callable, Iterator, Literal, Optional, Union +from typing import TYPE_CHECKING, Callable, AsyncIterable, Literal, Optional, Union from pydantic import BaseModel, Discriminator, Field, TypeAdapter from typing_extensions import Annotated, assert_never @@ -157,7 +157,7 @@ class BaseOpenAIInterpreter(Interpreter[OpenAIState]): def __init__( self, model: str, - client: "openai.OpenAI", + client: "openai.AsyncOpenAI", ): try: import openai @@ -169,31 +169,26 @@ def __init__( self.model = model self.client = client - def run(self, node: ASTNode, **kwargs) -> Iterator[OutputAttr]: - if not isinstance(node, RoleStart) and self.state.active_role is None: - raise ValueError( - "OpenAI models require an active role (e.g. use `with assistant(): ...`)" - ) - return super().run(node, **kwargs) - - def role_start(self, node: RoleStart, **kwargs) -> Iterator[OutputAttr]: + async def role_start(self, node: RoleStart, **kwargs) -> AsyncIterable[OutputAttr]: self.state.active_role = node.role # TODO: drop this and yield nothing. We need to add this for now as a workaround for the # fact that current vis code assumes that there is actually a role start message yield TextOutput(value=get_role_start(node.role), is_input=True) - def role_end(self, node: RoleEnd, **kwargs) -> Iterator[OutputAttr]: + async def role_end(self, node: RoleEnd, **kwargs) -> AsyncIterable[OutputAttr]: self.state.messages.append(self.state.get_active_message()) self.state.audio = None self.state.content = [] self.state.active_role = None - yield from () + if False: + # I know this is weird, but this is how async generators work + yield - def text(self, node: LiteralNode, **kwargs) -> Iterator[OutputAttr]: + async def text(self, node: LiteralNode, **kwargs) -> AsyncIterable[OutputAttr]: self.state.apply_text(node.value) yield TextOutput(value=node.value, is_input=True) - def rule(self, node: RuleNode, **kwargs) -> Iterator[OutputAttr]: + async def rule(self, node: RuleNode, **kwargs) -> AsyncIterable[OutputAttr]: if node.stop: raise ValueError("Stop condition not yet supported for OpenAI") if node.suffix: @@ -210,7 +205,7 @@ def rule(self, node: RuleNode, **kwargs) -> Iterator[OutputAttr]: chunks = self.run(node.value, **kwargs) if node.capture: buffered_text = "" - for chunk in chunks: + async for chunk in chunks: # TODO: this isinstance check is pretty darn fragile. # ~there must be a better way~ if isinstance(chunk, TextOutput): @@ -223,15 +218,16 @@ def rule(self, node: RuleNode, **kwargs) -> Iterator[OutputAttr]: is_append=node.list_append, ) else: - yield from chunks + async for chunk in chunks: + yield chunk - def regex(self, node: RegexNode, **kwargs) -> Iterator[OutputAttr]: + def regex(self, node: RegexNode, **kwargs) -> AsyncIterable[OutputAttr]: if node.regex is not None: raise ValueError("Regex not yet supported for OpenAI") # We're in unconstrained mode now. return self._run(**kwargs) - def json(self, node: JsonNode, **kwargs) -> Iterator[OutputAttr]: + def json(self, node: JsonNode, **kwargs) -> AsyncIterable[OutputAttr]: return self._run( response_format={ "type": "json_schema", @@ -244,7 +240,7 @@ def json(self, node: JsonNode, **kwargs) -> Iterator[OutputAttr]: **kwargs, ) - def _run(self, **kwargs) -> Iterator[OutputAttr]: + async def _run(self, **kwargs) -> AsyncIterable[OutputAttr]: if self.state.active_role is None: # Should never happen? raise ValueError( @@ -259,18 +255,21 @@ def _run(self, **kwargs) -> Iterator[OutputAttr]: f"OpenAI models do not support pre-filled assistant messages: got data {self.state.content}." ) - with self.client.chat.completions.create( + async with await self.client.chat.completions.create( model=self.model, messages=TypeAdapter(list[Message]).dump_python(self.state.messages), # type: ignore[arg-type] logprobs=self.log_probs, stream=True, **kwargs, ) as chunks: - yield from self._handle_stream(chunks) + async for output in self._handle_stream(chunks): + yield output - def _handle_stream(self, chunks: Iterator["ChatCompletionChunk"]) -> Iterator[OutputAttr]: + async def _handle_stream( + self, chunks: AsyncIterable["ChatCompletionChunk"] + ) -> AsyncIterable[OutputAttr]: audio: Optional[AssistantAudio] = None - for chunk in chunks: + async for chunk in chunks: try: choice = chunk.choices[0] except IndexError: @@ -357,7 +356,7 @@ def __deepcopy__(self, memo): class OpenAIImageMixin: - def image_blob(self, node: ImageBlob, **kwargs) -> Iterator[OutputAttr]: + async def image_blob(self, node: ImageBlob, **kwargs) -> AsyncIterable[OutputAttr]: try: import PIL.Image except ImportError: @@ -379,7 +378,7 @@ def image_blob(self, node: ImageBlob, **kwargs) -> Iterator[OutputAttr]: ) yield ImageOutput(value=node.data, input=True) - def image_url(self, node: ImageUrl, **kwargs) -> Iterator[OutputAttr]: + async def image_url(self, node: ImageUrl, **kwargs) -> AsyncIterable[OutputAttr]: self.state.content.append({"type": "image_url", "image_url": {"url": node.url}}) image_bytes = bytes_from(node.url, allow_local=False) base64_string = base64.b64encode(image_bytes).decode("utf-8") @@ -387,7 +386,9 @@ def image_url(self, node: ImageUrl, **kwargs) -> Iterator[OutputAttr]: class OpenAIAudioMixin: - def audio_blob(self, node: ImageBlob, **kwargs) -> Iterator[OutputAttr]: + log_probs: bool = False + + async def audio_blob(self, node: ImageBlob, **kwargs) -> AsyncIterable[OutputAttr]: format = "wav" # TODO: infer from node self.state.content.append( AudioContent( @@ -400,8 +401,8 @@ def audio_blob(self, node: ImageBlob, **kwargs) -> Iterator[OutputAttr]: ) yield AudioOutput(value=node.data, format=format, input=True) - def gen_audio(self, node: GenAudio, **kwargs) -> Iterator[OutputAttr]: - yield from self._run( + def gen_audio(self, node: GenAudio, **kwargs) -> AsyncIterable[OutputAttr]: + return self._run( modalities=["text", "audio"], # Has to be both? audio={ "voice": node.kwargs.get("voice", "alloy"), diff --git a/guidance/models/experimental/__init__.py b/guidance/models/experimental/__init__.py index 09d1311b5..5e0180708 100644 --- a/guidance/models/experimental/__init__.py +++ b/guidance/models/experimental/__init__.py @@ -1 +1 @@ -from ._vllm import VLLMModel \ No newline at end of file +from ._vllm import VLLMModel diff --git a/guidance/models/experimental/_vllm.py b/guidance/models/experimental/_vllm.py index 61c79afd4..2928d43ae 100644 --- a/guidance/models/experimental/_vllm.py +++ b/guidance/models/experimental/_vllm.py @@ -1,10 +1,12 @@ -from typing import Iterator, Optional, TYPE_CHECKING +from typing import AsyncIterable, Optional, TYPE_CHECKING import wave import base64 from io import BytesIO from copy import deepcopy from pydantic import TypeAdapter + if TYPE_CHECKING: + import openai from openai.types.chat import ChatCompletionChunk from ..._ast import GrammarNode, RoleStart, RoleEnd, ASTNode, LiteralNode @@ -13,15 +15,14 @@ from .._openai_base import OpenAIState, AssistantAudio, Message, get_role_start from .._base import Model, Interpreter + class BaseOpenAIInterpreterForVLLM(Interpreter[OpenAIState]): log_probs: bool = True def __init__( self, model: str, - base_url: Optional[str] = None, - api_key: Optional[str] = None, - **kwargs, + client: "openai.AsyncOpenAI", ): try: import openai @@ -31,33 +32,28 @@ def __init__( ) self.state = OpenAIState() self.model = model - self.client = openai.OpenAI(base_url=base_url, api_key=api_key, **kwargs) + self.client = client - def role_start(self, node: RoleStart, **kwargs) -> Iterator[OutputAttr]: + async def role_start(self, node: RoleStart, **kwargs) -> AsyncIterable[OutputAttr]: self.state.active_role = node.role # TODO: drop this and yield nothing. We need to add this for now as a workaround for the # fact that current vis code assumes that there is actually a role start message yield TextOutput(value=get_role_start(node.role), is_input=True) - def role_end(self, node: RoleEnd, **kwargs) -> Iterator[OutputAttr]: + async def role_end(self, node: RoleEnd, **kwargs) -> AsyncIterable[OutputAttr]: self.state.messages.append(self.state.get_active_message()) self.state.audio = None self.state.content = [] self.state.active_role = None - yield from () + if False: + # I know this is weird, but this is how async generators work + yield - def text(self, node: LiteralNode, **kwargs) -> Iterator[OutputAttr]: + async def text(self, node: LiteralNode, **kwargs) -> AsyncIterable[OutputAttr]: self.state.apply_text(node.value) yield TextOutput(value=node.value, is_input=True) - def run(self, node: ASTNode, **kwargs) -> Iterator[OutputAttr]: - if not isinstance(node, RoleStart) and self.state.active_role is None: - raise ValueError( - "OpenAI models require an active role (e.g. use `with assistant(): ...`)" - ) - return super().run(node, **kwargs) - - def _run(self, **kwargs) -> Iterator[OutputAttr]: + async def _run(self, **kwargs) -> AsyncIterable[OutputAttr]: if self.state.active_role is None: # Should never happen? raise ValueError( @@ -72,21 +68,27 @@ def _run(self, **kwargs) -> Iterator[OutputAttr]: f"OpenAI models do not support pre-filled assistant messages: got data {self.state.content}." ) - with self.client.chat.completions.create( + async with await self.client.chat.completions.create( model=self.model, messages=TypeAdapter(list[Message]).dump_python(self.state.messages), # type: ignore[arg-type] logprobs=self.log_probs, stream=True, **kwargs, ) as chunks: - yield from self._handle_stream(chunks) + async for output in self._handle_stream(chunks): + yield output - def _handle_stream( - self, chunks: Iterator["ChatCompletionChunk"] - ) -> Iterator[OutputAttr]: + async def _handle_stream( + self, chunks: AsyncIterable["ChatCompletionChunk"] + ) -> AsyncIterable[OutputAttr]: audio: Optional[AssistantAudio] = None - for chunk in chunks: - choice = chunk.choices[0] + async for chunk in chunks: + try: + choice = chunk.choices[0] + except IndexError: + # TODO: azure seems to return empty choices sometimes (on first chunk?) + # Need to make this more robust + continue delta = choice.delta if delta.content is not None: assert audio is None @@ -94,7 +96,7 @@ def _handle_stream( if len(content) == 0: continue self.state.apply_text(content) - if choice.logprobs is not None: + if getattr(choice, "logprobs", None) is not None: # TODO: actually get tokens from this and be less lazy prob = 2.718 ** choice.logprobs.content[0].logprob else: @@ -165,11 +167,12 @@ def __deepcopy__(self, memo): setattr(result, k, deepcopy(v, memo)) return result + class VLLMInterpreter(BaseOpenAIInterpreterForVLLM): - def grammar(self, node: GrammarNode, **kwargs) -> Iterator[OutputAttr]: + async def grammar(self, node: GrammarNode, **kwargs) -> AsyncIterable[OutputAttr]: buffer: str = "" - for attr in self._run( - extra_body = dict( + async for attr in self._run( + extra_body=dict( guided_decoding_backend="guidance", guided_grammar=node.ll_grammar(), ) @@ -194,13 +197,29 @@ def grammar(self, node: GrammarNode, **kwargs) -> Iterator[OutputAttr]: assert isinstance(log_probs, list) assert len(value) == len(log_probs) for v, l in zip(value, log_probs): - yield self.state.apply_capture(name=name, value=v, log_prob=l, is_append=True) + yield self.state.apply_capture( + name=name, value=v, log_prob=l, is_append=True + ) else: - yield self.state.apply_capture(name=name, value=value, log_prob=log_probs, is_append=False) + yield self.state.apply_capture( + name=name, value=value, log_prob=log_probs, is_append=False + ) + class VLLMModel(Model): - def __init__(self, model: str, echo=True, **kwargs): + def __init__( + self, model: str, base_url: str, api_key: Optional[str] = None, echo: bool = True, **kwargs + ): + try: + import openai + except ImportError: + raise Exception( + "Please install the openai package version >= 1 using `pip install openai -U` in order to use guidance.models.experimenta.VLLMModel!" + ) super().__init__( - interpreter=VLLMInterpreter(model=model, **kwargs), + interpreter=VLLMInterpreter( + model=model, + client=openai.AsyncOpenAI(base_url=base_url, api_key=api_key, **kwargs), + ), echo=echo, ) diff --git a/setup.py b/setup.py index 0d1a3cfa5..2c7778f38 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,7 @@ "pydantic", "requests", "psutil", + "greenlet", "guidance-stitch", "llguidance==0.7.19", ] diff --git a/tests/model_integration/library/test_gen.py b/tests/model_integration/library/test_gen.py index 3a03aa91c..7653284b8 100644 --- a/tests/model_integration/library/test_gen.py +++ b/tests/model_integration/library/test_gen.py @@ -52,8 +52,10 @@ def test_metrics_smoke(selected_model: models.Model): lm.engine.reset_metrics() lm += "abcd" + str(lm) # trigger execution print(f"{lm.engine.metrics=}") lm += gen("first", max_tokens=1) + str(lm) # trigger execution print(f"{lm.engine.metrics=}") # Can't be sure of exact count due to token healing assert ( @@ -65,6 +67,7 @@ def test_metrics_smoke(selected_model: models.Model): lm += "fg" lm += gen("second", max_tokens=1) + str(lm) # trigger execution # Again, trouble with healing assert ( lm.engine.metrics.engine_output_tokens >= 2 @@ -85,7 +88,7 @@ def test_metrics_select(selected_model: models.Model): "go for a swim in the ocean", ] ) - print(f"lm={str(lm)}") + print(f"lm={str(lm)}") # trigger execution print(f"{lm.engine.metrics=}") assert lm.engine.metrics.engine_input_tokens > 1 assert lm.engine.metrics.engine_output_tokens > 0 @@ -106,6 +109,7 @@ def test_unicode(selected_model: models.Model): Step 1''' + gen('steps', list_append=True, stop=['\nStep', '\n\n', '\nAnswer'], temperature=0.7, max_tokens=20) + '\n' i = 2 lm + f'Step {i}:' + gen('steps', list_append=True, stop=['\nStep', '\n\n', '\nAnswer'], temperature=0.7, max_tokens=20) + '\n' + str(lm) # trigger execution # fmt: on @@ -114,6 +118,7 @@ def test_unicode2(selected_model: models.Model): lm.engine.reset_metrics() prompt = "Janet’s ducks lay 16 eggs per day" lm += prompt + gen(max_tokens=10) + str(lm) # trigger execution assert lm.engine.metrics.engine_input_tokens > 1 # Due to token healing, we can't be sure of the # precise output count diff --git a/tests/model_integration/test_model.py b/tests/model_integration/test_model.py index affbe3c34..83582b180 100644 --- a/tests/model_integration/test_model.py +++ b/tests/model_integration/test_model.py @@ -26,7 +26,7 @@ def test_token_count(selected_model): lm = selected_model lm2 = lm + " 1 1 1 1 1" + gen(max_tokens=9) + gen(max_tokens=9) assert ( - 18 <= lm2.token_count <= 20 + 18 <= lm2.get_token_count() <= 20 ) # note we allow ourselves to be off by one because it is hard to know when we are continuing vs starting a new token in the parser @@ -76,11 +76,15 @@ def test_associativity(selected_model: models.Model): engine = selected_model.engine with patch.object(engine, "get_logits", side_effect=engine.get_logits) as get_logits_1: - _ = selected_model + (prompt + grammar) + lm = selected_model + (prompt + grammar) + _ = str(lm) # trigger execution prompt_tokens_1 = get_logits_1.call_args_list[0].kwargs["token_ids"] with patch.object(engine, "get_logits", side_effect=engine.get_logits) as get_logits_2: - _ = (selected_model + prompt) + grammar + lm = selected_model + prompt + _ = str(lm) # trigger execution + lm += grammar + _ = str(lm) # trigger execution prompt_tokens_2 = get_logits_2.call_args_list[0].kwargs["token_ids"] # Main assertion: the prompt tokens should be the same diff --git a/tests/model_specific/test_transformers.py b/tests/model_specific/test_transformers.py index 07e332db2..f26e03254 100644 --- a/tests/model_specific/test_transformers.py +++ b/tests/model_specific/test_transformers.py @@ -41,14 +41,17 @@ def ff_prompt(lm): gpt2_noff = models.Transformers("gpt2", enable_backtrack=False, enable_ff_tokens=False) gpt2_noff += ff_prompt() + str(gpt2_noff) # Trigger execution noff_count = gpt2_noff.engine.metrics.engine_output_tokens gpt2_nobt = models.Transformers("gpt2", enable_backtrack=False) gpt2_nobt += ff_prompt() + str(gpt2_nobt) # Trigger execution nobt_count = gpt2_nobt.engine.metrics.engine_output_tokens gpt2_ff = models.Transformers("gpt2") gpt2_ff += ff_prompt() + str(gpt2_ff) # Trigger execution ff_count = gpt2_ff.engine.metrics.engine_output_tokens assert nobt_count == 2 diff --git a/tests/unit/library/test_block.py b/tests/unit/library/test_block.py index 0726ef192..c86120a7c 100644 --- a/tests/unit/library/test_block.py +++ b/tests/unit/library/test_block.py @@ -10,12 +10,12 @@ def test_text_opener(): def test_text_closer(): - # NOTE(nopdive): Behavioral change, no longer need closer for str call. model = models.Mock("a") model += "" with block(closer="close text"): model += regex(r".") - assert str(model) == "a" + assert str(model) == "a" + assert str(model) == "aclose text" def test_grammar_opener(): diff --git a/tests/unit/test_async.py b/tests/unit/test_async.py new file mode 100644 index 000000000..abcbe5702 --- /dev/null +++ b/tests/unit/test_async.py @@ -0,0 +1,135 @@ +import asyncio +import pytest +from guidance import guidance, models, select +from guidance._reentrant_async import ReentrantAsyncException + +@guidance +def sync_func(lm: models.Model): + lm += select(["a", "b"], name="choice") + if lm.get("choice") == "a": + lm += "lpha" + else: + lm += "eta" + return lm + +@guidance +async def async_func(lm: models.Model): + lm += select(["a", "b"], name="choice") + if (await lm.get_async("choice")) == "a": + lm += "lpha" + else: + lm += "eta" + return lm + +@guidance +def sync_func_calling_sync_func(lm: models.Model): + lm += sync_func() + lm += select(["a", "b"], name="choice") + if lm.get("choice") == "a": + lm += "lpha" + else: + lm += "eta" + return lm + +@guidance +def sync_func_calling_async_func(lm: models.Model): + lm += async_func() + lm += select(["a", "b"], name="choice") + if lm.get("choice") == "a": + lm += "lpha" + else: + lm += "eta" + return lm + +@guidance +async def async_func_calling_async_func(lm: models.Model): + lm += async_func() + lm += select(["a", "b"], name="choice") + if (await lm.get_async("choice")) == "a": + lm += "lpha" + else: + lm += "eta" + return lm + +@guidance +async def async_func_calling_sync_func(lm: models.Model): + lm += sync_func() + lm += select(["a", "b"], name="choice") + if (await lm.get_async("choice")) == "a": + lm += "lpha" + else: + lm += "eta" + return lm + +@guidance +def sync_then_async(lm: models.Model): + lm += sync_func() + lm += async_func() + return lm + +@guidance +def async_then_sync(lm: models.Model): + lm += async_func() + lm += sync_func() + return lm + +def run(gfunc, sync: bool) -> str: + lm = models.Mock() + lm += gfunc() + if sync: + s = str(lm) + else: + s = asyncio.run(lm.to_string_async()) + return s + +@pytest.mark.parametrize( + "sync", + [ + True, + False, + ], +) +@pytest.mark.parametrize( + "gfunc, expected", + [ + (sync_func, {"alpha", "beta"}), + (async_func, {"alpha", "beta"}), + (sync_then_async, {"alphabeta", "betaalpha", "alphaalpha", "betabeta"}), + (async_then_sync, {"alphabeta", "betaalpha", "alphaalpha", "betabeta"}), + (sync_func_calling_sync_func, {"alphabeta", "betaalpha", "alphaalpha", "betabeta"}), + (async_func_calling_async_func, {"alphabeta", "betaalpha", "alphaalpha", "betabeta"}), + (async_func_calling_sync_func, {"alphabeta", "betaalpha", "alphaalpha", "betabeta"}), + (sync_func_calling_async_func, {"alphabeta", "betaalpha", "alphaalpha", "betabeta"}), + ], +) +def test_async(gfunc, expected, sync): + s = run(gfunc, sync) + assert s in expected + +@pytest.mark.parametrize( + "sync", + [ + True, + False, + ], +) +def test_async_with_sync_accessor(sync): + @guidance + async def async_func_with_sync_accessor(lm: models.Model): + lm += select(["a", "b"], name="choice") + if lm.get("choice") == "a": + lm += "lpha" + else: + lm += "eta" + return lm + # This should raise an AwaitException because the sync accessor is not + # allowed in the async function + with pytest.raises(ReentrantAsyncException): + run(async_func_with_sync_accessor, sync) + +def test_sync_accessor_in_foreign_event_loop(): + async def main(): + lm = models.Mock() + lm += sync_func() + assert str(lm) in {"alpha", "beta"} + asyncio.run(main()) diff --git a/tests/unit/test_visual.py b/tests/unit/test_visual.py index 4cfc82cd6..00620fbd6 100644 --- a/tests/unit/test_visual.py +++ b/tests/unit/test_visual.py @@ -8,6 +8,7 @@ from guidance.visual import serialize_message, deserialize_message from guidance.visual._environment import Environment import asyncio +import threading from guidance.visual._exchange import DEFAULT_TOPIC @@ -34,8 +35,12 @@ def test_serialization(message): def test_async(): - _, loop = get_bg_async()._thread_and_loop() - assert loop != asyncio.get_event_loop() + thread, loop = get_bg_async()._thread_and_loop() + try: + assert loop != asyncio.get_event_loop() + except RuntimeError: + pass + assert thread != threading.current_thread() async def f(): return True